mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
|
@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule):
|
|||
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
|
||||
class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config,
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config
|
||||
)
|
|
@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
num_tokens_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
|
||||
q = q.view(q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
|
||||
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
if kv_cache is not None:
|
||||
|
||||
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# q_nope.squeeze_(1)
|
||||
# q_pe.squeeze_(1)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, num_tokens_tensors)
|
||||
return attn_output
|
||||
|
|
287
ktransformers/operators/balance_serve_attention.py
Normal file
287
ktransformers/operators/balance_serve_attention.py
Normal file
|
@ -0,0 +1,287 @@
|
|||
'''
|
||||
Description :
|
||||
Author : Boxin Zhang
|
||||
Version : 0.2.5
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
import torch
|
||||
from torch import nn
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from flashinfer import BatchMLAPagedAttentionWrapper
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
num_tokens_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
|
||||
q = q.view(q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
|
||||
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
if kv_cache is not None:
|
||||
|
||||
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# q_nope.squeeze_(1)
|
||||
# q_pe.squeeze_(1)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, num_tokens_tensors)
|
||||
return attn_output
|
||||
|
||||
class KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
value_states = value_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
value_states = value_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
|
@ -689,6 +689,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
|||
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
||||
|
@ -1267,3 +1268,229 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
)
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
# y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
324
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
Normal file
324
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
import torch
|
||||
import flashinfer
|
||||
import gc
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
print("found flash_attn")
|
||||
|
||||
except ImportError:
|
||||
print("flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.")
|
||||
|
||||
from typing import Union, Optional
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
setup_seed(998244353)
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
global_dtype=torch.bfloat16
|
||||
global_device=torch.device("cuda",0)
|
||||
torch.cuda.set_device(0)
|
||||
torch.backends.cudnn.enabled =True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
class flashInferAttn():
|
||||
|
||||
float_workspace_buffer = None
|
||||
def __init__(self,
|
||||
max_batch_token,
|
||||
max_batch_size,
|
||||
max_pages,
|
||||
device = "cuda:0",
|
||||
kv_layout: str = "NHD",
|
||||
use_cuda_graph: bool = False,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.max_batch_token = max_batch_token
|
||||
self.kv_layout = kv_layout
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
if flashInferAttn.float_workspace_buffer is None:
|
||||
flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)
|
||||
self.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device)
|
||||
self.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device)
|
||||
|
||||
# TODO: custom mask
|
||||
self.custom_mask_buf = None
|
||||
self.qk_indptr_buf = None
|
||||
self.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
flashInferAttn.float_workspace_buffer,
|
||||
self.kv_layout,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
qo_indptr_buf=self.qo_indptr_buf,
|
||||
paged_kv_indptr_buf=self.paged_kv_indptr_buf,
|
||||
paged_kv_indices_buf=self.paged_kv_indices_buf,
|
||||
paged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
|
||||
def plan(self,
|
||||
qo_indptr: torch.Tensor,
|
||||
paged_kv_indptr: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_last_page_len: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor,
|
||||
num_tokens_tensor: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool = True,
|
||||
pos_encoding_mode: str = "NONE",
|
||||
q_data_type: Union[str, torch.dtype] = torch.bfloat16,
|
||||
kv_data_type: Optional[Union[str, torch.dtype]] = None):
|
||||
|
||||
self.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True)
|
||||
self.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True)
|
||||
self.page_size = page_size
|
||||
self.warpper.plan(
|
||||
qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = causal,
|
||||
pos_encoding_mode = pos_encoding_mode,
|
||||
q_data_type = q_data_type,
|
||||
kv_data_type = kv_data_type
|
||||
)
|
||||
|
||||
def calc_batch_indices(self, ragged_size = None):
|
||||
if self.use_cuda_graph:
|
||||
self.batch_indices, self.positions = flashinfer.get_batch_indices_positions(
|
||||
self.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token)
|
||||
else:
|
||||
self.batch_indices, self.positions = flashinfer.get_batch_indices_positions(
|
||||
self.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size)
|
||||
|
||||
def forward(self, q, k_cache, v_cache, k, v):
|
||||
if self.use_cuda_graph:
|
||||
flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)
|
||||
return self.warpper.run(q, (k_cache, v_cache))
|
||||
else:
|
||||
flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)
|
||||
return self.warpper.run(q, (k_cache, v_cache))
|
||||
|
||||
|
||||
def testCudaGraph():
|
||||
|
||||
# use max batch to create buffer
|
||||
batch_decode = 8
|
||||
prefill_chunk = 48
|
||||
past_kv_0 = 4090
|
||||
past_kv_1 = 4096
|
||||
raged_size = prefill_chunk + batch_decode
|
||||
num_key_value_heads = 8
|
||||
head_dim = 128
|
||||
num_attention_heads = 64
|
||||
page_size = 256
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
attn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True)
|
||||
|
||||
batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)
|
||||
|
||||
k_caches = []
|
||||
v_caches = []
|
||||
ks = []
|
||||
vs = []
|
||||
qs = []
|
||||
for layer_idx in range(3):
|
||||
k_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
v_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
ks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
vs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
|
||||
# warmup and capture small batch
|
||||
past_kv_0 = 250
|
||||
past_kv_1 = 256
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
|
||||
q_indptr[0] = 0
|
||||
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
|
||||
kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq
|
||||
kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)
|
||||
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
|
||||
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
|
||||
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
|
||||
|
||||
print(q_indptr)
|
||||
print(kv_indptr)
|
||||
print(kv_indices)
|
||||
print(kv_last_page_len)
|
||||
attn.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
batch_size_tensor,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = True,
|
||||
pos_encoding_mode="NONE",
|
||||
q_data_type=torch.bfloat16)
|
||||
|
||||
attn.calc_batch_indices(raged_size)
|
||||
for layer_idx in range(3):
|
||||
attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
outs = []
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for layer_idx in range(3):
|
||||
outs.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]))
|
||||
g.replay()
|
||||
|
||||
kv_last_page_len[:1+batch_decode//2] = int(past_kv_0)
|
||||
kv_last_page_len[1+batch_decode//2:] = int(past_kv_1)
|
||||
for layer_idx in range(3):
|
||||
for i in range(batch_decode + 1):
|
||||
|
||||
qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
o_ref_i = flash_attn_with_kvcache(
|
||||
qi.unsqueeze(0),
|
||||
k_caches[layer_idx],
|
||||
v_caches[layer_idx],
|
||||
causal=True,
|
||||
block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),
|
||||
cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)
|
||||
)
|
||||
o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
print(layer_idx, i)
|
||||
torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)
|
||||
|
||||
# run another batch size use capture cuda graph
|
||||
past_kv_0 = 4090
|
||||
past_kv_1 = 4096
|
||||
prefill_chunk = 24
|
||||
batch_decode = 4
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)
|
||||
num_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32)
|
||||
|
||||
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
|
||||
q_indptr[0] = 0
|
||||
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
|
||||
kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq
|
||||
kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)
|
||||
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
|
||||
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
|
||||
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
|
||||
attn.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
batch_size_tensor,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = True,
|
||||
pos_encoding_mode="NONE",
|
||||
q_data_type=torch.bfloat16)
|
||||
attn.calc_batch_indices(raged_size)
|
||||
g.replay()
|
||||
|
||||
kv_last_page_len[:1+batch_decode//2] = int(past_kv_0)
|
||||
kv_last_page_len[1+batch_decode//2:] = int(past_kv_1)
|
||||
for layer_idx in range(3):
|
||||
for i in range(batch_decode + 1):
|
||||
|
||||
qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
o_ref_i = flash_attn_with_kvcache(
|
||||
qi.unsqueeze(0),
|
||||
k_caches[layer_idx],
|
||||
v_caches[layer_idx],
|
||||
causal=True,
|
||||
block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),
|
||||
cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)
|
||||
)
|
||||
o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
print(layer_idx, i)
|
||||
torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)
|
||||
|
||||
|
||||
|
||||
def testAttentionFlashInfer(
|
||||
):
|
||||
batch_decode = 32
|
||||
prefill_chunk = 64
|
||||
past_kv_0 = 510
|
||||
past_kv_1 = 512
|
||||
raged_size = prefill_chunk + batch_decode
|
||||
num_key_value_heads = 8
|
||||
head_dim = 128
|
||||
num_attention_heads = 64
|
||||
cases = 1
|
||||
page_size = 32
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
|
||||
qs = []
|
||||
kvs = []
|
||||
q_indptrs = []
|
||||
kv_indptrs = []
|
||||
kv_indicess = []
|
||||
kv_last_page_lens = []
|
||||
wrappers = []
|
||||
for case_id in range(cases):
|
||||
kvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
|
||||
q_indptr[0] = 0
|
||||
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
|
||||
q_indptrs.append(q_indptr)
|
||||
kv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq)
|
||||
kv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32))
|
||||
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
|
||||
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
|
||||
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
wrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
qo_indptr_buf=q_indptrs[case_id],
|
||||
paged_kv_indptr_buf=kv_indptrs[case_id],
|
||||
paged_kv_indices_buf=kv_indicess[case_id],
|
||||
paged_kv_last_page_len_buf=kv_last_page_lens[case_id],
|
||||
))
|
||||
wrappers[case_id].plan(
|
||||
q_indptrs[case_id],
|
||||
kv_indptrs[case_id],
|
||||
kv_indicess[case_id],
|
||||
kv_last_page_lens[case_id],
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = True,
|
||||
pos_encoding_mode="ROPE_LLAMA",
|
||||
q_data_type=torch.bfloat16
|
||||
)
|
||||
|
||||
def custom_forward(case_id):
|
||||
out = wrappers[case_id].run(qs[case_id], kvs[case_id])
|
||||
|
||||
custom_forward(0)
|
||||
|
||||
# testCudaGraph()
|
||||
# pass
|
|
@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
self.e_score_correction_bias = None
|
||||
|
||||
|
||||
class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
generate_device: str = "cuda",
|
||||
generate_op: str| None = "KLinearMarlin",
|
||||
prefill_device: str = "cuda",
|
||||
prefill_op: str| None = "KLinearMarlin",
|
||||
use_quant: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
self.generate_op = generate_op
|
||||
self.prefill_op = prefill_op
|
||||
self.is_windows = os.name == 'nt'
|
||||
self.use_quant = use_quant
|
||||
if not self.is_windows and use_quant:
|
||||
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
|
||||
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
|
||||
gguf_loader, config, self.gate_linear, #orig_module
|
||||
generate_device, generate_op, prefill_device, prefill_op)
|
||||
else:
|
||||
self.gate_linear = None
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
if self.is_windows:
|
||||
return self.orig_module.forward(hidden_states)
|
||||
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
if self.use_quant:
|
||||
logits = self.gate_linear.forward(logits)
|
||||
else:
|
||||
logits = F.linear(
|
||||
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
||||
)
|
||||
|
||||
return grouped_topk(hidden_states, logits,
|
||||
self.top_k, self.norm_topk_prob,
|
||||
self.n_group, self.topk_group)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
|
||||
if not self.is_windows and self.use_quant:
|
||||
self.gate_linear.load(self.orig_module.weight)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias = None
|
|
@ -26,6 +26,8 @@ from transformers import PretrainedConfig
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from flashinfer.norm import (
|
||||
|
@ -75,4 +77,89 @@ class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
|
|||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(config.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
bsz, hidden_size = x.shape
|
||||
x = x.view(-1, self.orig_module.hidden_size)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
out = out.view(bsz, hidden_size)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
|
|
@ -4,8 +4,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
|
|||
from transformers import PretrainedConfig
|
||||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
|
||||
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP
|
||||
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
|
@ -18,6 +17,21 @@ class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
|||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.hidden_size, orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
||||
class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
Loading…
Add table
Add a link
Reference in a new issue