mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 17:19:42 +00:00
rollback-triton-prefill
This commit is contained in:
parent
bda9cf15e7
commit
3934b9dfc1
1 changed files with 8 additions and 17 deletions
|
@ -325,27 +325,18 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
# for bsz = 1
|
||||
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
|
||||
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
|
||||
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
|
||||
|
||||
max_input_len = q_len
|
||||
|
||||
context_attention_fwd(
|
||||
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
|
||||
o=attn_output,
|
||||
b_start_loc=b_start_loc,
|
||||
b_seq_len=b_seq_len,
|
||||
max_input_len=max_input_len,
|
||||
is_causal=True
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states_padded,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, : self.v_head_dim]
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.v_head_dim
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue