support chunk prefill, support 139K context for 24G VRAM

This commit is contained in:
Atream 2025-03-01 11:28:25 +00:00
parent 494469d4c5
commit f35e8d41d8
10 changed files with 227 additions and 83 deletions

View file

@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_linux_flashinfer_chunk(
def forward_linux_flashinfer(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@ -512,35 +512,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_linux_flashinfer(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if q_len <= self.chunck_size or not self.absorb_for_prefill:
return self.forward_linux_flashinfer_chunk(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
assert False
def forward_windows(
self,
hidden_states: torch.Tensor,