mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 15:54:37 +00:00
smallthink run
This commit is contained in:
parent
590fcb41cd
commit
71c1d4eed7
7 changed files with 123 additions and 32 deletions
|
@ -83,23 +83,6 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
|
|||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
router_input = hidden_states.clone()
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
|
@ -110,9 +93,9 @@ class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
|
|||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
if not self.config.moe_layer_layout[i]:
|
||||
hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors)
|
||||
hidden_states = decode_layer.block_sparse_moe(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||
hidden_states = decode_layer.block_sparse_moe(router_input, hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||
# hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue