fix rope; update moegate

This commit is contained in:
Azure 2025-02-01 18:05:45 +00:00
parent f873558a89
commit f748cd29f0
5 changed files with 54 additions and 21 deletions

View file

@ -643,7 +643,7 @@ class KDeepseekV2Model(BaseInjectedModule):
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
input_ids = input_ids.to(org_device)
if per_layer_prefill_flag: