smallthinker right

This commit is contained in:
qiyuxinlin 2025-07-25 12:46:14 +00:00
parent f8719ee7b9
commit 712ad1fa3c
7 changed files with 48 additions and 108 deletions

View file

@ -83,7 +83,7 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
with torch.cuda.stream(current_stream):
residual = torch.zeros_like(hidden_states)
for i, decode_layer in enumerate(self.model.layers):
router_input = hidden_states.clone()
router_input = hidden_states
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
freqs_cis if self.model.rope_layout[i] else None,