mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Fix cannot offload whole layer in cpu
This commit is contained in:
parent
35d7aed207
commit
6735beb5b6
4 changed files with 14 additions and 11 deletions
|
@ -67,6 +67,7 @@ def local_chat(
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||||
if mode == 'long_context':
|
if mode == 'long_context':
|
||||||
|
assert config.architectures[0] == "LlamaForCausalLM", "only LlamaForCausalLM support long_context mode"
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
else:
|
else:
|
||||||
torch.set_default_dtype(config.torch_dtype)
|
torch.set_default_dtype(config.torch_dtype)
|
||||||
|
@ -143,8 +144,9 @@ def local_chat(
|
||||||
input_tensor = tokenizer.apply_chat_template(
|
input_tensor = tokenizer.apply_chat_template(
|
||||||
messages, add_generation_prompt=True, return_tensors="pt"
|
messages, add_generation_prompt=True, return_tensors="pt"
|
||||||
)
|
)
|
||||||
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
if mode == 'long_context':
|
||||||
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
|
||||||
|
"please change max_seq_len in ~/.ktransformers/config.yaml"
|
||||||
torch.set_default_dtype(
|
torch.set_default_dtype(
|
||||||
torch.bfloat16
|
torch.bfloat16
|
||||||
) # TODO: Remove this, replace dtype using config
|
) # TODO: Remove this, replace dtype using config
|
||||||
|
|
|
@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang, chenht2022
|
||||||
Date : 2024-07-25 11:25:24
|
Date : 2024-07-25 11:25:24
|
||||||
Version : 0.1.0
|
Version : 0.1.0
|
||||||
LastEditors : Azure
|
LastEditors : Azure
|
||||||
LastEditTime : 2024-08-27 03:50:23
|
LastEditTime : 2024-08-29 09:41:10
|
||||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ class KExpertsCPU(KExpertsBase):
|
||||||
def forward(self, input_tensor, expert_ids, weights):
|
def forward(self, input_tensor, expert_ids, weights):
|
||||||
# generate, capture and run cuda graph
|
# generate, capture and run cuda graph
|
||||||
# print(expert_ids)
|
# print(expert_ids)
|
||||||
if input_tensor.size(0)==1:
|
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
|
||||||
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
|
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
|
||||||
#print("capturing experts")
|
#print("capturing experts")
|
||||||
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
|
||||||
|
@ -636,7 +636,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
|
||||||
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"):
|
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y_ = self.shared_experts(identity).squeeze(0)
|
y_ = self.shared_experts(identity).squeeze(0)
|
||||||
|
|
|
@ -6,7 +6,7 @@ Author : Azure-Tang, Boxin Zhang
|
||||||
Date : 2024-07-25 11:25:24
|
Date : 2024-07-25 11:25:24
|
||||||
Version : 0.1.0
|
Version : 0.1.0
|
||||||
LastEditors : Azure
|
LastEditors : Azure
|
||||||
LastEditTime : 2024-08-14 14:57:04
|
LastEditTime : 2024-08-29 09:11:16
|
||||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ class KLinearCPUInfer(KLinearBase):
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
origin_shape = x.shape # [batch_size, q_len, hidden_size]
|
origin_shape = x.shape # [batch_size, q_len, hidden_size]
|
||||||
if origin_shape[1] == 1:
|
if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():
|
||||||
out_device = x.device
|
out_device = x.device
|
||||||
self.input_tensor_cpu.copy_(x, non_blocking=True)
|
self.input_tensor_cpu.copy_(x, non_blocking=True)
|
||||||
qlen = origin_shape[1]
|
qlen = origin_shape[1]
|
||||||
|
|
|
@ -670,11 +670,12 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||||
if self.transfer_map is not None and i in self.transfer_map:
|
if self.transfer_map is not None and i in self.transfer_map:
|
||||||
prev_stream = torch.cuda.current_stream()
|
prev_stream = torch.cuda.current_stream()
|
||||||
cur_device = self.transfer_map[i]
|
cur_device = self.transfer_map[i]
|
||||||
if cur_device not in self.stream_device_map:
|
if cur_device not in self.stream_device_map and cur_device.lower() != "cpu":
|
||||||
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||||
torch.cuda.set_device(cur_device)
|
if cur_device.lower() != "cpu":
|
||||||
self.stream_device_map[cur_device].wait_stream(prev_stream)
|
torch.cuda.set_device(cur_device)
|
||||||
torch.cuda.set_stream(self.stream_device_map[cur_device])
|
self.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||||
|
torch.cuda.set_stream(self.stream_device_map[cur_device])
|
||||||
hidden_states = hidden_states.to(
|
hidden_states = hidden_states.to(
|
||||||
self.transfer_map[i], non_blocking=True
|
self.transfer_map[i], non_blocking=True
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue