mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 05:54:06 +00:00
更改token注入逻辑,减少token注入量,防止遗忘
Update chat.py Update chat.py Update chat.py
This commit is contained in:
parent
a7e8d7c1af
commit
88f688e2c8
1 changed files with 92 additions and 109 deletions
|
@ -72,7 +72,6 @@ def getTools(buffer):
|
||||||
extracted_tools = []
|
extracted_tools = []
|
||||||
working_buffer = buffer
|
working_buffer = buffer
|
||||||
|
|
||||||
|
|
||||||
# Iterate over all function calls
|
# Iterate over all function calls
|
||||||
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
|
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
|
||||||
# Find a complete function call
|
# Find a complete function call
|
||||||
|
@ -115,36 +114,56 @@ def getTools(buffer):
|
||||||
|
|
||||||
logger.info(f"Get Function: {function_name}")
|
logger.info(f"Get Function: {function_name}")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unable to get function,function_name: {function_name}")
|
logger.warning(f"Unable to get function, function_name: {function_name}")
|
||||||
|
|
||||||
logger.info(f"Total {len(extracted_tools)} Functions")
|
logger.info(f"Total {len(extracted_tools)} Functions")
|
||||||
return extracted_tools
|
return extracted_tools
|
||||||
|
|
||||||
|
def get_tool_instructions():
|
||||||
|
"""Return concise tool calling instructions in English"""
|
||||||
|
return """When you need real-time information or specialized operations, use function calls with this format:
|
||||||
|
|
||||||
|
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>function_name
|
||||||
|
```json
|
||||||
|
{"param1": "value1", "param2": "value2"}
|
||||||
|
```<|tool▁call▁end|><|tool▁calls▁end|>
|
||||||
|
|
||||||
|
Only use functions when needed. Ensure proper JSON formatting with appropriate parameters."""
|
||||||
|
|
||||||
@router.post('/chat/completions', tags=['openai'])
|
@router.post('/chat/completions', tags=['openai'])
|
||||||
async def chat_completion(request: Request, create: ChatCompletionCreate):
|
async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
id = str(uuid4().hex)
|
id = str(uuid4().hex)
|
||||||
|
|
||||||
# 1. Use system prompts to let models know how to use tools
|
# Process messages with tool functionality if needed
|
||||||
enhanced_messages = list(create.messages)
|
enhanced_messages = list(create.messages)
|
||||||
|
|
||||||
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
|
# Check if tools are present
|
||||||
if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user):
|
has_tools = create.tools and len(create.tools) > 0
|
||||||
tool_instructions = "你可以使用function_call,函数调用功能,目前,你可以使用以下工具\n\n"
|
|
||||||
|
if has_tools:
|
||||||
|
# Find the most recent user message to append tool information
|
||||||
|
latest_user_msg_idx = -1
|
||||||
|
for i in range(len(enhanced_messages) - 1, -1, -1):
|
||||||
|
if enhanced_messages[i].role == Role.user:
|
||||||
|
latest_user_msg_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# Build the tool descriptions
|
||||||
|
tools_description = ""
|
||||||
for tool in create.tools:
|
for tool in create.tools:
|
||||||
tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n"
|
tools_description += f"Function: {tool.function.name}\nDescription: {tool.function.description}\nParameters: {tool.function.parameters}\n\n"
|
||||||
|
|
||||||
# Modify tool usage guidelines to encourage JSON output
|
# If first message is system, add concise tool instructions
|
||||||
tool_instructions += "name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n"
|
if enhanced_messages[0].role == Role.system:
|
||||||
tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n"
|
if "function calls" not in enhanced_messages[0].content.lower():
|
||||||
tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n"
|
enhanced_messages[0].content += "\n\n" + get_tool_instructions()
|
||||||
tool_instructions += '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
|
||||||
tool_instructions += '示例: \n<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
|
||||||
tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n"
|
|
||||||
tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
|
|
||||||
|
|
||||||
enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions
|
# For the latest user message, append tool information
|
||||||
|
if latest_user_msg_idx >= 0:
|
||||||
|
# Add tool descriptions to the latest user message
|
||||||
|
enhanced_messages[latest_user_msg_idx].content += f"\n\nAvailable tools:\n{tools_description}"
|
||||||
|
|
||||||
# Requests processed
|
# Process request
|
||||||
interface: BackendInterfaceBase = get_interface()
|
interface: BackendInterfaceBase = get_interface()
|
||||||
input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
|
input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
|
||||||
|
|
||||||
|
@ -162,19 +181,20 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
system_fingerprint=f"fp_{uuid4().hex[:12]}",
|
system_fingerprint=f"fp_{uuid4().hex[:12]}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect the full output of the model, but specialize in processing tool calls
|
# Collect the full output of the model
|
||||||
full_content = ""
|
full_content = ""
|
||||||
buffer = "" # Used to temporarily store the current block of text
|
buffer = "" # Used to temporarily store the current block of text
|
||||||
tool_call_mode = False # Mark if a tool call is being processed
|
tool_call_mode = False # Mark if a tool call is being processed
|
||||||
tool_calls = [] # Store all detected tool calls
|
tool_calls = [] # Store all detected tool calls
|
||||||
|
|
||||||
# Customize model special tokens
|
# Tool call markers
|
||||||
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
tool_sep_marker = "<|tool▁sep|>"
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
tool_call_end_marker = "<|tool▁call▁end|>"
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
|
||||||
|
# Use check_client_connected for early stopping
|
||||||
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
# Final return on utilization
|
# Final return on utilization
|
||||||
|
@ -225,8 +245,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
# If the tool call end marker is found
|
# If the tool call end marker is found
|
||||||
if tool_calls_end_marker in buffer:
|
if tool_calls_end_marker in buffer:
|
||||||
try:
|
try:
|
||||||
# Parsing Calling Text Extraction Tool Calling Information
|
# Parse and extract tool calling information
|
||||||
|
|
||||||
tool_calls = getTools(buffer)
|
tool_calls = getTools(buffer)
|
||||||
if len(tool_calls):
|
if len(tool_calls):
|
||||||
# reset state
|
# reset state
|
||||||
|
@ -370,48 +389,12 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
|
|
||||||
# If the tool call end marker is found
|
# If the tool call end marker is found
|
||||||
if tool_calls_end_marker in buffer:
|
if tool_calls_end_marker in buffer:
|
||||||
try:
|
# Extract tool calls
|
||||||
# Parsing Calling Text Extraction Tool Calling Information
|
tool_calls = getTools(buffer)
|
||||||
full_tool_call = buffer
|
if tool_calls:
|
||||||
|
|
||||||
# Extract function name
|
|
||||||
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
|
||||||
function_name_end = full_tool_call.find("\n", function_name_start)
|
|
||||||
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
|
||||||
|
|
||||||
# Extract JSON Parameters - Extracts the content between ```json and ```.
|
|
||||||
json_pattern = r'```json\s*(.*?)\s*```'
|
|
||||||
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
|
||||||
|
|
||||||
if json_match:
|
|
||||||
arguments_str = json_match.group(1).strip()
|
|
||||||
# Generate tool call IDs
|
|
||||||
tool_call_id = f"call_{uuid4().hex[:24]}"
|
|
||||||
|
|
||||||
# Add to tool call list
|
|
||||||
tool_calls.append({
|
|
||||||
"id": tool_call_id,
|
|
||||||
"index": 0,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": function_name,
|
|
||||||
"arguments": arguments_str
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# If the tool call is successfully parsed, set the reason for completion
|
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
# reset state
|
# Reset state
|
||||||
tool_call_mode = False
|
|
||||||
buffer = ""
|
|
||||||
else:
|
|
||||||
# JSON extraction failed, probably incomplete formatting
|
|
||||||
logger.warning("Failed to extract JSON from tool call")
|
|
||||||
tool_call_mode = False
|
|
||||||
buffer = ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error processing tool call: {e}")
|
|
||||||
tool_call_mode = False
|
tool_call_mode = False
|
||||||
buffer = ""
|
buffer = ""
|
||||||
|
|
||||||
|
@ -430,7 +413,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
},
|
},
|
||||||
"finish_reason": finish_reason or "stop"
|
"finish_reason": finish_reason or "stop"
|
||||||
}],
|
}],
|
||||||
"usage": usage.__dict__,
|
"usage": usage.__dict__ if 'usage' in locals() else None,
|
||||||
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
|
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue