smallthinker right

This commit is contained in:
qiyuxinlin 2025-07-25 12:46:14 +00:00
parent f8719ee7b9
commit 712ad1fa3c
7 changed files with 48 additions and 108 deletions

View file

@ -64,7 +64,7 @@ default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
"SmallThinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
}
@ -135,7 +135,7 @@ class Engine:
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "Glm4MoeForCausalLM":
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "SmallthinkerForCausalLM":
elif args.model_name == "SmallThinkerForCausalLM":
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
config._attn_implementation = "eager"
else:
@ -162,7 +162,7 @@ class Engine:
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallthinkerForCausalLM":
elif config.architectures[0] == "SmallThinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallthinkerForCausalLM(config, self.cache)
elif config.architectures[0] == "Glm4MoeForCausalLM":
@ -219,7 +219,7 @@ class Engine:
self.block_num = inference_context.k_cache[0].size(1)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)