support safetensor load, delete architectures argument

This commit is contained in:
qiyuxinlin 2025-05-09 10:38:29 +00:00
parent 900a7f7c3e
commit c6aa379de2
30 changed files with 1075 additions and 328 deletions

View file

@ -12,7 +12,7 @@ from torch import nn
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
# from operators import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
from ktransformers.util.utils import set_module, load_weights
import itertools
import copy
@ -54,7 +54,7 @@ def del_meta(module:nn.Module):
def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"):
module_name = prefix[:-1]
translated_name = translate_name_to_gguf(prefix)[:-1]
# translated_name = translate_name_to_gguf(prefix)[:-1]
#print("gen_optimize_config", prefix, module_name, translated_name)
recursive = True
for rule in rule_list:
@ -76,7 +76,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
if "replace" in rule:
replace_meta = rule["replace"]
if module_name not in out_data:
out_data[module_name]={"key": translated_name,
out_data[module_name]={"key": module_name,
"class": replace_meta["class"] if "class" in replace_meta else "default",
# "device": replace_meta["device"] if "device" in replace_meta else default_device,
"kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()}
@ -91,7 +91,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
if module_name not in out_data:
out_data[module_name]= {
"class": "default",
"key": translated_name,
"key": module_name,
"kwargs": {"generate_device": default_device,
"prefill_device": default_device}
}
@ -123,12 +123,12 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
model_config = translate_model_config(model_config)
gguf_loader=GGUFLoader(gguf_path)
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader)
inject(module, optimize_config, model_config, weights_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader
load_weights(module.lm_head, weights_loader, "lm_head.")
load_weights(module, weights_loader)
module.gguf_loader = weights_loader
del_meta(module)
torch.cuda.empty_cache()