smallthink run

This commit is contained in:
qiyuxinlin 2025-07-24 15:08:29 +00:00
parent 590fcb41cd
commit 71c1d4eed7
7 changed files with 123 additions and 32 deletions

View file

@ -83,23 +83,6 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
with torch.cuda.stream(current_stream):
residual = torch.zeros_like(hidden_states)
for i, decode_layer in enumerate(self.model.layers):
if self.model.transfer_map is not None and i in self.model.transfer_map:
prev_stream = torch.cuda.current_stream()
cur_device = self.model.transfer_map[i]
if cur_device not in self.model.stream_device_map:
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
torch.cuda.set_device(cur_device)
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
hidden_states = hidden_states.to(
self.model.transfer_map[i], non_blocking=True
)
batch.minibatch.position_ids = (
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
if batch.minibatch.position_ids is not None
else None
)
router_input = hidden_states.clone()
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
@ -110,9 +93,9 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
if not self.config.moe_layer_layout[i]:
hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors)
hidden_states = decode_layer.block_sparse_moe(hidden_states, num_tokens_tensors)
else:
hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx)
hidden_states = decode_layer.block_sparse_moe(router_input, hidden_states, num_tokens_tensors, cuda_graph_idx)
# hidden_states = hidden_states.squeeze(0)
forward_batch_output = ForwardBatchOutput()