fix server cache lens

This commit is contained in:
liam 2025-03-01 00:09:57 +08:00
parent 71f4599dee
commit 8ddc990668
2 changed files with 10 additions and 1 deletions

View file

@ -130,8 +130,11 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1]
if(input_ids_length >= self.args.cache_lens):
logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}")
self.seq_length = input_ids_length
return
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device

View file

@ -329,6 +329,12 @@ class TransformersInterface(BackendInterfaceBase):
@torch.no_grad
def generate(self):
self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length)
if(self.args.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0")
yield self.streamer.end()
return
logger.info(f"max_new_tokens: {self.args.max_new_tokens}")
self.profiler.set_counter("decode", 0)
for i in range(1, self.args.max_new_tokens):