diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 0849d55..8a11f1f 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -53,8 +53,9 @@ class StaticCache(transformers.StaticCache): # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically self.page_size = 64 self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size - key_shape = (self.max_pages, self.page_size, 1, config.qk_rope_head_dim) - value_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank) + latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) + self.kv_lora_rank = config.kv_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim # TODO: support real page table self.page_table_map = dict() self.page_table_list = [] @@ -88,10 +89,17 @@ class StaticCache(transformers.StaticCache): target_device = device[f"blk.{idx}.self_attn"]["generate_device"] else: target_device = device - new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device) - new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) + + if self.is_MLA: + new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device) + new_layer_value_cache = None + torch._dynamo.mark_static_address(new_layer_key_cache) + else: + new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device) + new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) self.past_tokens.append(0) @@ -129,11 +137,12 @@ class StaticCache(transformers.StaticCache): if self.is_MLA: page_idx = cache_position // self.page_size page_offset = cache_position % self.page_size + # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim) #print("page_idx", page_idx) #print("page_offset", page_offset) - k_out[page_idx, page_offset, ...] = key_states - v_out[page_idx, page_offset, ...] = value_states - return k_out, v_out, self.page_table_list[layer_idx] + k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states + k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states + return k_out, self.page_table_list[layer_idx] else: k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index f487773..9f73d48 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,8 +13,6 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb -from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention -from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3 from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -23,8 +21,15 @@ from transformers.configuration_utils import PretrainedConfig from transformers.cache_utils import Cache from flash_attn import flash_attn_with_kvcache, flash_attn_func from ktransformers.operators.triton_attention import decode_attention_fwd_grouped +import os logger = logging.getLogger("attention") +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) # V3 MLA is same to V2 class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): @@ -80,6 +85,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) + # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split( @@ -103,16 +110,37 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - compressed_kv = compressed_kv.unsqueeze(1) - k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) - compressed_kv = compressed_kv.squeeze(1) - #if cache_position is not None: - # compressed_kv = compressed_kv[:,: cache_position[-1] + 1,:] - # k_pe = k_pe[:,:,: cache_position[-1] + 1,:] + + # compressed_kv [bsz, q_len, self.kv_lora_rank] + # k_pe [bsz, 1, q_len, self.qk_rope_head_dim] + k_pe = k_pe.transpose(1,2) + compressed_kv = compressed_kv.unsqueeze(2) + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + # k_pe [pages, page_size, 1, self.qk_rope_head_dim] + # compressed_kv [pages, page_size, 1, self.kv_lora_rank] + q_absorb, out_absorb = self.get_absorbed() + if hasattr(self.orig_module, 'kv_b_proj'): + del self.orig_module.kv_b_proj + # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] + k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:] + compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:] + # k_pe [bsz, 1, cache_len, self.qk_rope_head_dim] + # compressed_kv [bsz, 1, cache_len,self.kv_lora_rank] q_nope = torch.matmul(q_nope, q_absorb) - attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale + #print(q_pe.shape) + #print(k_pe.shape) + #print(q_nope.shape) + #print(compressed_kv.shape) + + attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale + #attn_weights [bsz, self.num_heads, q_len, kv_seq_len] + compressed_kv = compressed_kv.squeeze(1) """ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -156,25 +184,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): return attn_output, None, past_key_value - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + def forward_linux( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim) q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -184,38 +212,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) compressed_kv = self.kv_a_layernorm(compressed_kv) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) cos, sin = self.rotary_emb(q_pe, position_ids) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) - k_pe = k_pe.transpose(1, 2) # [bsz, q_len, 1, self.qk_rope_head_dim] - + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) + # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] + # decode if q_len == 1: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - k_pe, compressed_kv, page_table = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) - - # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] + compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed + # compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim] + # compressed_kv [bsz, q_len, 1, self.kv_lora_rank] + + # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim] # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() + q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below q_nope = torch.matmul(q_nope, q_absorb) # batched MM - # q_nope [bsz, self.num_heads, q_len, self.kv_lora_rank] - # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] - query_states = torch.cat([q_nope, q_pe], dim=-1) - # k_pe [bsz, q_len, 1, self.qk_rope_head_dim] - # compressed_kv [bsz, q_len, 1, self.kv_lora_rank] - key_states = torch.cat([compressed_kv, k_pe], dim=-1) + q_nope = q_nope.transpose(1, 2) + assert q_nope.is_contiguous() - query_states = query_states.squeeze(2) - attn_output = torch.zeros_like(q_nope) + # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank] + # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] + query_states = torch.cat([q_nope, q_pe], dim=-1) + + query_states = query_states.squeeze(1) + attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank] attn_logits = torch.empty( ( bsz, self.num_heads, - 1, #num_kv_splits # follow vLLM, fix it TODO + 4, #num_kv_splits # follow vLLM, fix it TODO self.kv_lora_rank + 1, ), dtype=torch.float32, @@ -224,22 +256,25 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): """ print("query_states", torch.isnan(query_states).any()) - print("key_states", torch.isnan(key_states[:,:,0,:]).any()) + print("compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any()) print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any()) print("position_ids", torch.isnan(position_ids).any()) """ - # flash attn doesn't support head_dim bigger than 256 + # flash attn doesn't support head_dim bigger than 256 # use vLLM triton attention kernel for MQA - decode_attention_fwd_grouped(query_states, key_states, compressed_kv, attn_output, + decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output, page_table, position_ids.squeeze(0).to(torch.int32), attn_logits, - 1, #num_kv_splits # follow vLLM, fix it TODO + 4, #num_kv_splits # follow vLLM, fix it TODO self.softmax_scale, past_key_value.page_size) - attn_output = torch.matmul(attn_output, out_absorb.mT) - attn_output = attn_output.transpose(1, 2).contiguous() + # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank] + # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank] + attn_output = attn_output.transpose(1, 2) + attn_output = torch.matmul(attn_output, out_absorb.mT) + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) @@ -250,7 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) - past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) k_pe.unsqueeze(0) compressed_kv.unsqueeze(0) @@ -261,7 +296,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -269,7 +304,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): key_states[:, :, :, :self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim:] = k_pe - query_states = query_states.transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) @@ -289,12 +323,106 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + + def forward_windows( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + if q_len <= self.chunck_size: + return self.forward_chunck( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs + ) + + assert output_attentions == False, "output_attentions is not supported when using chunked attention" + attn_output = None + cur_idx = 0 + while cur_idx < q_len: + if attention_mask is not None: + chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...] + else: + # generate chunk_mask automatically. + self.attn_mask = \ + torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \ + if self.attn_mask is None \ + else self.attn_mask + self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \ + -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\ + [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))] + self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38 + self.attn_mask[:, :, :, :cur_idx] = 0 + chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx)) + + cur_output, _, _ = self.forward_chunck( + hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...], + chunk_mask, + position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)], + past_key_value, + output_attentions, + use_cache, + cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)], + **kwargs + ) + cur_idx += self.chunck_size + if attn_output is None: + attn_output = cur_output + else: + attn_output = torch.cat((attn_output, cur_output), dim=-2) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if os.name == 'nt': + return self.forward_windows( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) + else: + return self.forward_linux( + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + **kwargs, + ) class KLlamaAttention(BaseInjectedModule):