mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
[update] support openai chat completion api
This commit is contained in:
parent
63b1c8525b
commit
299c4dca64
8 changed files with 166 additions and 83 deletions
|
@ -333,7 +333,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
logger.info(f"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}")
|
||||
if(self.max_new_tokens <= 0):
|
||||
logger.warning("max_new_tokens is less than 0")
|
||||
yield self.streamer.end()
|
||||
yield self.streamer.end(), "length"
|
||||
return
|
||||
logger.info(f"max_new_tokens: {self.max_new_tokens}")
|
||||
self.profiler.set_counter("decode", 0)
|
||||
|
@ -348,10 +348,17 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
|
||||
yield self.streamer.end(), None
|
||||
yield "", "stop"
|
||||
assert self.args.batch_size == 1
|
||||
break
|
||||
yield self.append_new_tokens(next_token)
|
||||
yield self.streamer.end()
|
||||
yield self.append_new_tokens(next_token), None
|
||||
|
||||
else: # for's else, if output get max new tokens
|
||||
yield self.streamer.end(), None
|
||||
yield "", "length"
|
||||
|
||||
|
||||
|
||||
def check_is_new(self, thread_id: str):
|
||||
if not self.use_static_cache:
|
||||
|
@ -391,20 +398,20 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t in self.generate():
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue