support deepseekv3; runable but have precition problem

This commit is contained in:
Azure 2025-01-31 08:27:24 +00:00
parent de7e892f72
commit 476b1d8dc6
13 changed files with 2178 additions and 24 deletions

View file

@ -641,6 +641,7 @@ class KDeepseekV2Model(BaseInjectedModule):
if inputs_embeds is None:
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.to(org_device)
@ -737,8 +738,9 @@ class KDeepseekV2Model(BaseInjectedModule):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
# if use_cache:
# next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)