Merge branch 'main' into feat-chunk-prefill-flashinfer

This commit is contained in:
Atream 2025-03-01 11:35:09 +00:00
commit fa03ea48dd
3 changed files with 10 additions and 146 deletions

View file

@ -129,8 +129,11 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device