mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -16,7 +16,7 @@ import torch
|
|||
from torch import Tensor, nn
|
||||
import KTransformersOps
|
||||
import vLLMMarlin
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
|
||||
MarlinWorkspace,
|
||||
|
@ -83,15 +83,15 @@ class KLinearBase(ABC):
|
|||
keys = [self.key]
|
||||
|
||||
for key in keys:
|
||||
if self.gguf_loader.safetensor_loader is not None:
|
||||
if isinstance(self.gguf_loader, SafeTensorLoader):
|
||||
# using safetensor_loader
|
||||
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
|
||||
if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
|
||||
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
|
||||
tensor = self.gguf_loader.load_tensor(key+'.weight')
|
||||
if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):
|
||||
weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')
|
||||
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
|
||||
return nn.Parameter(tensor)
|
||||
|
||||
elif key + ".weight" in self.gguf_loader.tensor_file_map:
|
||||
elif self.gguf_loader.has_tensor(key + ".weight"):
|
||||
if key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
tensors = self.load_multi(key, ["weight", "bias"], device=device)
|
||||
tensor = tensors["weight"]
|
||||
|
@ -760,7 +760,7 @@ class KLinearCPUInfer(KLinearBase):
|
|||
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
|
||||
|
||||
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
|
||||
if self.key + ".weight" in self.gguf_loader.tensor_info:
|
||||
if self.gguf_loader.has_tensor(self.key + ".weight"):
|
||||
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
|
||||
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
|
||||
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue