mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
updata function_call
This commit is contained in:
parent
038db30ec9
commit
a7e8d7c1af
4 changed files with 554 additions and 89 deletions
|
@ -1,4 +1,7 @@
|
|||
from typing import Any, List, Optional, Set
|
||||
import re
|
||||
import json
|
||||
import uuid
|
||||
from transformers import (
|
||||
LlamaTokenizer,
|
||||
AutoTokenizer,
|
||||
|
@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.last_request_id = thread_id
|
||||
return True
|
||||
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
|
||||
# Check if tools are present
|
||||
has_tools = tools is not None and len(tools) > 0
|
||||
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
elif isinstance(local_messages, str):
|
||||
#local_messages = local_messages[0]['content']
|
||||
input_ids = self.tokenize_prompt(local_messages)
|
||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
||||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
|
||||
|
@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
)
|
||||
|
||||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
|
||||
if Config().user_force_think:
|
||||
|
@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
yield think, None
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||
# output think token after prefill done
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, None
|
||||
self.profiler.pause_timer("prefill")
|
||||
|
||||
self.profiler.create_and_start_timer("decode")
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
|
||||
# Handle tool calling
|
||||
if has_tools:
|
||||
# Start collecting tokens until we detect a tool call
|
||||
collected_tokens = ""
|
||||
is_collecting_tool_call = False
|
||||
is_function_name_collected = False
|
||||
function_name = ""
|
||||
collected_arguments = ""
|
||||
brackets_count = 0
|
||||
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="", flush=True)
|
||||
collected_tokens += t
|
||||
|
||||
# Check if we're starting a tool call
|
||||
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
|
||||
is_collecting_tool_call = True
|
||||
|
||||
# Generate a unique tool call ID
|
||||
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
|
||||
|
||||
# Send first tool call info
|
||||
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
|
||||
# If tools are provided, use the first one's name
|
||||
recommended_function = tools[0].function.name
|
||||
else:
|
||||
# Otherwise try to extract from context
|
||||
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||
recommended_function = function_match.group(1) if function_match else ""
|
||||
|
||||
yield {
|
||||
'tool_call': {
|
||||
'id': tool_call_id,
|
||||
'type': 'function',
|
||||
'index': 0,
|
||||
'function': {
|
||||
'name': recommended_function,
|
||||
'arguments': ""
|
||||
}
|
||||
},
|
||||
'first_chunk': True
|
||||
}
|
||||
|
||||
# Extract function name if we're collecting tool call
|
||||
if is_collecting_tool_call and not is_function_name_collected:
|
||||
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||
if name_match:
|
||||
function_name = name_match.group(1)
|
||||
is_function_name_collected = True
|
||||
|
||||
# Track argument collection
|
||||
if is_collecting_tool_call and is_function_name_collected:
|
||||
args_position = collected_tokens.find('"arguments"')
|
||||
if args_position > -1:
|
||||
# Find the start of the JSON object after "arguments":
|
||||
json_start = collected_tokens.find('{', args_position)
|
||||
if json_start > -1:
|
||||
for i in range(json_start, len(collected_tokens)):
|
||||
char = collected_tokens[i]
|
||||
collected_arguments += char
|
||||
|
||||
if char == '{':
|
||||
brackets_count += 1
|
||||
elif char == '}':
|
||||
brackets_count -= 1
|
||||
|
||||
# Check if we've completed the arguments JSON
|
||||
if brackets_count == 0:
|
||||
# Send argument chunk
|
||||
yield {
|
||||
'tool_call': {
|
||||
'id': tool_call_id,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': function_name,
|
||||
'arguments': collected_arguments
|
||||
}
|
||||
},
|
||||
'argument_chunk': collected_arguments,
|
||||
'last_chunk': True,
|
||||
'prompt_tokens': 176,
|
||||
'completion_tokens': 20
|
||||
}
|
||||
# Reset for next potential tool call
|
||||
collected_tokens = ""
|
||||
is_collecting_tool_call = False
|
||||
is_function_name_collected = False
|
||||
function_name = ""
|
||||
collected_arguments = ""
|
||||
brackets_count = 0
|
||||
break
|
||||
|
||||
# Handle finish reason
|
||||
if finish_reason is not None:
|
||||
yield "", finish_reason
|
||||
|
||||
print("")
|
||||
else:
|
||||
# Regular text generation (no tools)
|
||||
for t, finish_reason in self.generate():
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t, finish_reason
|
||||
print("")
|
||||
|
||||
self.profiler.pause_timer("decode")
|
||||
self.report_last_time_performance()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue