diff --git a/ktransformers/server/backend/interfaces/transformers.py b/ktransformers/server/backend/interfaces/transformers.py index d5d2857..7200ef4 100644 --- a/ktransformers/server/backend/interfaces/transformers.py +++ b/ktransformers/server/backend/interfaces/transformers.py @@ -329,15 +329,16 @@ class TransformersInterface(BackendInterfaceBase): @torch.no_grad def generate(self): - self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - if(self.args.max_new_tokens <= 0): + self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1 + logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}") + if(self.max_new_tokens <= 0): logger.warning("max_new_tokens is less than 0") yield self.streamer.end() return - logger.info(f"max_new_tokens: {self.args.max_new_tokens}") + logger.info(f"max_new_tokens: {self.max_new_tokens}") self.profiler.set_counter("decode", 0) - for i in range(1, self.args.max_new_tokens): + for i in range(1, self.max_new_tokens): with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): if flashinfer_enabled: MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,