mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support safetensor load, delete architectures argument
This commit is contained in:
parent
900a7f7c3e
commit
c6aa379de2
30 changed files with 1075 additions and 328 deletions
|
@ -128,10 +128,7 @@ class ArgumentParser:
|
|||
else:
|
||||
args.model_dir = self.cfg.model_dir
|
||||
args.model_path = self.cfg.model_path
|
||||
# set config from args
|
||||
for key, value in vars(args).items():
|
||||
if value is not None and hasattr(self.cfg, key):
|
||||
setattr(self.cfg, key, value)
|
||||
|
||||
# we add the name not match args individually
|
||||
self.cfg.model_device = args.device
|
||||
self.cfg.mount_web = args.web
|
||||
|
@ -140,10 +137,15 @@ class ArgumentParser:
|
|||
self.cfg.user_force_think = args.force_think
|
||||
|
||||
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
if args.architectures == "Qwen3MoeForCausalLM" or args.architectures == "Qwen2MoeForCausalLM" :
|
||||
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" :
|
||||
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
|
||||
args.architectures = model_config.architectures[0]
|
||||
else:
|
||||
args.gpu_memory_size = args.cache_lens*2*576*61
|
||||
# set config from args
|
||||
for key, value in vars(args).items():
|
||||
if value is not None and hasattr(self.cfg, key):
|
||||
setattr(self.cfg, key, value)
|
||||
self.cfg.gpu_memory_size = args.gpu_memory_size
|
||||
free_ports = get_free_ports(3, [args.port])
|
||||
args.sched_port = free_ports[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue