mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support glm4moe
This commit is contained in:
parent
1677e90092
commit
d03d92ba53
31 changed files with 2265 additions and 74 deletions
|
@ -29,6 +29,8 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
|
|||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
|
||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
|
||||
|
@ -53,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list:
|
|||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
|
||||
|
@ -93,7 +95,7 @@ class ModelRunner:
|
|||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
|
@ -124,7 +126,7 @@ class ModelRunner:
|
|||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
|
||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
||||
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue