diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index 5d74b9f..ce9cb71 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -130,8 +130,11 @@ class KTransformersInterface(TransformersInterface): @torch.no_grad def prefill(self, input_ids: torch.Tensor, is_new: bool): input_ids_length = input_ids.shape[-1] + if(input_ids_length >= self.args.cache_lens): + logger.warning(f"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}") + self.seq_length = input_ids_length + return logger.debug(f"input_ids: {input_ids.shape}") - device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = "cuda:0" if device == "cuda" else device diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index a588fb3..e6d444e 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -329,6 +329,12 @@ class TransformersInterface(BackendInterfaceBase): @torch.no_grad def generate(self): + self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) + if(self.args.max_new_tokens <= 0): + logger.warning("max_new_tokens is less than 0") + yield self.streamer.end() + return + logger.info(f"max_new_tokens: {self.args.max_new_tokens}") self.profiler.set_counter("decode", 0) for i in range(1, self.args.max_new_tokens):