mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
Refactor the chat interface to support tool calling and parameter processing
Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling. Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations. Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases. Updated ktransformers.py and transformers.py to support the passing of tool parameters. Modified the default value of top_p in config.py to 1.0 to increase generation diversity. Extended the message model in chat.py to support the transmission of tool call information. These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
This commit is contained in:
parent
038db30ec9
commit
8699109129
6 changed files with 574 additions and 99 deletions
|
@ -1,4 +1,7 @@
|
|||
from typing import Any, List, Optional, Set
|
||||
import re
|
||||
import json
|
||||
import uuid
|
||||
from transformers import (
|
||||
LlamaTokenizer,
|
||||
AutoTokenizer,
|
||||
|
@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
|
||||
# Check if tools are present
|
||||
has_tools = tools is not None and len(tools) > 0
|
||||
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
|
@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
)
|
||||
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
if Config().user_force_think:
|
||||
|
@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
|
||||
# Handle tool calling
|
||||
if has_tools:
|
||||
# Start collecting tokens until we detect a tool call
|
||||
collected_tokens = ""
|
||||
is_collecting_tool_call = False
|
||||
is_function_name_collected = False
|
||||
function_name = ""
|
||||
collected_arguments = ""
|
||||
brackets_count = 0
|
||||
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="", flush=True)
|
||||
collected_tokens += t
|
||||
|
||||
# Check if we're starting a tool call
|
||||
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
|
||||
is_collecting_tool_call = True
|
||||
|
||||
# Generate a unique tool call ID
|
||||
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
|
||||
|
||||
# Send first tool call info
|
||||
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
|
||||
# If tools are provided, use the first one's name
|
||||
recommended_function = tools[0].function.name
|
||||
else:
|
||||
# Otherwise try to extract from context
|
||||
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||
recommended_function = function_match.group(1) if function_match else ""
|
||||
|
||||
yield {
|
||||
'tool_call': {
|
||||
'id': tool_call_id,
|
||||
'type': 'function',
|
||||
'index': 0,
|
||||
'function': {
|
||||
'name': recommended_function,
|
||||
'arguments': ""
|
||||
}
|
||||
},
|
||||
'first_chunk': True
|
||||
}
|
||||
|
||||
# Extract function name if we're collecting tool call
|
||||
if is_collecting_tool_call and not is_function_name_collected:
|
||||
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
is_function_name_collected = True
|
||||
|
||||
# Track argument collection
|
||||
if is_collecting_tool_call and is_function_name_collected:
|
||||
args_position = collected_tokens.find('"arguments"')
|
||||
if args_position > -1:
|
||||
# Find the start of the JSON object after "arguments":
|
||||
json_start = collected_tokens.find('{', args_position)
|
||||
if json_start > -1:
|
||||
for i in range(json_start, len(collected_tokens)):
|
||||
char = collected_tokens[i]
|
||||
collected_arguments += char
|
||||
|
||||
if char == '{':
|
||||
brackets_count += 1
|
||||
elif char == '}':
|
||||
brackets_count -= 1
|
||||
|
||||
# Check if we've completed the arguments JSON
|
||||
if brackets_count == 0:
|
||||
# Send argument chunk
|
||||
yield {
|
||||
'tool_call': {
|
||||
'id': tool_call_id,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': function_name,
|
||||
'arguments': collected_arguments
|
||||
}
|
||||
},
|
||||
'argument_chunk': collected_arguments,
|
||||
'last_chunk': True,
|
||||
'prompt_tokens': 176,
|
||||
'completion_tokens': 20
|
||||
}
|
||||
# Reset for next potential tool call
|
||||
collected_tokens = ""
|
||||
is_collecting_tool_call = False
|
||||
is_function_name_collected = False
|
||||
function_name = ""
|
||||
collected_arguments = ""
|
||||
brackets_count = 0
|
||||
break
|
||||
|
||||
# Handle finish reason
|
||||
if finish_reason is not None:
|
||||
yield "", finish_reason
|
||||
|
||||
print("")
|
||||
else:
|
||||
# Regular text generation (no tools)
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue