diff --git a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu index 3ecaeb0..73ba3dd 100644 --- a/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu +++ b/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu @@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int* locks // extra global storage for barrier synchronization ) { int prob_m = *prob_m_ptr; + prob_m = min(prob_m, 1024); const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); if(prob_m > 16 * thread_m_blocks) prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks)); diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py index 4a24fc9..a785413 100644 --- a/ktransformers/operators/balance_serve_attention.py +++ b/ktransformers/operators/balance_serve_attention.py @@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention): ): q_len, _ = hidden_states.size() - query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors) - key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors) + bsz_tensors_q = bsz_tensors * self.num_heads + bsz_tensors_kv = bsz_tensors * self.num_key_value_heads + + query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q) + key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv) value_states = self.v_proj(hidden_states, bsz_tensors) diff --git a/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml b/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml index 41b41a7..27dba2b 100644 --- a/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml @@ -56,6 +56,7 @@ generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" + backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default) recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" diff --git a/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml b/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml index 63f67da..fb9d125 100644 --- a/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml +++ b/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml @@ -56,6 +56,7 @@ generate_device: "cpu" generate_op: "KExpertsCPU" out_device: "cuda" + backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default) recursive: False # don't recursively inject submodules of this module - match: name: "^model\\.layers\\..*\\.self_attn$" diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py index 03e18d1..0193576 100644 --- a/ktransformers/server/balance_serve/inference/model_runner.py +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -85,7 +85,7 @@ class ModelRunner: elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM): self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, - head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads, + head_dim=128, page_size=self.model.cache.page_size, causal=True, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx) else: