mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
use marlin for lm_head, lm_head only calc last token for prefill extend context window to 19K for DeepSeek-V3/R1 within 24GB VRAM
134 lines
6.6 KiB
Python
134 lines
6.6 KiB
Python
'''
|
|
Description :
|
|
Author : Boxin Zhang, Azure-Tang
|
|
Version : 0.1.0
|
|
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
|
'''
|
|
from typing import Mapping, List
|
|
import torch
|
|
import yaml
|
|
import re
|
|
from torch import nn
|
|
from transformers import AutoConfig
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
# from operators import BaseInjectedModule
|
|
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
|
|
from ktransformers.util.utils import set_module, load_weights
|
|
import itertools
|
|
import copy
|
|
|
|
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
child_prefix = prefix + name
|
|
if child_prefix in local_optimization_dict:
|
|
inject_module_meta=local_optimization_dict[child_prefix]
|
|
if inject_module_meta["class"] != "default":
|
|
import_path = inject_module_meta["class"].split(".")
|
|
import_module_name = ".".join(import_path[:-1])
|
|
gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict()
|
|
import_class_name = import_path[-1]
|
|
module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name)
|
|
print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name)
|
|
inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta["kwargs"])
|
|
set_module(module, name, inject_module)
|
|
elif inject_module_meta["class"] == "default":
|
|
print(f"Injecting {child_prefix} as default")
|
|
gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict()
|
|
else:
|
|
raise Exception("inject_module_meta[\"class\"] must be \"default\" or a class path")
|
|
child_prefix += "."
|
|
child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)}
|
|
inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix)
|
|
|
|
def del_meta(module:nn.Module):
|
|
#print("default loading weights", prefix)
|
|
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
|
|
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
|
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
|
for name, param in local_state.items():
|
|
if param.device == "meta" or param.device == torch.device("meta"):
|
|
module.__delattr__(name)
|
|
for name, child in module._modules.items():
|
|
del_meta(child)
|
|
|
|
def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"):
|
|
module_name = prefix[:-1]
|
|
translated_name = translate_name_to_gguf(prefix)[:-1]
|
|
#print("gen_optimize_config", prefix, module_name, translated_name)
|
|
recursive = True
|
|
for rule in rule_list:
|
|
match_meta = rule["match"]
|
|
if "class" not in match_meta and "name" not in match_meta:
|
|
raise Exception("match must have at least one of \"class\" and \"name\"")
|
|
if "class" in match_meta:
|
|
import_path = match_meta["class"].split(".")
|
|
import_module_name = ".".join(import_path[:-1])
|
|
import_class_name = import_path[-1]
|
|
module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name)
|
|
if not isinstance(module, module_cls):
|
|
continue
|
|
if "name" in match_meta:
|
|
if re.search(match_meta["name"], module_name) is None:
|
|
continue
|
|
if "replace" not in rule:
|
|
raise Exception("replace must be in rule")
|
|
if "replace" in rule:
|
|
replace_meta = rule["replace"]
|
|
if module_name not in out_data:
|
|
out_data[module_name]={"key": translated_name,
|
|
"class": replace_meta["class"] if "class" in replace_meta else "default",
|
|
# "device": replace_meta["device"] if "device" in replace_meta else default_device,
|
|
"kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()}
|
|
else:
|
|
if out_data[module_name]["class"] == "default":
|
|
out_data[module_name]["class"] = replace_meta["class"] if "class" in replace_meta else "default"
|
|
out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict())
|
|
if "recursive" in rule:
|
|
recursive = bool(rule["recursive"])
|
|
break
|
|
|
|
if module_name not in out_data:
|
|
out_data[module_name]= {
|
|
"class": "default",
|
|
"key": translated_name,
|
|
"kwargs": {"generate_device": default_device,
|
|
"prefill_device": default_device}
|
|
}
|
|
|
|
#print(out_data[module_name])
|
|
#input()
|
|
|
|
if recursive:
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
child_prefix = prefix + name + "."
|
|
gen_optimize_config(child, out_data, rule_list, child_prefix)
|
|
|
|
|
|
def translate_model_config(model_config: PretrainedConfig):
|
|
# for supporting some special model
|
|
if model_config.model_type == "mixtral":
|
|
model_config.moe_intermediate_size = model_config.intermediate_size
|
|
|
|
return model_config
|
|
|
|
|
|
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
|
|
with open(rule_file, 'r', encoding='utf-8') as f:
|
|
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
|
|
|
|
optimize_config = dict()
|
|
gen_optimize_config(module, optimize_config, rule_list, default_device = default_device)
|
|
|
|
model_config = translate_model_config(model_config)
|
|
|
|
gguf_loader=GGUFLoader(gguf_path)
|
|
with torch.device("meta"):
|
|
inject(module, optimize_config, model_config, gguf_loader)
|
|
# pre load lm_head because its big inter result
|
|
load_weights(module.lm_head, gguf_loader, "lm_head.")
|
|
load_weights(module, gguf_loader)
|
|
module.gguf_loader = gguf_loader
|
|
del_meta(module)
|
|
torch.cuda.empty_cache()
|