support npu

This commit is contained in:
Dongjw 2025-07-23 09:54:55 +00:00
parent a641aa8063
commit b982815325
22 changed files with 162 additions and 1562 deletions

View file

@ -14,9 +14,9 @@ from ktransformers.util.ascend.ascend_utils import (
get_tensor_parallel_group
)
from ktransformers.util import utils
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.custom_loader import GGUFLoader
from ktransformers.util.utils import InferenceState
from ktransformers.util.custom_loader import translate_name_to_gguf
class KLinearW8A8(KLinearBase):
def __init__(
@ -39,6 +39,11 @@ class KLinearW8A8(KLinearBase):
for key in keys:
if device is None:
device = utils.CUR_DEVICE
key = translate_name_to_gguf(key)
if key == "lm_head":
key = "output"
if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map:
if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
@ -47,25 +52,25 @@ class KLinearW8A8(KLinearBase):
input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale")
input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset")
tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset)
print(f"Loading {key} with shape {qweight.shape}, {deq_scale.shape}, {quant_bias.shape}, {input_scale.shape}, {input_offset.shape}")
print(tensors)
return tensors
elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
if key.endswith("ffn_gate_shexp"):
parts = key.split(".")
layer = parts[1]
gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight")
gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t()
up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight")
up_weight = get_safetensors_cut_weight(self.key, up_weight).t()
gate_up_weight = torch.cat((gate_weight, up_weight), 0)
gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale")
gate_scale = get_safetensors_cut_weight(self.key, gate_scale)
up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale")
up_scale = get_safetensors_cut_weight(self.key, up_scale)
gate_up_weight = torch.cat((gate_weight, up_weight), 1)
gate_up_scale = torch.cat((gate_scale, up_scale), 0)
gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset")
up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset")
gate_up_offset = torch.cat((gate_offset, up_offset), 0)
tensors = (gate_up_weight, gate_up_scale, gate_up_offset)
print(f"Loading {key} as ffn_gate_shexp with shape {gate_up_weight.shape}, {gate_up_scale.shape}, {gate_up_offset.shape}")
print(tensors)
elif key.endswith("ffn_up_shexp"):
return fake_tensor
else:
@ -73,10 +78,11 @@ class KLinearW8A8(KLinearBase):
weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale")
weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset")
tensors = (qweight, weight_scale, weight_offset)
print(f"Loading {key} with shape {qweight.shape}, {weight_scale.shape}, {weight_offset.shape}")
print(tensors)
return tensors
else:
weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
weight = get_safetensors_cut_weight(self.key, weight)
return weight
else:
raise FileNotFoundError(f"Weight file not found for key {key}")