mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support qwen3
This commit is contained in:
parent
3f9bbf1181
commit
0da3792b27
5 changed files with 9 additions and 3 deletions
|
@ -1420,6 +1420,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||
int* locks // extra global storage for barrier synchronization
|
||||
) {
|
||||
int prob_m = *prob_m_ptr;
|
||||
prob_m = min(prob_m, 1024);
|
||||
const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);
|
||||
if(prob_m > 16 * thread_m_blocks)
|
||||
prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));
|
||||
|
|
|
@ -255,8 +255,11 @@ class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
|
|||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors)
|
||||
bsz_tensors_q = bsz_tensors * self.num_heads
|
||||
bsz_tensors_kv = bsz_tensors * self.num_key_value_heads
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
|
|
|
@ -56,6 +56,7 @@
|
|||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
backend: "AMXInt8" # or "AMXBF16" or "llamafile" (default)
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
|
|
|
@ -85,7 +85,7 @@ class ModelRunner:
|
|||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||
head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
head_dim=128,
|
||||
page_size=self.model.cache.page_size, causal=True,
|
||||
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue