mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Merge branch 'fix_precision_MLA' of https://github.com/kvcache-ai/ktransformers into server-prefix-cache
This commit is contained in:
commit
bb1cadfff3
11 changed files with 479 additions and 46 deletions
|
@ -19,9 +19,13 @@ from ktransformers.util.custom_gguf import GGUFLoader
|
|||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from flash_attn import flash_attn_with_kvcache, flash_attn_func
|
||||
from flash_attn import flash_attn_func
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref
|
||||
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
|
@ -41,15 +45,15 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
use_triton: bool = False,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
self.use_triton = use_triton
|
||||
self.mla_wrapper = None
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
|
@ -141,6 +145,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
#print(compressed_kv.shape)
|
||||
|
||||
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
|
||||
|
||||
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
|
||||
compressed_kv = compressed_kv.squeeze(1)
|
||||
"""
|
||||
|
@ -168,8 +173,9 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.attention_dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
|
||||
|
||||
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
|
||||
|
@ -179,14 +185,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux(
|
||||
def forward_linux_triton(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -232,7 +238,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if self.use_triton and q_len == 1:
|
||||
if q_len == 1:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
|
@ -277,7 +283,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# use triton attention kernel adapted from vLLM and SGLang for MQA
|
||||
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
|
||||
page_table,
|
||||
position_ids.squeeze(0).to(torch.int32), attn_logits,
|
||||
position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
|
||||
4, #num_kv_splits # follow vLLM, fix it TODO
|
||||
self.softmax_scale,
|
||||
past_key_value.page_size)
|
||||
|
@ -337,6 +343,154 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux_flashinfer(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if q_len == 1:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)
|
||||
# k_pe [max_pages, page_size, self.qk_rope_head_dim]
|
||||
# compressed_kv [max_pages, page_size, self.kv_lora_rank]
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
|
||||
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
q_nope.squeeze_(1)
|
||||
q_pe.squeeze_(1)
|
||||
|
||||
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||
if self.mla_wrapper is None:
|
||||
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||
if self.mla_wrapper.need_plan:
|
||||
self.mla_wrapper.need_plan = False
|
||||
self.mla_wrapper.plan(None,None,None,
|
||||
position_ids.squeeze(1)+1,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
"""
|
||||
k = (
|
||||
torch.cat([compressed_kv, k_pe], dim=-1)
|
||||
.view(-1, 1, 512 + 64)
|
||||
.repeat_interleave(self.num_heads, dim=1)
|
||||
)
|
||||
v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
|
||||
lens = position_ids.item() + 1
|
||||
#print("lens", lens)
|
||||
attn_ref, lse_ref = attention_ref(
|
||||
1,
|
||||
torch.cat([q_nope, q_pe], dim=-1),
|
||||
k[:lens],
|
||||
v[:lens],
|
||||
False,
|
||||
self.softmax_scale
|
||||
)
|
||||
attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||
"""
|
||||
|
||||
# mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
|
||||
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
else:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe.squeeze(0)
|
||||
compressed_kv.squeeze(0)
|
||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
k_pe.unsqueeze(0)
|
||||
compressed_kv.unsqueeze(0)
|
||||
|
||||
k_pe = k_pe[:, :q_len]
|
||||
compressed_kv = compressed_kv[:, :q_len]
|
||||
kv = (
|
||||
self.kv_b_proj(compressed_kv)
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states_padded,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.v_head_dim
|
||||
).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_windows(
|
||||
self,
|
||||
|
@ -415,7 +569,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode
|
||||
if os.name == 'nt':
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -427,16 +581,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_linux(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
if flashinfer_enabled:
|
||||
return self.forward_linux_flashinfer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_linux_triton(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class KLlamaAttention(BaseInjectedModule):
|
||||
|
@ -447,9 +613,10 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, device, **kwargs)
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue