support qwen3, dont speak human language

This commit is contained in:
djw 2025-04-28 08:44:47 +00:00
parent f3d842a0ca
commit 3f9bbf1181
30 changed files with 3696 additions and 290 deletions

View file

@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule):
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
self.max_seq_len_cached = max_position_embeddings
class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
config,
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
self.orig_module.__init__(
self.orig_module.config
)

View file

@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None
return attn_output, attn_weights, past_key_value
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
wrapper: BatchMLAPagedAttentionWrapper,
num_tokens_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states, num_tokens_tensors)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
q = q.view(q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = compressed_kv.contiguous()
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
q_pe = q_pe.squeeze(0)
if kv_cache is not None:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(0, 1)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
attn_output = attn_output.transpose(0, 1)
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, num_tokens_tensors)
return attn_output

View file

@ -0,0 +1,287 @@
'''
Description :
Author : Boxin Zhang
Version : 0.2.5
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import torch
from torch import nn
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
import logging
from transformers.configuration_utils import PretrainedConfig
from flashinfer import BatchMLAPagedAttentionWrapper
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache
logger = logging.getLogger("attention")
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
wrapper: BatchMLAPagedAttentionWrapper,
num_tokens_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states, num_tokens_tensors)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
q = q.view(q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = compressed_kv.contiguous()
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
q_pe = q_pe.squeeze(0)
if kv_cache is not None:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(0, 1)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
attn_output = attn_output.transpose(0, 1)
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, num_tokens_tensors)
return attn_output
class KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
position_ids: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, bsz_tensors)
key_states = self.k_proj(hidden_states, bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
key_states = key_states.view(
q_len, self.num_key_value_heads, self.head_dim
)
value_states = value_states.view(
q_len, self.num_key_value_heads, self.head_dim
)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
return attn_output
class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KGQACache,
position_ids: torch.Tensor,
wrapper: flashInferAttn,
bsz_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors)
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors)
value_states = self.v_proj(hidden_states, bsz_tensors)
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
key_states = key_states.view(
q_len, self.num_key_value_heads, self.head_dim
)
value_states = value_states.view(
q_len, self.num_key_value_heads, self.head_dim
)
k_cache = kv_cache.get_k_cache(self.layer_idx)
v_cache = kv_cache.get_v_cache(self.layer_idx)
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
return attn_output

View file

@ -689,6 +689,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
@ -1267,3 +1268,229 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
router_logits = self.gate(hidden_states, bsz_tensor)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
y_ = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, selected_experts, routing_weights)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
.view(*orig_shape)
.to(device=hidden_states.device)
)
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
router_logits = self.gate(hidden_states, bsz_tensor)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
# y += y_
y.resize_(*orig_shape)
return y
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
# y_ = (
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
# )
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, selected_experts, routing_weights)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
.view(*orig_shape)
.to(device=hidden_states.device)
)
# y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out

View file

@ -0,0 +1,324 @@
import torch
import flashinfer
import gc
try:
from flash_attn import flash_attn_with_kvcache
print("found flash_attn")
except ImportError:
print("flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.")
from typing import Union, Optional
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
setup_seed(998244353)
torch.set_grad_enabled(False)
torch.set_default_dtype(torch.bfloat16)
global_dtype=torch.bfloat16
global_device=torch.device("cuda",0)
torch.cuda.set_device(0)
torch.backends.cudnn.enabled =True
torch.backends.cudnn.benchmark = True
class flashInferAttn():
float_workspace_buffer = None
def __init__(self,
max_batch_token,
max_batch_size,
max_pages,
device = "cuda:0",
kv_layout: str = "NHD",
use_cuda_graph: bool = False,
) -> None:
self.device = device
self.max_batch_token = max_batch_token
self.kv_layout = kv_layout
self.use_cuda_graph = use_cuda_graph
if flashInferAttn.float_workspace_buffer is None:
flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device)
self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
self.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)
self.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device)
self.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device)
# TODO: custom mask
self.custom_mask_buf = None
self.qk_indptr_buf = None
self.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
flashInferAttn.float_workspace_buffer,
self.kv_layout,
use_cuda_graph=self.use_cuda_graph,
qo_indptr_buf=self.qo_indptr_buf,
paged_kv_indptr_buf=self.paged_kv_indptr_buf,
paged_kv_indices_buf=self.paged_kv_indices_buf,
paged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf,
backend = "fa2",
)
def plan(self,
qo_indptr: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
batch_size_tensor: torch.Tensor,
num_tokens_tensor: torch.Tensor,
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
page_size: int,
causal: bool = True,
pos_encoding_mode: str = "NONE",
q_data_type: Union[str, torch.dtype] = torch.bfloat16,
kv_data_type: Optional[Union[str, torch.dtype]] = None):
self.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True)
self.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True)
self.page_size = page_size
self.warpper.plan(
qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal = causal,
pos_encoding_mode = pos_encoding_mode,
q_data_type = q_data_type,
kv_data_type = kv_data_type
)
def calc_batch_indices(self, ragged_size = None):
if self.use_cuda_graph:
self.batch_indices, self.positions = flashinfer.get_batch_indices_positions(
self.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token)
else:
self.batch_indices, self.positions = flashinfer.get_batch_indices_positions(
self.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size)
def forward(self, q, k_cache, v_cache, k, v):
if self.use_cuda_graph:
flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)
return self.warpper.run(q, (k_cache, v_cache))
else:
flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)
return self.warpper.run(q, (k_cache, v_cache))
def testCudaGraph():
# use max batch to create buffer
batch_decode = 8
prefill_chunk = 48
past_kv_0 = 4090
past_kv_1 = 4096
raged_size = prefill_chunk + batch_decode
num_key_value_heads = 8
head_dim = 128
num_attention_heads = 64
page_size = 256
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
attn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True)
batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)
k_caches = []
v_caches = []
ks = []
vs = []
qs = []
for layer_idx in range(3):
k_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
v_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
ks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
vs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))
# warmup and capture small batch
past_kv_0 = 250
past_kv_1 = 256
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
q_indptr[0] = 0
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq
kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
print(q_indptr)
print(kv_indptr)
print(kv_indices)
print(kv_last_page_len)
attn.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
batch_size_tensor,
num_attention_heads,
num_key_value_heads,
head_dim,
page_size,
causal = True,
pos_encoding_mode="NONE",
q_data_type=torch.bfloat16)
attn.calc_batch_indices(raged_size)
for layer_idx in range(3):
attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])
torch.cuda.synchronize()
outs = []
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for layer_idx in range(3):
outs.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]))
g.replay()
kv_last_page_len[:1+batch_decode//2] = int(past_kv_0)
kv_last_page_len[1+batch_decode//2:] = int(past_kv_1)
for layer_idx in range(3):
for i in range(batch_decode + 1):
qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
o_ref_i = flash_attn_with_kvcache(
qi.unsqueeze(0),
k_caches[layer_idx],
v_caches[layer_idx],
causal=True,
block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),
cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)
)
o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
print(layer_idx, i)
torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)
# run another batch size use capture cuda graph
past_kv_0 = 4090
past_kv_1 = 4096
prefill_chunk = 24
batch_decode = 4
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)
num_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32)
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
q_indptr[0] = 0
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq
kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
attn.plan(q_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
batch_size_tensor,
num_attention_heads,
num_key_value_heads,
head_dim,
page_size,
causal = True,
pos_encoding_mode="NONE",
q_data_type=torch.bfloat16)
attn.calc_batch_indices(raged_size)
g.replay()
kv_last_page_len[:1+batch_decode//2] = int(past_kv_0)
kv_last_page_len[1+batch_decode//2:] = int(past_kv_1)
for layer_idx in range(3):
for i in range(batch_decode + 1):
qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
o_ref_i = flash_attn_with_kvcache(
qi.unsqueeze(0),
k_caches[layer_idx],
v_caches[layer_idx],
causal=True,
block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),
cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)
)
o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
print(layer_idx, i)
torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)
def testAttentionFlashInfer(
):
batch_decode = 32
prefill_chunk = 64
past_kv_0 = 510
past_kv_1 = 512
raged_size = prefill_chunk + batch_decode
num_key_value_heads = 8
head_dim = 128
num_attention_heads = 64
cases = 1
page_size = 32
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
qs = []
kvs = []
q_indptrs = []
kv_indptrs = []
kv_indicess = []
kv_last_page_lens = []
wrappers = []
for case_id in range(cases):
kvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
q_indptr[0] = 0
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
q_indptrs.append(q_indptr)
kv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq)
kv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32))
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
kv_last_page_lens.append(kv_last_page_len)
wrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
"NHD",
use_cuda_graph=True,
qo_indptr_buf=q_indptrs[case_id],
paged_kv_indptr_buf=kv_indptrs[case_id],
paged_kv_indices_buf=kv_indicess[case_id],
paged_kv_last_page_len_buf=kv_last_page_lens[case_id],
))
wrappers[case_id].plan(
q_indptrs[case_id],
kv_indptrs[case_id],
kv_indicess[case_id],
kv_last_page_lens[case_id],
num_attention_heads,
num_key_value_heads,
head_dim,
page_size,
causal = True,
pos_encoding_mode="ROPE_LLAMA",
q_data_type=torch.bfloat16
)
def custom_forward(case_id):
out = wrappers[case_id].run(qs[case_id], kvs[case_id])
custom_forward(0)
# testCudaGraph()
# pass

View file

@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.e_score_correction_bias = None
class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str| None = "KLinearMarlin",
use_quant: bool = False,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
self.generate_op = generate_op
self.prefill_op = prefill_op
self.is_windows = os.name == 'nt'
self.use_quant = use_quant
if not self.is_windows and use_quant:
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
gguf_loader, config, self.gate_linear, #orig_module
generate_device, generate_op, prefill_device, prefill_op)
else:
self.gate_linear = None
def forward(self, hidden_states) -> torch.Tensor:
if self.is_windows:
return self.orig_module.forward(hidden_states)
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
if self.use_quant:
logits = self.gate_linear.forward(logits)
else:
logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
)
return grouped_topk(hidden_states, logits,
self.top_k, self.norm_topk_prob,
self.n_group, self.topk_group)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
if not self.is_windows and self.use_quant:
self.gate_linear.load(self.orig_module.weight)
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None

View file

@ -26,6 +26,8 @@ from transformers import PretrainedConfig
import torch
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from flashinfer.norm import (
@ -75,4 +77,89 @@ class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
return self.weight * hidden_states.to(input_dtype)
class KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(config.hidden_size,
orig_module.variance_epsilon)
def forward(
self,
x: torch.Tensor,
batch_size_tensor: torch.Tensor = None,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
#return self.forward_native(x, residual)
if batch_size_tensor is None:
return self.forward_native(x)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return x, residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
return out
def forward_native(
self, hidden_states
):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
orig_module.variance_epsilon)
def forward(
self,
x: torch.Tensor,
batch_size_tensor: torch.Tensor = None,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
#return self.forward_native(x, residual)
bsz, hidden_size = x.shape
x = x.view(-1, self.orig_module.hidden_size)
if batch_size_tensor is None:
return self.forward_native(x)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return x, residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
out = out.view(bsz, hidden_size)
return out
def forward_native(
self, hidden_states
):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

View file

@ -4,8 +4,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
from transformers import PretrainedConfig
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
def __init__(self,
key: str,
@ -18,6 +17,21 @@ class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.hidden_size, orig_module.intermediate_size)
def forward(self, x, bsz_tensor):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
return down_proj
class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.intermediate_size)
def forward(self, x, bsz_tensor):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
return down_proj