support smt and glm4

This commit is contained in:
djw 2025-07-25 15:03:27 +00:00
parent 48bc6185b5
commit 17246bf84f
7 changed files with 129 additions and 16 deletions

View file

@ -143,7 +143,7 @@ class ArgumentParser:
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "Glm4MoeForCausalLM":
model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "SmallthinkerForCausalLM":
elif args.model_name == "SmallThinkerForCausalLM":
model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
model_config._attn_implementation = "eager"
else:
@ -153,7 +153,7 @@ class ArgumentParser:
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallthinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM":
if model_config.architectures[0] == "Qwen3MoeForCausalLM" or model_config.architectures[0] == "Qwen2MoeForCausalLM" or model_config.architectures[0] == "SmallThinkerForCausalLM" or model_config.architectures[0] == "Glm4MoeForCausalLM":
args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim
args.architectures = model_config.architectures[0]
else: