support smt and glm4

This commit is contained in:
djw 2025-07-25 15:03:27 +00:00
parent 48bc6185b5
commit 17246bf84f
7 changed files with 129 additions and 16 deletions

View file

@ -24,7 +24,7 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
@ -138,6 +138,7 @@ class Engine:
elif args.model_name == "SmallThinkerForCausalLM":
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
config._attn_implementation = "eager"
config.moe_intermediate_size = config.moe_ffn_hidden_size
else:
try:
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
@ -164,7 +165,7 @@ class Engine:
self.model = KQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallThinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallthinkerForCausalLM(config, self.cache)
self.model = KSmallThinkerForCausalLM(config, self.cache)
elif config.architectures[0] == "Glm4MoeForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KGlm4MoeForCausalLM(config, self.cache)
@ -462,8 +463,8 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
99052, 101052, 11314]], device='cuda')
# input_ids = torch.tensor([[151331, 151333, 98964, 117392, 103408, 99668, 3837, 99073, 99444,
# 99052, 101052, 11314]], device='cuda')
query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0]
query_add.query_length = query_length