mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -22,8 +22,7 @@ from transformers import (
|
|||
EtaLogitsWarper,
|
||||
)
|
||||
|
||||
from ktransformers.util.custom_gguf import translate_name_to_gguf
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, GGUFLoader, translate_name_to_gguf
|
||||
from ktransformers.operators import base_operator
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
|
@ -98,25 +97,24 @@ def get_all_used_cuda_device(device_map:dict):
|
|||
all_device_list = list(all_device_list)
|
||||
return all_device_list
|
||||
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
|
||||
def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""):
|
||||
prefix = prefix.replace("orig_module.", "")
|
||||
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
translated_key = translate_name_to_gguf(key)
|
||||
translated_key = key
|
||||
|
||||
# TODO: Merge all loader.
|
||||
# I know this is ugly but lets do it for now.
|
||||
if gguf_loader.safetensor_loader is not None:
|
||||
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
|
||||
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
|
||||
if isinstance(gguf_loader, SafeTensorLoader):
|
||||
load_dequantized_tensor = gguf_loader.load_dequantized_tensor
|
||||
else:
|
||||
load_dequantized_tensor = gguf_loader.load_gguf_tensor
|
||||
tensor_file_map = gguf_loader.tensor_file_map
|
||||
|
||||
if translated_key in tensor_file_map:
|
||||
if gguf_loader.has_tensor(translated_key):
|
||||
target_dtype = torch.get_default_dtype()
|
||||
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
|
||||
print(f"loading {translated_key} to {device}")
|
||||
|
@ -128,7 +126,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
|
|||
#print(load_config.tensor_file_map.keys())
|
||||
raise Exception(f"can't find {translated_key} in GGUF file!")
|
||||
|
||||
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
||||
def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''):
|
||||
#print(f"recursively loading weights {prefix}")
|
||||
if not isinstance(module, base_operator.BaseInjectedModule):
|
||||
load_cur_state_dict(module, gguf_loader, prefix)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue