linux support triton MLA kernel

This commit is contained in:
Atream 2025-02-14 11:38:55 +00:00
parent bb35dc5b0d
commit 1084d4e4b4
2 changed files with 198 additions and 61 deletions

View file

@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache):
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
self.page_size = 64
self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size
key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim)
value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank)
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
self.kv_lora_rank = config.kv_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
# TODO: support real page table
self.page_table_map = dict()
self.page_table_list = []
@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache):
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
else:
target_device = device
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
if self.is_MLA:
new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = None
torch._dynamo.mark_static_address(new_layer_key_cache)
else:
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
self.past_tokens.append(0)
@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache):
if self.is_MLA:
page_idx = cache_position // self.page_size
page_offset = cache_position % self.page_size
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
#print("page_idx", page_idx)
#print("page_offset", page_offset)
k_out[page_idx, page_offset, ...] = key_states
v_out[page_idx, page_offset, ...] = value_states
return k_out, v_out, self.page_table_list[layer_idx]
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
return k_out, self.page_table_list[layer_idx]
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

View file

@ -13,8 +13,6 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
@ -23,8 +21,15 @@ from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
from flash_attn import flash_attn_with_kvcache, flash_attn_func
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
import os
logger = logging.getLogger("attention")
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# V3 MLA is same to V2
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
@ -80,6 +85,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
@ -103,16 +110,37 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
compressed_kv = compressed_kv.unsqueeze(1)
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv.squeeze(1)
#if cache_position is not None:
# compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:]
# k_pe = k_pe[:,:,: cache_position[-1] + 1,:]
# compressed_kv [bsz, q_len, self.kv_lora_rank]
# k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
k_pe = k_pe.transpose(1,2)
compressed_kv = compressed_kv.unsqueeze(2)
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv, k_pe = torch.split(
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
# k_pe [pages, page_size, 1, self.qk_rope_head_dim]
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
q_absorb, out_absorb = self.get_absorbed()
if hasattr(self.orig_module, 'kv_b_proj'):
del self.orig_module.kv_b_proj
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:]
compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:]
# k_pe [bsz, 1, cache_len, self.qk_rope_head_dim]
# compressed_kv [bsz, 1, cache_len,self.kv_lora_rank]
q_nope = torch.matmul(q_nope, q_absorb)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
#print(q_pe.shape)
#print(k_pe.shape)
#print(q_nope.shape)
#print(compressed_kv.shape)
attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
#attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
compressed_kv = compressed_kv.squeeze(1)
"""
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
@ -156,25 +184,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
def forward_linux(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
@ -184,38 +212,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
k_pe = k_pe.transpose(1, 2) # [bsz, q_len, 1, self.qk_rope_head_dim]
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if q_len == 1:
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe, compressed_kv, page_table = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed
# compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim]
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
# q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
# q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
# q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
query_states = torch.cat([q_nope, q_pe], dim=-1)
# k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# compressed_kv [bsz, q_len, 1, self.kv_lora_rank]
key_states = torch.cat([compressed_kv, k_pe], dim=-1)
q_nope = q_nope.transpose(1, 2)
assert q_nope.is_contiguous()
query_states = query_states.squeeze(2)
attn_output = torch.zeros_like(q_nope)
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
query_states = torch.cat([q_nope, q_pe], dim=-1)
query_states = query_states.squeeze(1)
attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_logits = torch.empty(
(
bsz,
self.num_heads,
1, #num_kv_splits # follow vLLM, fix it TODO
4, #num_kv_splits # follow vLLM, fix it TODO
self.kv_lora_rank + 1,
),
dtype=torch.float32,
@ -224,22 +256,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""
print("query_states", torch.isnan(query_states).any())
print("key_states", torch.isnan(key_states[:,:,0,:]).any())
print("compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any())
print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
print("position_ids", torch.isnan(position_ids).any())
"""
# flash attn doesn't support head_dim bigger than 256
# flash attn doesn't support head_dim bigger than 256
# use vLLM triton attention kernel for MQA
decode_attention_fwd_grouped(query_states, key_states, compressed_kv, attn_output,
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
page_table,
position_ids.squeeze(0).to(torch.int32), attn_logits,
1, #num_kv_splits # follow vLLM, fix it TODO
4, #num_kv_splits # follow vLLM, fix it TODO
self.softmax_scale,
past_key_value.page_size)
attn_output = torch.matmul(attn_output, out_absorb.mT)
attn_output = attn_output.transpose(1, 2).contiguous()
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2)
attn_output = torch.matmul(attn_output, out_absorb.mT)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
@ -250,7 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe.squeeze(0)
compressed_kv.squeeze(0)
past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
k_pe.unsqueeze(0)
compressed_kv.unsqueeze(0)
@ -261,7 +296,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
@ -269,7 +304,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
query_states = query_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
@ -289,12 +323,106 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_windows(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
bsz, q_len, _ = hidden_states.size()
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
if q_len <= self.chunck_size:
return self.forward_chunck(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs
)
assert output_attentions == False, "output_attentions is not supported when using chunked attention"
attn_output = None
cur_idx = 0
while cur_idx < q_len:
if attention_mask is not None:
chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
else:
# generate chunk_mask automatically.
self.attn_mask = \
torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
if self.attn_mask is None \
else self.attn_mask
self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
-1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
[:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
self.attn_mask[:, :, :, :cur_idx] = 0
chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))
cur_output, _, _ = self.forward_chunck(
hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
chunk_mask,
position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
past_key_value,
output_attentions,
use_cache,
cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
**kwargs
)
cur_idx += self.chunck_size
if attn_output is None:
attn_output = cur_output
else:
attn_output = torch.cat((attn_output, cur_output), dim=-2)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt':
return self.forward_windows(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
else:
return self.forward_linux(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
class KLlamaAttention(BaseInjectedModule):