mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Merge branch 'main' into feat-chunk-prefill-flashinfer
This commit is contained in:
commit
fa03ea48dd
3 changed files with 10 additions and 146 deletions
|
@ -129,8 +129,11 @@ class KTransformersInterface(TransformersInterface):
|
|||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float], top_p: Optional[float]):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
if(input_ids_length >= self.args.cache_lens):
|
||||
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
|
||||
self.seq_length = input_ids_length
|
||||
return
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
device = "cuda:0" if device == "cuda" else device
|
||||
|
||||
|
|
|
@ -328,6 +328,12 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
|
||||
@torch.no_grad
|
||||
def generate(self):
|
||||
self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length)
|
||||
if(self.args.max_new_tokens <= 0):
|
||||
logger.warning("max_new_tokens is less than 0")
|
||||
yield self.streamer.end()
|
||||
return
|
||||
logger.info(f"max_new_tokens: {self.args.max_new_tokens}")
|
||||
self.profiler.set_counter("decode", 0)
|
||||
for i in range(1, self.args.max_new_tokens):
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue