mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-08 13:39:48 +00:00
fix: use flash_attn for faster prefill
This commit is contained in:
parent
bb0ccc7b1a
commit
5ac266085e
1 changed files with 26 additions and 14 deletions
|
@ -125,8 +125,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]
|
||||||
|
|
||||||
q_absorb, out_absorb = self.get_absorbed()
|
q_absorb, out_absorb = self.get_absorbed()
|
||||||
if hasattr(self.orig_module, 'kv_b_proj'):
|
# if hasattr(self.orig_module, 'kv_b_proj'):
|
||||||
del self.orig_module.kv_b_proj
|
# del self.orig_module.kv_b_proj
|
||||||
|
|
||||||
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
|
||||||
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
|
||||||
|
@ -216,13 +216,23 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
|
||||||
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
|
||||||
|
|
||||||
|
kv_seq_len = q_len
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
|
||||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||||
|
|
||||||
# decode
|
# decode
|
||||||
if q_len == 1:
|
if self.use_triton and q_len == 1:
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
|
@ -287,26 +297,28 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
k_pe.squeeze(0)
|
k_pe.squeeze(0)
|
||||||
compressed_kv.squeeze(0)
|
compressed_kv.squeeze(0)
|
||||||
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||||
k_pe.unsqueeze(0)
|
compressed_kv, k_pe = torch.split(
|
||||||
compressed_kv.unsqueeze(0)
|
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
k_pe = k_pe[:, :q_len]
|
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
|
||||||
compressed_kv = compressed_kv[:, :q_len]
|
k_pe = k_pe[:, :kv_seq_len]
|
||||||
|
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
|
||||||
|
compressed_kv = compressed_kv[:, :kv_seq_len]
|
||||||
kv = (
|
kv = (
|
||||||
self.kv_b_proj(compressed_kv)
|
self.kv_b_proj(compressed_kv)
|
||||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
)
|
)
|
||||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
||||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||||
|
|
||||||
key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
|
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
|
||||||
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
|
||||||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
|
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||||
|
|
||||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
|
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||||
|
|
||||||
attn_output = flash_attn_func(
|
attn_output = flash_attn_func(
|
||||||
|
@ -403,7 +415,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if not self.use_triton: # os.name == 'nt'
|
if os.name == 'nt' or hidden_states.shape[1] == 1: # Use in decode
|
||||||
return self.forward_windows(
|
return self.forward_windows(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
|
Loading…
Add table
Reference in a new issue