mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-11 15:54:37 +00:00
support absorb for prefill long context
This commit is contained in:
parent
e9b1216a9a
commit
f4c198bd42
8 changed files with 93 additions and 33 deletions
|
@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
|
|||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.util.utils import get_compute_capability
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.cache_utils import Cache
|
||||
|
@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
absorb_for_prefill: bool = False,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
self.mla_wrapper = None
|
||||
self.absorb_for_prefill = absorb_for_prefill
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
|
@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
#assert q_nope.is_contiguous()
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
|
@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
|
||||
|
||||
# decode
|
||||
if q_len == 1:
|
||||
if q_len == 1 or self.absorb_for_prefill:
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
|
||||
|
@ -395,27 +399,41 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(1, 2)
|
||||
assert q_nope.is_contiguous()
|
||||
q_nope = q_nope.contiguous()
|
||||
#assert q_nope.is_contiguous()
|
||||
|
||||
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
|
||||
q_nope.squeeze_(1)
|
||||
q_pe.squeeze_(1)
|
||||
q_nope.squeeze_(0)
|
||||
q_pe.squeeze_(0)
|
||||
|
||||
# flash attn doesn't support head_dim bigger than 256, use flashinfer
|
||||
if self.mla_wrapper is None:
|
||||
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
|
||||
if self.mla_wrapper.need_plan:
|
||||
self.mla_wrapper.need_plan = False
|
||||
if self.mla_wrapper.need_plan:
|
||||
self.mla_wrapper.need_plan = False
|
||||
if q_len == 1:
|
||||
self.mla_wrapper.plan(None,None,None,
|
||||
position_ids.squeeze(1)+1,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
position_ids.squeeze(1)+1,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
else:
|
||||
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
|
||||
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
|
||||
self.mla_wrapper.plan(qo_indptr,None,None,
|
||||
kv_len_arr,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
past_key_value.page_size,
|
||||
self.softmax_scale,
|
||||
q_nope.dtype,
|
||||
compressed_kv.dtype)
|
||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
"""
|
||||
|
@ -443,6 +461,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
|
||||
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@ -571,7 +590,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt':
|
||||
if os.name == 'nt' or get_compute_capability()<8:
|
||||
print("for Windows or GPU before ampere, use forward_windows")
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue