small fix about max new token

This commit is contained in:
Azure 2025-03-05 09:25:41 +00:00
parent 5aee6c0446
commit 662c1e4c14

View file

@ -329,15 +329,16 @@ class TransformersInterface(BackendInterfaceBase):
@torch.no_grad @torch.no_grad
def generate(self): def generate(self):
self.args.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1
if(self.args.max_new_tokens <= 0): logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
if(self.max_new_tokens <= 0):
logger.warning("max_new_tokens is less than 0") logger.warning("max_new_tokens is less than 0")
yield self.streamer.end() yield self.streamer.end()
return return
logger.info(f"max_new_tokens: {self.args.max_new_tokens}") logger.info(f"max_new_tokens: {self.max_new_tokens}")
self.profiler.set_counter("decode", 0) self.profiler.set_counter("decode", 0)
for i in range(1, self.args.max_new_tokens): for i in range(1, self.max_new_tokens):
with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
if flashinfer_enabled: if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,