roll back ktransformers backend, add max_tokens, max_completion_tokens param

This commit is contained in:
qiyuxinlin 2025-04-21 12:55:37 +00:00
parent a1162eea01
commit 03a65d6bea
10 changed files with 144 additions and 161 deletions

View file

@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
if pos < query_manager.query_map[query_updates[i].id].max_length:
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler):
try:
@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if max_tokens is not None:
max_completion_tokens = max_tokens
if max_completion_tokens is None:
max_completion_tokens = self.args.max_new_tokens
else:
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
if temperature is None:
temperature = Config().temperature
temperature = self.args.temperature
if top_p is None:
top_p = Config().top_p
top_p = self.args.top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p
return temperature, top_p, max_completion_tokens
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None,
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
profiler = Profiler()
profiler.create_and_start_timer("tokenize")
@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
temperature, top_p = self.get_sampling_params(temperature, top_p)
temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
query_add.sample_options.temperature = temperature
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
if query_add.estimated_length < query_add.query_length:
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
queue = asyncio.Queue(maxsize=max_new_tokens)
self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id
is_first_token = True
@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
profiler.pause_timer("decode")
report_last_time_performance(profiler)
yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
if profiler.get_counter('decode') >= max_new_tokens - 1:
yield "", "length"
else:
yield "", "stop"