更改token注入逻辑,减少token注入量,防止遗忘

Update chat.py

Update chat.py

Update chat.py
This commit is contained in:
Creeper-MZ 2025-04-16 14:55:30 -04:00
parent a7e8d7c1af
commit 88f688e2c8

View file

@ -71,38 +71,37 @@ def getTools(buffer):
tool_calls_end_marker = "<tool▁calls▁end>"
extracted_tools = []
working_buffer = buffer
# Iterate over all function calls
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
# Find a complete function call
start_index = working_buffer.find(tool_call_begin_marker)
end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)
if start_index == -1 or end_index == -1 or start_index > end_index:
logger.warning("Not a function")
break
# Extract the full function call
full_tool_call = working_buffer[start_index:end_index]
# Remove this function call from the working buffer to prevent duplicate processing
working_buffer = working_buffer.replace(full_tool_call, "", 1)
# Extract the function name
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
function_name_end = full_tool_call.find("\n", function_name_start)
function_name = full_tool_call[function_name_start:function_name_end].strip()
# Extract JSON parameters
json_pattern = r'```json\s*(.*?)\s*```'
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
if json_match:
arguments_str = json_match.group(1).strip()
# Generate tool call IDs
tool_call_id = f"call_{uuid4().hex[:24]}"
# Add to tool call list
extracted_tools.append({
"id": tool_call_id,
@ -112,45 +111,65 @@ def getTools(buffer):
"arguments": arguments_str
}
})
logger.info(f"Get Function: {function_name}")
else:
logger.warning(f"Unable to get functionfunction_name: {function_name}")
logger.warning(f"Unable to get function, function_name: {function_name}")
logger.info(f"Total {len(extracted_tools)} Functions")
return extracted_tools
def get_tool_instructions():
"""Return concise tool calling instructions in English"""
return """When you need real-time information or specialized operations, use function calls with this format:
<toolcallsbegin><toolcallbegin>function<toolsep>function_name
```json
{"param1": "value1", "param2": "value2"}
```<toolcallend><toolcallsend>
Only use functions when needed. Ensure proper JSON formatting with appropriate parameters."""
@router.post('/chat/completions', tags=['openai'])
async def chat_completion(request: Request, create: ChatCompletionCreate):
id = str(uuid4().hex)
# 1. Use system prompts to let models know how to use tools
# Process messages with tool functionality if needed
enhanced_messages = list(create.messages)
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user):
tool_instructions = "你可以使用function_call函数调用功能目前你可以使用以下工具\n\n"
# Check if tools are present
has_tools = create.tools and len(create.tools) > 0
if has_tools:
# Find the most recent user message to append tool information
latest_user_msg_idx = -1
for i in range(len(enhanced_messages) - 1, -1, -1):
if enhanced_messages[i].role == Role.user:
latest_user_msg_idx = i
break
# Build the tool descriptions
tools_description = ""
for tool in create.tools:
tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n"
# Modify tool usage guidelines to encourage JSON output
tool_instructions += "name为函数名称description为函数功能的描述parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n"
tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n"
tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n"
tool_instructions += '<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<tool▁call▁end><tool▁calls▁end>\n'
tool_instructions += '示例: \n<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<tool▁call▁end><tool▁calls▁end>\n'
tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n"
tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions
# Requests processed
tools_description += f"Function: {tool.function.name}\nDescription: {tool.function.description}\nParameters: {tool.function.parameters}\n\n"
# If first message is system, add concise tool instructions
if enhanced_messages[0].role == Role.system:
if "function calls" not in enhanced_messages[0].content.lower():
enhanced_messages[0].content += "\n\n" + get_tool_instructions()
# For the latest user message, append tool information
if latest_user_msg_idx >= 0:
# Add tool descriptions to the latest user message
enhanced_messages[latest_user_msg_idx].content += f"\n\nAvailable tools:\n{tools_description}"
# Process request
interface: BackendInterfaceBase = get_interface()
input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
if Config().api_key != '':
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream:
async def inner():
chunk = ChatCompletionChunk(
@ -161,20 +180,21 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
model=Config().model_name,
system_fingerprint=f"fp_{uuid4().hex[:12]}",
)
# Collect the full output of the model, but specialize in processing tool calls
# Collect the full output of the model
full_content = ""
buffer = "" # Used to temporarily store the current block of text
tool_call_mode = False # Mark if a tool call is being processed
tool_calls = [] # Store all detected tool calls
# Customize model special tokens
# Tool call markers
tool_calls_begin_marker = "<tool▁calls▁begin>"
tool_call_begin_marker = "<tool▁call▁begin>"
tool_sep_marker = "<tool▁sep>"
tool_call_end_marker = "<tool▁call▁end>"
tool_calls_end_marker = "<tool▁calls▁end>"
# Use check_client_connected for early stopping
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
# Final return on utilization
@ -188,11 +208,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
yield chunk
elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res
# Detecting model-specific formatting tool call starts
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
tool_call_mode = True
# Adjust full_content to remove tool call section
if buffer.endswith(tool_calls_begin_marker):
full_content = full_content[:-len(tool_calls_begin_marker)]
@ -200,7 +220,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
idx = (buffer + token).find(tool_calls_begin_marker)
full_content = full_content[:-(len(buffer) - idx)]
buffer = ""
# Send the current cumulative text content (if any)
if full_content:
chunk.choices = [{
@ -210,7 +230,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
}]
yield chunk
full_content = ""
# Accumulation of content in non-tool call mode
if not tool_call_mode:
full_content += token
@ -221,18 +241,17 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
else:
# In tool call mode, continue to collect tool call related text
buffer += token
# If the tool call end marker is found
if tool_calls_end_marker in buffer:
try:
# Parsing Calling Text Extraction Tool Calling Information
# Parse and extract tool calling information
tool_calls = getTools(buffer)
if len(tool_calls):
# reset state
tool_call_mode = False
buffer = ""
# Send tool call events
for idx, tool_call in enumerate(tool_calls):
# First tool call message
@ -254,7 +273,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"finish_reason": None
}]
yield chunk
# Sending Parameters
chunk.choices = [{
"index": 0,
@ -267,7 +286,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"finish_reason": None
}]
yield chunk
# Send Completion Message
chunk.choices = [{
"index": 0,
@ -275,7 +294,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"finish_reason": "tool_calls"
}]
yield chunk
# No further processing after return
return
else:
@ -287,7 +306,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
logger.error(f"Error processing tool call: {e}")
tool_call_mode = False
buffer = ""
# Normal text output (only in non-tool call mode)
if not tool_call_mode and token:
if finish_reason is not None:
@ -307,17 +326,17 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
"finish_reason": None
}]
yield chunk
# If gotten this far without returning, it means that the full tool call was not detected
# Send Routine Completion Message
if not tool_call_mode:
chunk.choices = [{
"index": 0,
"delta": {},
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
yield chunk
return chat_stream_response(request, inner())
else:
# non streaming response processing
@ -326,14 +345,14 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
tool_calls = []
buffer = ""
tool_call_mode = False
# Custom model special markers
tool_calls_begin_marker = "<tool▁calls▁begin>"
tool_call_begin_marker = "<tool▁call▁begin>"
tool_sep_marker = "<tool▁sep>"
tool_call_end_marker = "<tool▁call▁end>"
tool_calls_end_marker = "<tool▁calls▁end>"
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
if isinstance(res, RawUsage):
raw_usage = res
@ -344,11 +363,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
)
elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res
# Detecting the start of model-specific formatting tool calls
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
tool_call_mode = True
# Adjust full_content to remove tool call section
if buffer.endswith(tool_calls_begin_marker):
full_content = full_content[:-len(tool_calls_begin_marker)]
@ -356,7 +375,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
idx = (buffer + token).find(tool_calls_begin_marker)
full_content = full_content[:-(len(buffer) - idx)]
buffer = ""
# Accumulation of content in non-tool call mode
if not tool_call_mode:
full_content += token
@ -367,54 +386,18 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
else:
# In tool call mode, continue to collect tool call related text
buffer += token
# If the tool call end marker is found
if tool_calls_end_marker in buffer:
try:
# Parsing Calling Text Extraction Tool Calling Information
full_tool_call = buffer
# Extract function name
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
function_name_end = full_tool_call.find("\n", function_name_start)
function_name = full_tool_call[function_name_start:function_name_end].strip()
# Extract JSON Parameters - Extracts the content between ```json and ```.
json_pattern = r'```json\s*(.*?)\s*```'
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
if json_match:
arguments_str = json_match.group(1).strip()
# Generate tool call IDs
tool_call_id = f"call_{uuid4().hex[:24]}"
# Add to tool call list
tool_calls.append({
"id": tool_call_id,
"index": 0,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments_str
}
})
# If the tool call is successfully parsed, set the reason for completion
finish_reason = "tool_calls"
# reset state
tool_call_mode = False
buffer = ""
else:
# JSON extraction failed, probably incomplete formatting
logger.warning("Failed to extract JSON from tool call")
tool_call_mode = False
buffer = ""
except Exception as e:
logger.error(f"Error processing tool call: {e}")
tool_call_mode = False
buffer = ""
# Extract tool calls
tool_calls = getTools(buffer)
if tool_calls:
finish_reason = "tool_calls"
# Reset state
tool_call_mode = False
buffer = ""
# Build Response
response = {
"id": id,
@ -430,8 +413,8 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
},
"finish_reason": finish_reason or "stop"
}],
"usage": usage.__dict__,
"usage": usage.__dict__ if 'usage' in locals() else None,
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
}
return response
return response