merge main; Add torch q8 linear

This commit is contained in:
Azure-Tang 2025-03-14 05:52:07 -04:00
parent 6c4ed59175
commit ed8437413b
27 changed files with 1561 additions and 114 deletions

View file

@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
from flash_attn import flash_attn_func
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
try:
from flash_attn import flash_attn_func
except:
pass
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
attn_output = flash_attn_func(
query_states,
key_states,
value_states_padded,
softmax_scale=self.softmax_scale,
causal=True,
# for bsz = 1
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
max_input_len = q_len
context_attention_fwd(
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
o=attn_output,
b_start_loc=b_start_loc,
b_seq_len=b_seq_len,
max_input_len=max_input_len,
is_causal=True
)
if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output[:, :, : self.v_head_dim]
attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.v_head_dim
@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
return self.forward_windows(
hidden_states,
attention_mask,