mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 08:39:42 +00:00
support npu
This commit is contained in:
parent
a641aa8063
commit
b982815325
22 changed files with 162 additions and 1562 deletions
|
@ -14,9 +14,9 @@ from ktransformers.util.ascend.ascend_utils import (
|
|||
get_tensor_parallel_group
|
||||
)
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
|
||||
from ktransformers.util.custom_loader import translate_name_to_gguf
|
||||
|
||||
class KLinearW8A8(KLinearBase):
|
||||
def __init__(
|
||||
|
@ -39,6 +39,11 @@ class KLinearW8A8(KLinearBase):
|
|||
for key in keys:
|
||||
if device is None:
|
||||
device = utils.CUR_DEVICE
|
||||
|
||||
key = translate_name_to_gguf(key)
|
||||
if key == "lm_head":
|
||||
key = "output"
|
||||
|
||||
if key + ".weight" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
if key + ".deq_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
qweight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
||||
|
@ -47,25 +52,25 @@ class KLinearW8A8(KLinearBase):
|
|||
input_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_scale")
|
||||
input_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.input_offset")
|
||||
tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset)
|
||||
print(f"Loading {key} with shape {qweight.shape}, {deq_scale.shape}, {quant_bias.shape}, {input_scale.shape}, {input_offset.shape}")
|
||||
print(tensors)
|
||||
return tensors
|
||||
elif key + ".weight_scale" in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
if key.endswith("ffn_gate_shexp"):
|
||||
parts = key.split(".")
|
||||
layer = parts[1]
|
||||
gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight")
|
||||
gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t()
|
||||
up_weight = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight")
|
||||
up_weight = get_safetensors_cut_weight(self.key, up_weight).t()
|
||||
gate_up_weight = torch.cat((gate_weight, up_weight), 0)
|
||||
gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_scale")
|
||||
gate_scale = get_safetensors_cut_weight(self.key, gate_scale)
|
||||
up_scale = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_scale")
|
||||
up_scale = get_safetensors_cut_weight(self.key, up_scale)
|
||||
gate_up_weight = torch.cat((gate_weight, up_weight), 1)
|
||||
gate_up_scale = torch.cat((gate_scale, up_scale), 0)
|
||||
gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_gate_shexp.weight_offset")
|
||||
up_offset = self.gguf_loader.safetensor_loader.load_tensor(f"blk.{layer}.ffn_up_shexp.weight_offset")
|
||||
gate_up_offset = torch.cat((gate_offset, up_offset), 0)
|
||||
tensors = (gate_up_weight, gate_up_scale, gate_up_offset)
|
||||
print(f"Loading {key} as ffn_gate_shexp with shape {gate_up_weight.shape}, {gate_up_scale.shape}, {gate_up_offset.shape}")
|
||||
print(tensors)
|
||||
elif key.endswith("ffn_up_shexp"):
|
||||
return fake_tensor
|
||||
else:
|
||||
|
@ -73,10 +78,11 @@ class KLinearW8A8(KLinearBase):
|
|||
weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_scale")
|
||||
weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight_offset")
|
||||
tensors = (qweight, weight_scale, weight_offset)
|
||||
print(f"Loading {key} with shape {qweight.shape}, {weight_scale.shape}, {weight_offset.shape}")
|
||||
print(tensors)
|
||||
return tensors
|
||||
else:
|
||||
weight = self.gguf_loader.safetensor_loader.load_tensor(f"{key}.weight")
|
||||
weight = get_safetensors_cut_weight(self.key, weight)
|
||||
return weight
|
||||
else:
|
||||
raise FileNotFoundError(f"Weight file not found for key {key}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue