Merge branch 'fix_precision_MLA' of https://github.com/kvcache-ai/ktransformers into server-prefix-cache

This commit is contained in:
ceerrep 2025-02-17 18:08:04 +08:00
commit bb1cadfff3
11 changed files with 479 additions and 46 deletions

View file

@ -30,6 +30,7 @@ from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
from ktransformers.util.utils import prefill_and_generate from ktransformers.util.utils import prefill_and_generate
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
custom_models = { custom_models = {
"DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM,
@ -170,9 +171,16 @@ def local_chat(
torch.set_default_dtype( torch.set_default_dtype(
torch.bfloat16 torch.bfloat16
) # TODO: Remove this, replace dtype using config ) # TODO: Remove this, replace dtype using config
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode, force_think if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
) generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -138,8 +138,6 @@ class StaticCache(transformers.StaticCache):
page_idx = cache_position // self.page_size page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx] return k_out, self.page_table_list[layer_idx]

View file

@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.max_position_embeddings, orig_module.base orig_module.dim, orig_module.max_position_embeddings, orig_module.base
@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs, # **kwargs,
# ): # ):
# BaseInjectedModule.__init__( # BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs # self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
# ) # )
# self.generate_device = generate_device # self.generate_device = generate_device
# self.prefill_device = prefill_device # self.prefill_device = prefill_device
@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,

View file

@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
import logging import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from flash_attn import flash_attn_with_kvcache, flash_attn_func from flash_attn import flash_attn_func
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
import os import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
logger = logging.getLogger("attention") logger = logging.getLogger("attention")
# Copied from transformers.models.llama.modeling_llama.rotate_half # Copied from transformers.models.llama.modeling_llama.rotate_half
@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000, chunck_size: int = 1000,
use_triton: bool = False,
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
self.orig_module.__init__(orig_module.config, self.orig_module.__init__(orig_module.config,
orig_module.layer_idx) orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically. self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
self.use_triton = use_triton self.mla_wrapper = None
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
#print(compressed_kv.shape) #print(compressed_kv.shape)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len] #attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
compressed_kv = compressed_kv.squeeze(1) compressed_kv = compressed_kv.squeeze(1)
""" """
@ -168,6 +173,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training attn_weights, p=self.attention_dropout, training=self.training
) )
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.matmul(attn_output, out_absorb.mT) attn_output = torch.matmul(attn_output, out_absorb.mT)
@ -186,7 +192,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return attn_output, None, past_key_value return attn_output, None, past_key_value
def forward_linux( def forward_linux_triton(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@ -232,7 +238,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode # decode
if self.use_triton and q_len == 1: if q_len == 1:
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
@ -277,7 +283,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# use triton attention kernel adapted from vLLM and SGLang for MQA # use triton attention kernel adapted from vLLM and SGLang for MQA
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
page_table, page_table,
position_ids.squeeze(0).to(torch.int32), attn_logits, position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
4, #num_kv_splits # follow vLLM, fix it TODO 4, #num_kv_splits # follow vLLM, fix it TODO
self.softmax_scale, self.softmax_scale,
past_key_value.page_size) past_key_value.page_size)
@ -338,6 +344,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def forward_linux_flashinfer(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if q_len == 1:
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)
# k_pe [max_pages, page_size, self.qk_rope_head_dim]
# compressed_kv [max_pages, page_size, self.kv_lora_rank]
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(1, 2)
assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope.squeeze_(1)
q_pe.squeeze_(1)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if self.mla_wrapper is None:
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
if self.mla_wrapper.need_plan:
self.mla_wrapper.need_plan = False
self.mla_wrapper.plan(None,None,None,
position_ids.squeeze(1)+1,
self.num_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
past_key_value.page_size,
self.softmax_scale,
q_nope.dtype,
compressed_kv.dtype)
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
"""
k = (
torch.cat([compressed_kv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(self.num_heads, dim=1)
)
v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
lens = position_ids.item() + 1
#print("lens", lens)
attn_ref, lse_ref = attention_ref(
1,
torch.cat([q_nope, q_pe], dim=-1),
k[:lens],
v[:lens],
False,
self.softmax_scale
)
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
"""
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2)
attn_output = torch.matmul(attn_output, out_absorb.mT)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
else:
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe.squeeze(0)
compressed_kv.squeeze(0)
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
k_pe.unsqueeze(0)
compressed_kv.unsqueeze(0)
k_pe = k_pe[:, :q_len]
compressed_kv = compressed_kv[:, :q_len]
kv = (
self.kv_b_proj(compressed_kv)
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
attn_output = flash_attn_func(
query_states,
key_states,
value_states_padded,
softmax_scale=self.softmax_scale,
causal=True,
)
if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim
).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_windows( def forward_windows(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -415,7 +569,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode if os.name == 'nt':
return self.forward_windows( return self.forward_windows(
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -427,16 +581,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
**kwargs, **kwargs,
) )
else: else:
return self.forward_linux( if flashinfer_enabled:
hidden_states, return self.forward_linux_flashinfer(
attention_mask, hidden_states,
position_ids, attention_mask,
past_key_value, position_ids,
output_attentions, past_key_value,
use_cache, output_attentions,
cache_position, use_cache,
**kwargs, cache_position,
) **kwargs,
)
else:
return self.forward_linux_triton(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
class KLlamaAttention(BaseInjectedModule): class KLlamaAttention(BaseInjectedModule):
@ -447,9 +613,10 @@ class KLlamaAttention(BaseInjectedModule):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
self.orig_module.__init__(orig_module.config, self.orig_module.__init__(orig_module.config,
orig_module.layer_idx) orig_module.layer_idx)
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

View file

@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs): **kwargs):
nn.Module.__init__(self) nn.Module.__init__(self)
nn.Module.__setattr__(self, "orig_module", orig_module) nn.Module.__setattr__(self, "orig_module", orig_module)
object.__setattr__(self, "key", key) object.__setattr__(self, "key", key)
object.__setattr__(self, "gguf_loader", gguf_loader) object.__setattr__(self, "gguf_loader", gguf_loader)
object.__setattr__(self, "config", config) object.__setattr__(self, "config", config)
object.__setattr__(self, "device", device) object.__setattr__(self, "prefill_device", prefill_device)
object.__setattr__(self, "generate_device", generate_device)
object.__setattr__(self, "device", generate_device)
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__, # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,

View file

@ -119,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
output_cpu:Tensor = None output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER = CPUInfer(Config().cpu_infer) CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
@ -132,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
**kwargs **kwargs
): ):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self.gguf_loader = gguf_loader
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU" assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.out_device = out_device self.out_device = out_device
@ -532,7 +536,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
generate_device: str = "cpu", generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU", generate_op: str | None = "KExpertsCPU",
**kwargs): **kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None: if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs) self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)

View file

@ -0,0 +1,240 @@
'''
Description : flashinfer MLA wrapper
Author : Boxin Zhang
Version : 0.2.2
'''
import torch
flashinfer_enabled = False
try:
import flashinfer
flashinfer_enabled = True
print("found flashinfer")
except ImportError:
print("flashinfer not found, use triton for linux")
import math
def attention_ref(
batch_size,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool,
sm_scale: float,
) -> torch.Tensor:
qo_len = q.shape[0] // batch_size
kv_len = k.shape[0] // batch_size
num_qo_heads = q.shape[1]
head_dim_qk = q.shape[2]
head_dim_vo = v.shape[2]
logits = (
torch.einsum(
"bmhd,bnhd->bhmn",
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
)
* sm_scale
)
#print("attn weights", logits)
if causal:
mask = (
torch.arange(kv_len - qo_len, kv_len).unsqueeze(1)
>= torch.arange(0, kv_len).unsqueeze(0)
).to(q.device)
else:
mask = torch.ones(qo_len, kv_len).to(q.device)
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)
p = torch.softmax(logits, dim=-1)
o_ref = (
torch.einsum(
"bhmn,bnhd->bmhd",
p,
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
)
.contiguous()
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
.to(q)
)
return o_ref, lse_ref * math.log2(math.e)
class MLAWrapper():
def __init__(self,
max_batch_size,
max_pages,
use_cuda_graph = True,
device = "cuda",
):
self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device)
self.max_batch_size = max_batch_size
self.max_pages = max_pages
if use_cuda_graph:
if self.max_batch_size == 1:
self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device)
self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
else:
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
else:
self.qo_indptr_buf = None
self.kv_indptr_buf = None
self.kv_indices_buf = None
self.kv_len_arr_buf = None
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.float_workspace_buffer,
use_cuda_graph=False,
qo_indptr=self.qo_indptr_buf,
kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf,
)
self.need_plan = True
def plan(self,
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
sm_scale,
q_data_type,
kv_data_type,
):
if qo_indptr is None:
assert self.max_batch_size == 1
qo_indptr = self.qo_indptr_buf
if kv_indptr is None:
assert self.max_batch_size == 1
kv_indptr = self.kv_indptr_buf
if kv_indices is None:
assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf
self.wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
False, # causal is False for decoding
sm_scale,
q_data_type,
kv_data_type,
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
class MLAWrapperSingleton():
wrappers:dict = {}
@classmethod
def get_instance(cls, device, *args, **kwargs)->MLAWrapper:
if device not in cls.wrappers:
cls.make_instance(device, *args, **kwargs)
return cls.wrappers[device]
@classmethod
def make_instance(cls, device, *args, **kwargs):
cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device)
@classmethod
def plan_all(cls, qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
sm_scale,
q_data_type,
kv_data_type,):
for device, wrapper in cls.wrappers.items():
kv_len_arr_cur_device = kv_len_arr.to(device)
wrapper.plan(qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr_cur_device,
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
sm_scale,
q_data_type,
kv_data_type,)
if __name__ == "__main__":
max_batch_size = 1
max_pages = 1
page_size = 64
num_heads = 128
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
wrapper = MLAWrapperSingleton.get_instance(
"cuda",
max_batch_size,
max_pages,
)
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
wrapper.plan(
None,
None,
None,
kv_len_arr,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
print(k[:10].shape)
print(v[:10].shape)
attn_ref, lse_ref = attention_ref(
max_batch_size,
torch.cat([q_nope, q_pe], dim=-1),
k[:10],
v[:10],
False,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")

View file

@ -93,11 +93,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module = None, orig_module: nn.Module = None,
generate_device: str = "cuda",
prefill_device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device

View file

@ -383,7 +383,7 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
prefill_op: str| None = "KLinearTorch", prefill_op: str| None = "KLinearTorch",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
# build all the linear operators # build all the linear operators
if prefill_op is not None: if prefill_op is not None:

View file

@ -109,6 +109,7 @@ GGML_TYPES = {
"Q5_K": 13, "Q5_K": 13,
"Q6_K": 14, "Q6_K": 14,
"IQ4_XS": 23, "IQ4_XS": 23,
"BF16": 30,
} }
GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()} GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
@ -116,6 +117,7 @@ GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
GGML_BLOCK_SIZES = { GGML_BLOCK_SIZES = {
"F32": 4, "F32": 4,
"F16": 2, "F16": 2,
"BF16": 2,
"Q4_0": 2 + 16, "Q4_0": 2 + 16,
"Q5_0": 2 + 4 + 16, "Q5_0": 2 + 4 + 16,
"Q8_0": 2 + 32, "Q8_0": 2 + 32,
@ -130,6 +132,7 @@ GGML_BLOCK_SIZES = {
GGML_ELEMENTS_PER_BLOCK = { GGML_ELEMENTS_PER_BLOCK = {
"F32": 1, "F32": 1,
"F16": 1, "F16": 1,
"BF16": 1,
"Q4_0": 32, "Q4_0": 32,
"Q5_0": 32, "Q5_0": 32,
"Q8_0": 32, "Q8_0": 32,
@ -333,6 +336,8 @@ class GGUFLoader:
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
values = values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count'] n_head = self.gguf_file_meta['llama.attention.head_count']
@ -764,6 +769,7 @@ def dequantize_f16_gpu(data, device):
GGML_DEQUANTIZE = { GGML_DEQUANTIZE = {
"F32": dequantize_f32, "F32": dequantize_f32,
"F16": dequantize_f16, "F16": dequantize_f16,
"BF16": dequantize_f16,
"Q4_0": dequantize_q4_0, "Q4_0": dequantize_q4_0,
"Q5_0": dequantize_q5_0, "Q5_0": dequantize_q5_0,
"Q8_0": dequantize_q8_0, "Q8_0": dequantize_q8_0,
@ -778,6 +784,7 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu, "F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu, "F16": dequantize_f16_gpu,
"BF16": dequantize_f16_gpu,
"Q4_0": dequantize_q4_0_gpu, "Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu, "Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu, "Q8_0": dequantize_q8_0_gpu,

View file

@ -17,6 +17,7 @@ from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer from ktransformers.util.textstream import TextStreamer
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped = False warm_uped = False
@ -87,7 +88,8 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
module.load() module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True, def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
mode = 'normal', force_think: bool = False): mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
@ -137,7 +139,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
) )
else: else:
past_key_values = None past_key_values = None
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.long) cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
generated_ids = torch.zeros( generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
) )
@ -182,7 +184,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
generated_ids[:, seq_length] = next_token generated_ids[:, seq_length] = next_token
tokens.append(int(next_token)) tokens.append(int(next_token))
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.long) cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
seq_length += 1 seq_length += 1
@ -195,7 +197,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
warm_uped = True warm_uped = True
cuda_graph_runner = CUDAGraphRunner() cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True) cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
if i > 1 and use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device) next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1) inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int() generated_ids[:, cache_position] = next_token.int()