support npu

This commit is contained in:
Dongjw 2025-07-23 09:54:55 +00:00
parent a641aa8063
commit b982815325
22 changed files with 162 additions and 1562 deletions

View file

@ -135,10 +135,47 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list = list(all_device_list)
return all_device_list
def get_current_device():
return f"npu:{torch.npu.current_device()}"
# TODO: support NPU
def load_cur_state_dict_npu(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="npu"):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
translated_key = translate_name_to_gguf(key)
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if gguf_loader.safetensor_loader is not None:
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
if translated_key in tensor_file_map:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
# Todo need fix
device = "cpu" if "embd" in translated_key else get_current_device()
print(f"loading layer {translated_key} to {device}")
torch.cuda.empty_cache()
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else:
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!")
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = "", device="cuda"):
if use_torch_npu:
load_cur_state_dict_npu(module, gguf_loader, prefix, device)
return
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
@ -214,7 +251,7 @@ def xpu_fp16_model(config):
return False
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device="cuda"):
#print(f"recursively loading weights {prefix}")
# print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix, device=device)
for name, child in module._modules.items():
@ -314,6 +351,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
print(logits)
if past_key_values != None:
past_key_values.change_seq_length(1)
all_cuda_device = ['npu:' + str(index) for index in range(torch.distributed.get_world_size())]
@ -361,7 +399,6 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
if past_key_values != None and isinstance(past_key_values, StaticCache):
past_key_values.change_seq_length(1)
sync_all_device(all_cuda_device)
#print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -410,6 +447,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids,
cache_position, past_key_values, logits_warper, generation_config,
use_cuda_graph).to(torch_device)
print(next_token)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(int(next_token))
@ -596,8 +636,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0)
seq_length += 1
if use_torch_npu:
past_key_values.position += 1
cuda_graph_runner = None