mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -6,7 +6,7 @@ import os
|
|||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.operators.linear import KTransformersLinear
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
@ -55,24 +55,20 @@ class KMoEGateBase(ABC):
|
|||
down_type = None
|
||||
|
||||
for key in keys:
|
||||
key = ".".join(key.split(".")[:-1])
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
|
||||
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
|
||||
weight_type = weight.dtype
|
||||
e_score_correction_bias_type = e_score_correction_bias.dtype
|
||||
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
|
||||
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
|
||||
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
# key = ".".join(key.split(".")[:-1])
|
||||
if isinstance(self.gguf_loader, SafeTensorLoader):
|
||||
res = self.gguf_loader.load_gate(key, device=device)
|
||||
elif self.gguf_loader.has_tensor(key+".weight"):
|
||||
# targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
|
||||
targets = [".weight", ".e_score_correction_bias"]
|
||||
tensors = self.load_multi(key, targets, device=device)
|
||||
weight = tensors[".ffn_gate_inp.weight"]
|
||||
e_score_correction_bias = tensors[".exp_probs_b.bias"]
|
||||
weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"]
|
||||
e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"]
|
||||
weight = tensors[".weight"]
|
||||
e_score_correction_bias = tensors[".e_score_correction_bias"]
|
||||
# weight_type = self.gguf_loader.tensor_info[key + ".weight"]["ggml_type"]
|
||||
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias}
|
||||
else:
|
||||
raise ValueError(f"Experts {key} not found in gguf_loader")
|
||||
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
|
||||
|
||||
return res
|
||||
|
||||
def load_multi(self, key: str, keys: list[str], device: str = "cpu"):
|
||||
|
@ -106,8 +102,6 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
if w is None: w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
|
@ -175,8 +169,6 @@ class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
|
|||
if w is None: w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue