Refactor the chat interface to support tool calling and parameter processing

Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling.

Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations.

Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases.

Updated ktransformers.py and transformers.py to support the passing of tool parameters.

Modified the default value of top_p in config.py to 1.0 to increase generation diversity.

Extended the message model in chat.py to support the transmission of tool call information.

These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
This commit is contained in:
sean.su 2025-04-14 15:23:37 +08:00
parent 038db30ec9
commit 8699109129
6 changed files with 574 additions and 99 deletions

View file

@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase):
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.queue_map:dict[int,asyncio.Queue] = {}
@ -282,7 +283,21 @@ class BalanceServeInterface(BackendInterfaceBase):
p.start()
processes.append(p)
start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if temperature is None:
temperature = Config().temperature
if top_p is None:
top_p = Config().top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase):
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError("local_messages should be List or str")
@ -352,12 +366,9 @@ class BalanceServeInterface(BackendInterfaceBase):
[input_ids, token_thinks], dim=1
)
profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist()
@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase):
#@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
if temperature == 0:
temperature = 0.0001
temperature, top_p = self.get_sampling_params(temperature, top_p)
query_add.sample_options.temperature = temperature
if top_p == 0:
top_p = 0.0001
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)