mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Refactor the chat interface to support tool calling and parameter processing
Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling. Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations. Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases. Updated ktransformers.py and transformers.py to support the passing of tool parameters. Modified the default value of top_p in config.py to 1.0 to increase generation diversity. Extended the message model in chat.py to support the transmission of tool call information. These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
This commit is contained in:
parent
038db30ec9
commit
8699109129
6 changed files with 574 additions and 99 deletions
|
@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
# thread_related
|
||||
last_request_id: Optional[str] = None
|
||||
ever_generated_ids: Set[int] = set()
|
||||
|
||||
def __init__(self, args: ConfigArgs = default_args):
|
||||
self.args = args
|
||||
self.queue_map:dict[int,asyncio.Queue] = {}
|
||||
|
@ -282,7 +283,21 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
p.start()
|
||||
processes.append(p)
|
||||
start_event.wait()
|
||||
|
||||
|
||||
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
||||
"""Get sampling parameters and handle default values and edge cases"""
|
||||
if temperature is None:
|
||||
temperature = Config().temperature
|
||||
if top_p is None:
|
||||
top_p = Config().top_p
|
||||
|
||||
if temperature == 0:
|
||||
temperature = 0.0001
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
|
||||
return temperature, top_p
|
||||
|
||||
def run_queue_proxy(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
@ -352,12 +366,9 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
||||
|
||||
profiler.pause_timer("tokenize")
|
||||
|
||||
profiler.create_and_start_timer("prefill")
|
||||
|
||||
|
||||
|
||||
query_add = sched_ext.QueryAdd()
|
||||
query_add.query_token = input_ids[0].tolist()
|
||||
|
@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase):
|
|||
#@TODO add server
|
||||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||||
query_add.stop_criteria = stop_criteria
|
||||
if temperature == 0:
|
||||
temperature = 0.0001
|
||||
|
||||
temperature, top_p = self.get_sampling_params(temperature, top_p)
|
||||
|
||||
query_add.sample_options.temperature = temperature
|
||||
if top_p == 0:
|
||||
top_p = 0.0001
|
||||
query_add.sample_options.top_p = top_p
|
||||
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue