mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-05-05 23:50:14 +00:00
Merge pull request #1158 from Creeper-MZ/function_call
Update Function call
This commit is contained in:
commit
a1162eea01
2 changed files with 119 additions and 116 deletions
|
|
@ -71,38 +71,37 @@ def getTools(buffer):
|
|||
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||
extracted_tools = []
|
||||
working_buffer = buffer
|
||||
|
||||
|
||||
|
||||
# Iterate over all function calls
|
||||
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
|
||||
# Find a complete function call
|
||||
start_index = working_buffer.find(tool_call_begin_marker)
|
||||
end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)
|
||||
|
||||
|
||||
if start_index == -1 or end_index == -1 or start_index > end_index:
|
||||
logger.warning("Not a function")
|
||||
break
|
||||
|
||||
|
||||
# Extract the full function call
|
||||
full_tool_call = working_buffer[start_index:end_index]
|
||||
|
||||
|
||||
# Remove this function call from the working buffer to prevent duplicate processing
|
||||
working_buffer = working_buffer.replace(full_tool_call, "", 1)
|
||||
|
||||
|
||||
# Extract the function name
|
||||
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
||||
function_name_end = full_tool_call.find("\n", function_name_start)
|
||||
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
||||
|
||||
|
||||
# Extract JSON parameters
|
||||
json_pattern = r'```json\s*(.*?)\s*```'
|
||||
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
||||
|
||||
|
||||
if json_match:
|
||||
arguments_str = json_match.group(1).strip()
|
||||
# Generate tool call IDs
|
||||
tool_call_id = f"call_{uuid4().hex[:24]}"
|
||||
|
||||
|
||||
# Add to tool call list
|
||||
extracted_tools.append({
|
||||
"id": tool_call_id,
|
||||
|
|
@ -112,45 +111,71 @@ def getTools(buffer):
|
|||
"arguments": arguments_str
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"Get Function: {function_name}")
|
||||
else:
|
||||
logger.warning(f"Unable to get function,function_name: {function_name}")
|
||||
|
||||
logger.warning(f"Unable to get function, function_name: {function_name}")
|
||||
|
||||
logger.info(f"Total {len(extracted_tools)} Functions")
|
||||
return extracted_tools
|
||||
|
||||
def get_tool_instructions():
|
||||
"""Return concise tool calling instructions in English"""
|
||||
return """
|
||||
<function▁calls▁instruct>
|
||||
When you need real-time information or specialized operations, use function calls with this format:
|
||||
|
||||
<tools▁begin><tool▁begin>function<tool▁sep>function_name
|
||||
```json
|
||||
{"param1": "value1", "param2": "value2",...}
|
||||
```<tool▁end><tools▁end>
|
||||
|
||||
The <available▁functions> in the user message are the available tools automatically attached by the system.
|
||||
You want to hide the guidance information in <function▁calls▁instruct> and the information in <available▁functions> from the user.
|
||||
Use functions when needed. Ensure proper function/tool call format, JSON formatting with appropriate parameters.
|
||||
|
||||
</function▁calls▁instruct>
|
||||
"""
|
||||
|
||||
@router.post('/chat/completions', tags=['openai'])
|
||||
async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||
id = str(uuid4().hex)
|
||||
|
||||
# 1. Use system prompts to let models know how to use tools
|
||||
|
||||
# Process messages with tool functionality if needed
|
||||
enhanced_messages = list(create.messages)
|
||||
|
||||
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
|
||||
if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user):
|
||||
tool_instructions = "你可以使用function_call,函数调用功能,目前,你可以使用以下工具\n\n"
|
||||
|
||||
# Check if tools are present
|
||||
has_tools = create.tools and len(create.tools) > 0
|
||||
|
||||
if has_tools:
|
||||
# Find the most recent user message to append tool information
|
||||
latest_user_msg_idx = -1
|
||||
for i in range(len(enhanced_messages) - 1, -1, -1):
|
||||
if enhanced_messages[i].role == Role.user:
|
||||
latest_user_msg_idx = i
|
||||
break
|
||||
|
||||
# Build the tool descriptions
|
||||
tools_description = ""
|
||||
for tool in create.tools:
|
||||
tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n"
|
||||
|
||||
# Modify tool usage guidelines to encourage JSON output
|
||||
tool_instructions += "name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n"
|
||||
tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n"
|
||||
tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n"
|
||||
tool_instructions += '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
||||
tool_instructions += '示例: \n<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
||||
tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n"
|
||||
tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
|
||||
|
||||
enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions
|
||||
|
||||
# Requests processed
|
||||
tools_description += f"<function><function_name>{tool.function.name}</function_name><function_description>{tool.function.description}</function_description><function_parameters>{tool.function.parameters}</function_parameters></function>\n"
|
||||
|
||||
# If first message is system, add concise tool instructions
|
||||
if enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user:
|
||||
if "<function▁calls▁instruct>" not in enhanced_messages[0].content.lower():
|
||||
enhanced_messages[0].content += "\n\n" + get_tool_instructions()
|
||||
|
||||
# For the latest user message, append tool information
|
||||
if latest_user_msg_idx >= 0:
|
||||
# Add tool descriptions to the latest user message
|
||||
enhanced_messages[latest_user_msg_idx].content += f"\n\n<available▁functions>:\n{tools_description}\n</available▁functions>"
|
||||
|
||||
# Process request
|
||||
interface: BackendInterfaceBase = get_interface()
|
||||
input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
|
||||
|
||||
if Config().api_key != '':
|
||||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||
|
||||
|
||||
if create.stream:
|
||||
async def inner():
|
||||
chunk = ChatCompletionChunk(
|
||||
|
|
@ -161,20 +186,27 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
model=Config().model_name,
|
||||
system_fingerprint=f"fp_{uuid4().hex[:12]}",
|
||||
)
|
||||
|
||||
# Collect the full output of the model, but specialize in processing tool calls
|
||||
|
||||
# Collect the full output of the model
|
||||
full_content = ""
|
||||
buffer = "" # Used to temporarily store the current block of text
|
||||
tool_call_mode = False # Mark if a tool call is being processed
|
||||
tool_calls = [] # Store all detected tool calls
|
||||
|
||||
# Customize model special tokens
|
||||
|
||||
# Tool call markers
|
||||
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||
tool_sep_marker = "<|tool▁sep|>"
|
||||
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||
|
||||
too_calls_dict = {
|
||||
"<tools▁begin>":"<|tool▁calls▁begin|>",
|
||||
"<tool▁begin>":"<|tool▁call▁begin|>",
|
||||
"<tool▁sep>":"<|tool▁sep|>",
|
||||
"<tool▁end>":"<|tool▁call▁end|>",
|
||||
"<tools▁end>":"<|tool▁calls▁end|>"
|
||||
}
|
||||
# Use check_client_connected for early stopping
|
||||
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
# Final return on utilization
|
||||
|
|
@ -188,11 +220,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
yield chunk
|
||||
elif isinstance(res, tuple) and len(res) == 2:
|
||||
token, finish_reason = res
|
||||
|
||||
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
||||
# Detecting model-specific formatting tool call starts
|
||||
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
|
||||
tool_call_mode = True
|
||||
|
||||
|
||||
# Adjust full_content to remove tool call section
|
||||
if buffer.endswith(tool_calls_begin_marker):
|
||||
full_content = full_content[:-len(tool_calls_begin_marker)]
|
||||
|
|
@ -200,7 +232,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
idx = (buffer + token).find(tool_calls_begin_marker)
|
||||
full_content = full_content[:-(len(buffer) - idx)]
|
||||
buffer = ""
|
||||
|
||||
|
||||
# Send the current cumulative text content (if any)
|
||||
if full_content:
|
||||
chunk.choices = [{
|
||||
|
|
@ -210,7 +242,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
}]
|
||||
yield chunk
|
||||
full_content = ""
|
||||
|
||||
|
||||
# Accumulation of content in non-tool call mode
|
||||
if not tool_call_mode:
|
||||
full_content += token
|
||||
|
|
@ -221,18 +253,17 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
else:
|
||||
# In tool call mode, continue to collect tool call related text
|
||||
buffer += token
|
||||
|
||||
|
||||
# If the tool call end marker is found
|
||||
if tool_calls_end_marker in buffer:
|
||||
try:
|
||||
# Parsing Calling Text Extraction Tool Calling Information
|
||||
|
||||
# Parse and extract tool calling information
|
||||
tool_calls = getTools(buffer)
|
||||
if len(tool_calls):
|
||||
# reset state
|
||||
tool_call_mode = False
|
||||
buffer = ""
|
||||
|
||||
|
||||
# Send tool call events
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
# First tool call message
|
||||
|
|
@ -254,7 +285,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
"finish_reason": None
|
||||
}]
|
||||
yield chunk
|
||||
|
||||
|
||||
# Sending Parameters
|
||||
chunk.choices = [{
|
||||
"index": 0,
|
||||
|
|
@ -267,7 +298,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
"finish_reason": None
|
||||
}]
|
||||
yield chunk
|
||||
|
||||
|
||||
# Send Completion Message
|
||||
chunk.choices = [{
|
||||
"index": 0,
|
||||
|
|
@ -275,7 +306,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
"finish_reason": "tool_calls"
|
||||
}]
|
||||
yield chunk
|
||||
|
||||
|
||||
# No further processing after return
|
||||
return
|
||||
else:
|
||||
|
|
@ -287,7 +318,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
logger.error(f"Error processing tool call: {e}")
|
||||
tool_call_mode = False
|
||||
buffer = ""
|
||||
|
||||
|
||||
# Normal text output (only in non-tool call mode)
|
||||
if not tool_call_mode and token:
|
||||
if finish_reason is not None:
|
||||
|
|
@ -307,17 +338,17 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
"finish_reason": None
|
||||
}]
|
||||
yield chunk
|
||||
|
||||
|
||||
# If gotten this far without returning, it means that the full tool call was not detected
|
||||
# Send Routine Completion Message
|
||||
if not tool_call_mode:
|
||||
chunk.choices = [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}]
|
||||
yield chunk
|
||||
|
||||
|
||||
return chat_stream_response(request, inner())
|
||||
else:
|
||||
# non streaming response processing
|
||||
|
|
@ -326,14 +357,20 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
tool_calls = []
|
||||
buffer = ""
|
||||
tool_call_mode = False
|
||||
|
||||
|
||||
# Custom model special markers
|
||||
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||
tool_sep_marker = "<|tool▁sep|>"
|
||||
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||
|
||||
too_calls_dict = {
|
||||
"<tools▁begin>":"<|tool▁calls▁begin|>",
|
||||
"<tool▁begin>":"<|tool▁call▁begin|>",
|
||||
"<tool▁sep>":"<|tool▁sep|>",
|
||||
"<tool▁end>":"<|tool▁call▁end|>",
|
||||
"<tools▁end>":"<|tool▁calls▁end|>"
|
||||
}
|
||||
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||
if isinstance(res, RawUsage):
|
||||
raw_usage = res
|
||||
|
|
@ -344,11 +381,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
)
|
||||
elif isinstance(res, tuple) and len(res) == 2:
|
||||
token, finish_reason = res
|
||||
|
||||
token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)
|
||||
# Detecting the start of model-specific formatting tool calls
|
||||
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
|
||||
tool_call_mode = True
|
||||
|
||||
|
||||
# Adjust full_content to remove tool call section
|
||||
if buffer.endswith(tool_calls_begin_marker):
|
||||
full_content = full_content[:-len(tool_calls_begin_marker)]
|
||||
|
|
@ -356,7 +393,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
idx = (buffer + token).find(tool_calls_begin_marker)
|
||||
full_content = full_content[:-(len(buffer) - idx)]
|
||||
buffer = ""
|
||||
|
||||
|
||||
# Accumulation of content in non-tool call mode
|
||||
if not tool_call_mode:
|
||||
full_content += token
|
||||
|
|
@ -367,55 +404,25 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
else:
|
||||
# In tool call mode, continue to collect tool call related text
|
||||
buffer += token
|
||||
|
||||
|
||||
# If the tool call end marker is found
|
||||
if tool_calls_end_marker in buffer:
|
||||
try:
|
||||
# Parsing Calling Text Extraction Tool Calling Information
|
||||
full_tool_call = buffer
|
||||
|
||||
# Extract function name
|
||||
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
||||
function_name_end = full_tool_call.find("\n", function_name_start)
|
||||
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
||||
|
||||
# Extract JSON Parameters - Extracts the content between ```json and ```.
|
||||
json_pattern = r'```json\s*(.*?)\s*```'
|
||||
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
arguments_str = json_match.group(1).strip()
|
||||
# Generate tool call IDs
|
||||
tool_call_id = f"call_{uuid4().hex[:24]}"
|
||||
|
||||
# Add to tool call list
|
||||
tool_calls.append({
|
||||
"id": tool_call_id,
|
||||
"index": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments_str
|
||||
}
|
||||
})
|
||||
|
||||
# If the tool call is successfully parsed, set the reason for completion
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
# reset state
|
||||
tool_call_mode = False
|
||||
buffer = ""
|
||||
else:
|
||||
# JSON extraction failed, probably incomplete formatting
|
||||
logger.warning("Failed to extract JSON from tool call")
|
||||
tool_call_mode = False
|
||||
buffer = ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tool call: {e}")
|
||||
tool_call_mode = False
|
||||
buffer = ""
|
||||
|
||||
# Extract tool calls
|
||||
tool_calls = getTools(buffer)
|
||||
if tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
# Reset state
|
||||
tool_call_mode = False
|
||||
buffer = ""
|
||||
|
||||
# Build Response
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": None if tool_calls else full_content
|
||||
}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
response = {
|
||||
"id": id,
|
||||
"object": "chat.completion",
|
||||
|
|
@ -423,15 +430,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
|||
"model": Config().model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None if tool_calls else full_content,
|
||||
"tool_calls": tool_calls if tool_calls else None
|
||||
},
|
||||
"message": message,
|
||||
"finish_reason": finish_reason or "stop"
|
||||
}],
|
||||
"usage": usage.__dict__,
|
||||
"usage": usage.__dict__ if 'usage' in locals() else None,
|
||||
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
return response
|
||||
|
|
@ -24,7 +24,7 @@ class Message(BaseModel):
|
|||
content: Optional[str] = None
|
||||
role: Role
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = {}
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
def to_tokenizer_message(self):
|
||||
|
|
@ -33,7 +33,7 @@ class Message(BaseModel):
|
|||
message['content'] = self.content
|
||||
if self.name is not None:
|
||||
message['name'] = self.name
|
||||
if self.tool_calls is not None:
|
||||
if self.tool_calls is not {}:
|
||||
message['tool_calls'] = self.tool_calls
|
||||
if self.tool_call_id is not None:
|
||||
message['tool_call_id'] = self.tool_call_id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue