mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-13 00:29:59 +00:00
merge main; Add torch q8 linear
This commit is contained in:
parent
6c4ed59175
commit
ed8437413b
27 changed files with 1561 additions and 114 deletions
|
@ -20,8 +20,14 @@ from ktransformers.util.utils import get_compute_capability
|
|||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
from flash_attn import flash_attn_func
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
except:
|
||||
pass
|
||||
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
|
||||
from ktransformers.operators.triton_attention_prefill import context_attention_fwd
|
||||
import os
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
if flashinfer_enabled:
|
||||
|
@ -319,18 +325,27 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
|
||||
|
||||
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
|
||||
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
|
||||
|
||||
attn_output = flash_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states_padded,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=True,
|
||||
# for bsz = 1
|
||||
attn_output = torch.zeros(bsz * q_len, self.num_heads, self.v_head_dim, device=hidden_states.device)
|
||||
b_start_loc = torch.zeros(bsz, dtype=torch.int64, device=hidden_states.device)
|
||||
b_seq_len = torch.full((bsz,), q_len, dtype=torch.int64, device=hidden_states.device)
|
||||
|
||||
max_input_len = q_len
|
||||
|
||||
context_attention_fwd(
|
||||
q=query_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
k=key_states.squeeze(0).view(-1, self.num_heads, self.q_head_dim),
|
||||
v=value_states.squeeze(0).view(-1, self.num_heads, self.v_head_dim),
|
||||
o=attn_output,
|
||||
b_start_loc=b_start_loc,
|
||||
b_seq_len=b_seq_len,
|
||||
max_input_len=max_input_len,
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
attn_output = attn_output[:, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.v_head_dim
|
||||
|
@ -589,8 +604,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue