mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
use generation config from json file in official repo
This commit is contained in:
parent
9660b2cc1e
commit
e645d84794
4 changed files with 57 additions and 21 deletions
|
@ -338,7 +338,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux_flashinfer(
|
||||
def forward_linux_flashinfer_chunk(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
@ -512,6 +512,35 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
def forward_linux_flashinfer(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if q_len <= self.chunck_size or not self.absorb_for_prefill:
|
||||
return self.forward_linux_flashinfer_chunk(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
assert False
|
||||
|
||||
|
||||
def forward_windows(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue