mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
roll back ktransformers backend, add max_tokens, max_completion_tokens param
This commit is contained in:
parent
a1162eea01
commit
03a65d6bea
10 changed files with 144 additions and 161 deletions
|
@ -80,7 +80,8 @@ def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_
|
|||
query_updates[i].generated_token = generated_tokens[i].item()
|
||||
if not query_manager.query_map[query_updates[i].id].is_prefill:
|
||||
pos = query_updates[i].active_position
|
||||
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
||||
if pos < query_manager.query_map[query_updates[i].id].max_length:
|
||||
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
|
||||
|
||||
def report_last_time_performance(profiler: Profiler):
|
||||
try:
|
||||
|
@ -314,19 +315,26 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
|
||||
start_event.wait()
|
||||
|
||||
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
||||
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||||
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
|
||||
"""Get sampling parameters and handle default values and edge cases"""
|
||||
if max_tokens is not None:
|
||||
max_completion_tokens = max_tokens
|
||||
if max_completion_tokens is None:
|
||||
max_completion_tokens = self.args.max_new_tokens
|
||||
else:
|
||||
max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)
|
||||
if temperature is None:
|
||||
temperature = Config().temperature
|
||||
temperature = self.args.temperature
|
||||
if top_p is None:
|
||||
top_p = Config().top_p
|
||||
top_p = self.args.top_p
|
||||
|
||||
if temperature == 0:
|
||||
temperature = 0.0001
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
|
||||
return temperature, top_p
|
||||
return temperature, top_p, max_completion_tokens
|
||||
|
||||
def run_queue_proxy(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
|
@ -380,7 +388,8 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
logger.debug(f"get input ids of shape {input_ids.shape}")
|
||||
return input_ids
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||||
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):
|
||||
profiler = Profiler()
|
||||
profiler.create_and_start_timer("tokenize")
|
||||
|
||||
|
@ -409,17 +418,17 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||||
query_add.stop_criteria = stop_criteria
|
||||
|
||||
temperature, top_p = self.get_sampling_params(temperature, top_p)
|
||||
temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)
|
||||
|
||||
query_add.sample_options.temperature = temperature
|
||||
query_add.sample_options.top_p = top_p
|
||||
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
||||
query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)
|
||||
|
||||
if query_add.estimated_length < query_add.query_length:
|
||||
raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')
|
||||
|
||||
query_id = self.sched_client.add_query(query_add)
|
||||
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
|
||||
queue = asyncio.Queue(maxsize=max_new_tokens)
|
||||
self.queue_map[query_id] = queue
|
||||
self.thread_map[thread_id] = query_id
|
||||
is_first_token = True
|
||||
|
@ -439,7 +448,7 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
profiler.pause_timer("decode")
|
||||
report_last_time_performance(profiler)
|
||||
yield self.streamer.end(), None
|
||||
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
|
||||
if profiler.get_counter('decode') >= max_new_tokens - 1:
|
||||
yield "", "length"
|
||||
else:
|
||||
yield "", "stop"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue