mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -12,7 +12,7 @@ from torch import nn
|
|||
from transformers import AutoConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
# from operators import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
|
||||
from ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory
|
||||
from ktransformers.util.utils import set_module, load_weights
|
||||
import itertools
|
||||
import copy
|
||||
|
@ -54,7 +54,7 @@ def del_meta(module:nn.Module):
|
|||
|
||||
def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"):
|
||||
module_name = prefix[:-1]
|
||||
translated_name = translate_name_to_gguf(prefix)[:-1]
|
||||
# translated_name = translate_name_to_gguf(prefix)[:-1]
|
||||
#print("gen_optimize_config", prefix, module_name, translated_name)
|
||||
recursive = True
|
||||
for rule in rule_list:
|
||||
|
@ -76,7 +76,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
|||
if "replace" in rule:
|
||||
replace_meta = rule["replace"]
|
||||
if module_name not in out_data:
|
||||
out_data[module_name]={"key": translated_name,
|
||||
out_data[module_name]={"key": module_name,
|
||||
"class": replace_meta["class"] if "class" in replace_meta else "default",
|
||||
# "device": replace_meta["device"] if "device" in replace_meta else default_device,
|
||||
"kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()}
|
||||
|
@ -91,7 +91,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
|||
if module_name not in out_data:
|
||||
out_data[module_name]= {
|
||||
"class": "default",
|
||||
"key": translated_name,
|
||||
"key": module_name,
|
||||
"kwargs": {"generate_device": default_device,
|
||||
"prefill_device": default_device}
|
||||
}
|
||||
|
@ -123,12 +123,12 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
|
|||
|
||||
model_config = translate_model_config(model_config)
|
||||
|
||||
gguf_loader=GGUFLoader(gguf_path)
|
||||
weights_loader = ModelLoaderFactory.create_loader(gguf_path)
|
||||
with torch.device("meta"):
|
||||
inject(module, optimize_config, model_config, gguf_loader)
|
||||
inject(module, optimize_config, model_config, weights_loader)
|
||||
# pre load lm_head because its big inter result
|
||||
load_weights(module.lm_head, gguf_loader, "lm_head.")
|
||||
load_weights(module, gguf_loader)
|
||||
module.gguf_loader = gguf_loader
|
||||
load_weights(module.lm_head, weights_loader, "lm_head.")
|
||||
load_weights(module, weights_loader)
|
||||
module.gguf_loader = weights_loader
|
||||
del_meta(module)
|
||||
torch.cuda.empty_cache()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue