mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 00:29:59 +00:00
support npu
This commit is contained in:
parent
a641aa8063
commit
b982815325
22 changed files with 162 additions and 1562 deletions
|
@ -135,10 +135,47 @@ def get_all_used_cuda_device(device_map:dict):
|
|||
all_device_list = list(all_device_list)
|
||||
return all_device_list
|
||||
|
||||
def get_current_device():
|
||||
return f"npu:{torch.npu.current_device()}"
|
||||
|
||||
|
||||
# TODO: support NPU
|
||||
def load_cur_state_dict_npu(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="npu"):
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
translated_key = translate_name_to_gguf(key)
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if gguf_loader.safetensor_loader is not None:
|
||||
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
||||
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
||||
else:
|
||||
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||
tensor_file_map = gguf_loader.tensor_file_map
|
||||
|
||||
if translated_key in tensor_file_map:
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
# Todo need fix
|
||||
device = "cpu" if "embd" in translated_key else get_current_device()
|
||||
print(f"loading layer {translated_key} to {device}")
|
||||
torch.cuda.empty_cache()
|
||||
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
|
||||
set_param(module, name, weights)
|
||||
del weights
|
||||
else:
|
||||
#print(load_config.tensor_file_map.keys())
|
||||
raise Exception(f"can't find {translated_key} in GGUF file!")
|
||||
|
||||
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
|
||||
if use_torch_npu:
|
||||
load_cur_state_dict_npu(module, gguf_loader, prefix, device)
|
||||
return
|
||||
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
||||
|
@ -214,7 +251,7 @@ def xpu_fp16_model(config):
|
|||
return False
|
||||
|
||||
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
|
||||
#print(f"recursively loading weights {prefix}")
|
||||
# print(f"recursively loading weights {prefix}")
|
||||
if not isinstance(module, base_operator.BaseInjectedModule):
|
||||
load_cur_state_dict(module, gguf_loader, prefix, device=device)
|
||||
for name, child in module._modules.items():
|
||||
|
@ -314,6 +351,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=False, use_cache=True)[0]
|
||||
print(logits)
|
||||
if past_key_values != None:
|
||||
past_key_values.change_seq_length(1)
|
||||
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
|
||||
|
@ -361,7 +399,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
if past_key_values != None and isinstance(past_key_values, StaticCache):
|
||||
past_key_values.change_seq_length(1)
|
||||
sync_all_device(all_cuda_device)
|
||||
#print(logits)
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
@ -410,6 +447,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
|
||||
cache_position, past_key_values, logits_warper, generation_config,
|
||||
use_cuda_graph).to(torch_device)
|
||||
|
||||
print(next_token)
|
||||
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
|
@ -596,8 +636,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
seq_length += 1
|
||||
if use_torch_npu:
|
||||
past_key_values.position += 1
|
||||
|
||||
|
||||
cuda_graph_runner = None
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue