support qwen3

This commit is contained in:
djw 2025-04-28 14:05:24 +00:00
parent 3f9bbf1181
commit 0da3792b27
5 changed files with 9 additions and 3 deletions

View file

@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int* locks // extra global storage for barrier synchronization int* locks // extra global storage for barrier synchronization
) { ) {
int prob_m = *prob_m_ptr; int prob_m = *prob_m_ptr;
prob_m = min(prob_m, 1024);
const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks); const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
if(prob_m > 16 * thread_m_blocks) if(prob_m > 16 * thread_m_blocks)
prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks)); prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));

View file

@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
): ):
q_len, _ = hidden_states.size() q_len, _ = hidden_states.size()
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors) bsz_tensors_q = bsz_tensors * self.num_heads
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors) bsz_tensors_kv = bsz_tensors * self.num_key_value_heads
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q)
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv)
value_states = self.v_proj(hidden_states, bsz_tensors) value_states = self.v_proj(hidden_states, bsz_tensors)

View file

@ -56,6 +56,7 @@
generate_device: "cpu" generate_device: "cpu"
generate_op: "KExpertsCPU" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\..*\\.self_attn$" name: "^model\\.layers\\..*\\.self_attn$"

View file

@ -56,6 +56,7 @@
generate_device: "cpu" generate_device: "cpu"
generate_op: "KExpertsCPU" generate_op: "KExpertsCPU"
out_device: "cuda" out_device: "cuda"
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match: - match:
name: "^model\\.layers\\..*\\.self_attn$" name: "^model\\.layers\\..*\\.self_attn$"

View file

@ -85,7 +85,7 @@ class ModelRunner:
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM): elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads, head_dim=128,
page_size=self.model.cache.page_size, causal=True, page_size=self.model.cache.page_size, causal=True,
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx) q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)
else: else: