mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 13:09:50 +00:00
fix: use flash_attn for faster prefill
This commit is contained in:
parent
bb0ccc7b1a
commit
5ac266085e
1 changed files with 26 additions and 14 deletions
|
@ -125,8 +125,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
del self.orig_module.kv_b_proj
|
||||
# if hasattr(self.orig_module, 'kv_b_proj'):
|
||||
# del self.orig_module.kv_b_proj
|
||||
|
||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||
|
@ -216,13 +216,23 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||
|
||||
kv_seq_len = q_len
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if q_len == 1:
|
||||
if self.use_triton and q_len == 1:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
|
@ -287,26 +297,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
k_pe.squeeze(0)
|
||||
compressed_kv.squeeze(0)
|
||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
k_pe.unsqueeze(0)
|
||||
compressed_kv.unsqueeze(0)
|
||||
|
||||
k_pe = k_pe[:, :q_len]
|
||||
compressed_kv = compressed_kv[:, :q_len]
|
||||
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
|
||||
k_pe = k_pe[:, :kv_seq_len]
|
||||
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
|
||||
compressed_kv = compressed_kv[:, :kv_seq_len]
|
||||
kv = (
|
||||
self.kv_b_proj(compressed_kv)
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
|
||||
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
|
||||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
|
@ -403,7 +415,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if not self.use_triton: # os.name == 'nt'
|
||||
if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
Loading…
Add table
Reference in a new issue