mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
support chunk prefill, support 139K context for 24G VRAM
This commit is contained in:
parent
494469d4c5
commit
f35e8d41d8
10 changed files with 227 additions and 83 deletions
|
@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux_flashinfer_chunk(
|
||||
def forward_linux_flashinfer(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -512,35 +512,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux_flashinfer(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if q_len <= self.chunck_size or not self.absorb_for_prefill:
|
||||
return self.forward_linux_flashinfer_chunk(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
def forward_windows(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue