diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 2971cc7..53dac8b 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -125,8 +125,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): # compressed_kv [pages, page_size, 1, self.kv_lora_rank] q_absorb, out_absorb = self.get_absorbed() - if hasattr(self.orig_module, 'kv_b_proj'): - del self.orig_module.kv_b_proj + # if hasattr(self.orig_module, 'kv_b_proj'): + # del self.orig_module.kv_b_proj # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim] # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim] @@ -216,13 +216,23 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim) compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank) + + kv_seq_len = q_len + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2) # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim] # decode - if q_len == 1: + if self.use_triton and q_len == 1: if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) @@ -287,26 +297,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models k_pe.squeeze(0) compressed_kv.squeeze(0) - past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) - k_pe.unsqueeze(0) - compressed_kv.unsqueeze(0) - - k_pe = k_pe[:, :q_len] - compressed_kv = compressed_kv[:, :q_len] + compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs) + compressed_kv, k_pe = torch.split( + compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim) + k_pe = k_pe[:, :kv_seq_len] + compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank) + compressed_kv = compressed_kv[:, :kv_seq_len] kv = ( self.kv_b_proj(compressed_kv) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim) + key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim) key_states[:, :, :, :self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim:] = k_pe + key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1) - value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim) + value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim) value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0) attn_output = flash_attn_func( @@ -403,7 +415,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if not self.use_triton: # os.name == 'nt' + if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode return self.forward_windows( hidden_states, attention_mask,