support safetensor load, delete architectures argument

This commit is contained in:
qiyuxinlin 2025-05-09 10:38:29 +00:00
parent 900a7f7c3e
commit c6aa379de2
30 changed files with 1075 additions and 328 deletions

View file

@ -6,7 +6,7 @@ import os
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.linear import KTransformersLinear
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader
from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod
@ -55,24 +55,20 @@ class KMoEGateBase(ABC):
down_type = None
for key in keys:
key = ".".join(key.split(".")[:-1])
if self.gguf_loader.safetensor_loader is not None:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
weight_type = weight.dtype
e_score_correction_bias_type = e_score_correction_bias.dtype
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
# key = ".".join(key.split(".")[:-1])
if isinstance(self.gguf_loader, SafeTensorLoader):
res = self.gguf_loader.load_gate(key, device=device)
elif self.gguf_loader.has_tensor(key+".weight"):
# targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
targets = [".weight", ".e_score_correction_bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"]
e_score_correction_bias = tensors[".exp_probs_b.bias"]
weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"]
e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"]
weight = tensors[".weight"]
e_score_correction_bias = tensors[".e_score_correction_bias"]
# weight_type = self.gguf_loader.tensor_info[key + ".weight"]["ggml_type"]
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias}
else:
raise ValueError(f"Experts {key} not found in gguf_loader")
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
return res
def load_multi(self, key: str, keys: list[str], device: str = "cpu"):
@ -106,8 +102,6 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
@ -175,8 +169,6 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else: