rollback-triton-prefill

This commit is contained in:
Atream 2025-03-15 14:21:21 +00:00
parent bda9cf15e7
commit 3934b9dfc1

View file

@ -325,27 +325,18 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
# for bsz = 1
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
max_input_len = q_len
context_attention_fwd(
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
o=attn_output,
b_start_loc=b_start_loc,
b_seq_len=b_seq_len,
max_input_len=max_input_len,
is_causal=True
attn_output = flash_attn_func(
query_states,
key_states,
value_states_padded,
softmax_scale=self.softmax_scale,
causal=True,
)
if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, : self.v_head_dim]
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim