support smt and glm4

This commit is contained in:
djw 2025-07-24 08:40:58 +00:00
parent 1677e90092
commit b66d96db97
18 changed files with 3519 additions and 16 deletions

View file

@ -24,7 +24,11 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
@ -60,6 +64,8 @@ default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
}
@ -123,15 +129,24 @@ class Engine:
self.sched_client = SchedulerClient(args.sched_port)
self.updates = []
try:
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
except:
if args.model_name == "Qwen3Moe":
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
else:
assert False, f"model {args.model_name} not supported"
print(f"args.model_name: {args.model_name}")
if args.model_name == "Qwen3MoeForCausalLM":
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "Glm4MoeForCausalLM":
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
elif args.model_name == "SmallthinkerForCausalLM":
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
config._attn_implementation = "eager"
else:
try:
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
except:
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
self.gen_queue = generated_token_queue
with torch.device("meta"):
@ -147,6 +162,13 @@ class Engine:
self.model = KQwen2MoeForCausalLM(config, self.cache)
else:
self.model = KQwen3MoeForCausalLM(config, self.cache)
elif config.architectures[0] == "SmallthinkerForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KSmallthinkerForCausalLM(config, self.cache)
elif config.architectures[0] == "Glm4MoeForCausalLM":
self.cache = KGQACache(config, self.args.page_size)
self.model = KGlm4MoeForCausalLM(config, self.cache)
context = zmq.Context()
@ -197,7 +219,7 @@ class Engine:
self.block_num = inference_context.k_cache[0].size(1)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
#@TODO add config
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
else:
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)

View file

@ -29,6 +29,8 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.settings import sched_ext
@ -53,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list:
class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallthinkerForCausalLM | KGlm4MoeForCausalLM
input: ForwardBatchInput | list[ForwardBatchInput]
output: ForwardBatchOutput
@ -93,7 +95,7 @@ class ModelRunner:
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
@ -124,7 +126,7 @@ class ModelRunner:
num_tokens = self.features_buf[i][0].size(0)
print("capturing cuda graph", batch_size, num_tokens)
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
self.bsz_tensor_buf[0] = batch_size