fix: adapt prefix cache in forward_linux_flashinfer

This commit is contained in:
ceerrep 2025-02-18 11:40:15 +08:00
parent c70b6f4d5b
commit 2ffb43f974

View file

@ -374,6 +374,16 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
kv_seq_len = q_len
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
@ -453,26 +463,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe.squeeze(0)
compressed_kv.squeeze(0)
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
k_pe.unsqueeze(0)
compressed_kv.unsqueeze(0)
k_pe = k_pe[:, :q_len]
compressed_kv = compressed_kv[:, :q_len]
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv, k_pe = torch.split(
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
k_pe = k_pe[:, :kv_seq_len]
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
compressed_kv = compressed_kv[:, :kv_seq_len]
kv = (
self.kv_b_proj(compressed_kv)
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
attn_output = flash_attn_func(