kvcache-ai-ktransformers/ktransformers/optimize/optimize.py
2024-07-27 16:06:58 +08:00

102 lines
4.9 KiB
Python

'''
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
from typing import Mapping, List
import torch
import yaml
import re
from torch import nn
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
# from operators import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf
from ktransformers.util.utils import set_module, load_weights
import itertools
def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name
if child_prefix in local_optimization_dict:
inject_module_meta=local_optimization_dict[child_prefix]
if isinstance(inject_module_meta, Mapping):
import_path = inject_module_meta["class"].split(".")
import_module_name = ".".join(import_path[:-1])
import_class_name = import_path[-1]
module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name)
print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name)
inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, device = inject_module_meta["device"], **inject_module_meta["kwargs"])
set_module(module, name, inject_module)
elif isinstance(inject_module_meta, str):
assert inject_module_meta=="default", "for str inject_module_meta, only support \"default\"."
else:
raise Exception("inject_module_meta must be a dict or str")
child_prefix += "."
child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)}
inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix)
def del_meta(module:nn.Module):
#print("default loading weights", prefix)
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
if param.device == "meta" or param.device == torch.device("meta"):
module.__delattr__(name)
for name, child in module._modules.items():
del_meta(child)
def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"):
module_name = prefix[:-1]
translated_name = translate_name_to_gguf(prefix)[:-1]
#print("gen_optimize_config", prefix, module_name, translated_name)
recursive = True
for rule in rule_list:
#print(rule)
match_meta = rule["match"]
if "class" in match_meta:
import_path = match_meta["class"].split(".")
import_module_name = ".".join(import_path[:-1])
import_class_name = import_path[-1]
module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name)
if not isinstance(module, module_cls):
continue
if "name" in match_meta:
if re.search(match_meta["name"], module_name) is None:
continue
replace_meta = rule["replace"]
out_data[module_name]={"key": translated_name,
"class": replace_meta["class"],
"device": replace_meta["device"] if "device" in replace_meta else default_device,
"kwargs": replace_meta["kwargs"] if "kwargs" in replace_meta else dict()}
if "recursive" in rule:
recursive = bool(rule["recursive"])
if module_name not in out_data:
out_data[module_name]="default"
#print(out_data[module_name])
#input()
if recursive:
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
gen_optimize_config(child, out_data, rule_list, child_prefix)
def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = "cuda:0"):
with open(rule_file, 'r', encoding='utf-8') as f:
rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)
optimize_config = dict()
gen_optimize_config(module, optimize_config, rule_list, default_device = default_device)
gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader)
load_weights(module, gguf_loader)
del_meta(module)