use generation config from json file in official repo

This commit is contained in:
Atream 2025-02-27 11:48:34 +00:00
parent 9660b2cc1e
commit e645d84794
4 changed files with 57 additions and 21 deletions

View file

@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_linux_flashinfer(
def forward_linux_flashinfer_chunk(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward_linux_flashinfer(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if q_len <= self.chunck_size or not self.absorb_for_prefill:
return self.forward_linux_flashinfer_chunk(
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
**kwargs,
)
assert False
def forward_windows(
self,
hidden_states: torch.Tensor,