From 88f688e2c81f17a00128bd55ff4fc9536bbbcde9 Mon Sep 17 00:00:00 2001 From: Creeper-MZ Date: Wed, 16 Apr 2025 14:55:30 -0400 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9token=E6=B3=A8=E5=85=A5?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E5=87=8F=E5=B0=91token=E6=B3=A8?= =?UTF-8?q?=E5=85=A5=E9=87=8F=EF=BC=8C=E9=98=B2=E6=AD=A2=E9=81=97=E5=BF=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update chat.py Update chat.py Update chat.py --- .../server/api/openai/endpoints/chat.py | 201 ++++++++---------- 1 file changed, 92 insertions(+), 109 deletions(-) diff --git a/ktransformers/server/api/openai/endpoints/chat.py b/ktransformers/server/api/openai/endpoints/chat.py index 2092e97..9455f5c 100644 --- a/ktransformers/server/api/openai/endpoints/chat.py +++ b/ktransformers/server/api/openai/endpoints/chat.py @@ -71,38 +71,37 @@ def getTools(buffer): tool_calls_end_marker = "<|tool▁calls▁end|>" extracted_tools = [] working_buffer = buffer - - + # Iterate over all function calls while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer: # Find a complete function call start_index = working_buffer.find(tool_call_begin_marker) end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker) - + if start_index == -1 or end_index == -1 or start_index > end_index: logger.warning("Not a function") break - + # Extract the full function call full_tool_call = working_buffer[start_index:end_index] - + # Remove this function call from the working buffer to prevent duplicate processing working_buffer = working_buffer.replace(full_tool_call, "", 1) - + # Extract the function name function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker) function_name_end = full_tool_call.find("\n", function_name_start) function_name = full_tool_call[function_name_start:function_name_end].strip() - + # Extract JSON parameters json_pattern = r'```json\s*(.*?)\s*```' json_match = re.search(json_pattern, full_tool_call, re.DOTALL) - + if json_match: arguments_str = json_match.group(1).strip() # Generate tool call IDs tool_call_id = f"call_{uuid4().hex[:24]}" - + # Add to tool call list extracted_tools.append({ "id": tool_call_id, @@ -112,45 +111,65 @@ def getTools(buffer): "arguments": arguments_str } }) - + logger.info(f"Get Function: {function_name}") else: - logger.warning(f"Unable to get function,function_name: {function_name}") - + logger.warning(f"Unable to get function, function_name: {function_name}") + logger.info(f"Total {len(extracted_tools)} Functions") return extracted_tools +def get_tool_instructions(): + """Return concise tool calling instructions in English""" + return """When you need real-time information or specialized operations, use function calls with this format: + +<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>function_name +```json +{"param1": "value1", "param2": "value2"} +```<|tool▁call▁end|><|tool▁calls▁end|> + +Only use functions when needed. Ensure proper JSON formatting with appropriate parameters.""" + @router.post('/chat/completions', tags=['openai']) async def chat_completion(request: Request, create: ChatCompletionCreate): id = str(uuid4().hex) - - # 1. Use system prompts to let models know how to use tools + + # Process messages with tool functionality if needed enhanced_messages = list(create.messages) - - # If there is a tool and the first message is system, add instructions on how to use the tool in the system tip - if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user): - tool_instructions = "你可以使用function_call,函数调用功能,目前,你可以使用以下工具\n\n" + + # Check if tools are present + has_tools = create.tools and len(create.tools) > 0 + + if has_tools: + # Find the most recent user message to append tool information + latest_user_msg_idx = -1 + for i in range(len(enhanced_messages) - 1, -1, -1): + if enhanced_messages[i].role == Role.user: + latest_user_msg_idx = i + break + + # Build the tool descriptions + tools_description = "" for tool in create.tools: - tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n" - - # Modify tool usage guidelines to encourage JSON output - tool_instructions += "name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n" - tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n" - tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n" - tool_instructions += '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n' - tool_instructions += '示例: \n<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n' - tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n" - tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。" - - enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions - - # Requests processed + tools_description += f"Function: {tool.function.name}\nDescription: {tool.function.description}\nParameters: {tool.function.parameters}\n\n" + + # If first message is system, add concise tool instructions + if enhanced_messages[0].role == Role.system: + if "function calls" not in enhanced_messages[0].content.lower(): + enhanced_messages[0].content += "\n\n" + get_tool_instructions() + + # For the latest user message, append tool information + if latest_user_msg_idx >= 0: + # Add tool descriptions to the latest user message + enhanced_messages[latest_user_msg_idx].content += f"\n\nAvailable tools:\n{tools_description}" + + # Process request interface: BackendInterfaceBase = get_interface() input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages] - + if Config().api_key != '': assert request.headers.get('Authorization', '').split()[-1] == Config().api_key - + if create.stream: async def inner(): chunk = ChatCompletionChunk( @@ -161,20 +180,21 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): model=Config().model_name, system_fingerprint=f"fp_{uuid4().hex[:12]}", ) - - # Collect the full output of the model, but specialize in processing tool calls + + # Collect the full output of the model full_content = "" buffer = "" # Used to temporarily store the current block of text tool_call_mode = False # Mark if a tool call is being processed tool_calls = [] # Store all detected tool calls - - # Customize model special tokens + + # Tool call markers tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" - + + # Use check_client_connected for early stopping async for res in interface.inference(input_message, id, create.temperature, create.top_p): if isinstance(res, RawUsage): # Final return on utilization @@ -188,11 +208,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): yield chunk elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res - + # Detecting model-specific formatting tool call starts if not tool_call_mode and tool_calls_begin_marker in buffer + token: tool_call_mode = True - + # Adjust full_content to remove tool call section if buffer.endswith(tool_calls_begin_marker): full_content = full_content[:-len(tool_calls_begin_marker)] @@ -200,7 +220,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): idx = (buffer + token).find(tool_calls_begin_marker) full_content = full_content[:-(len(buffer) - idx)] buffer = "" - + # Send the current cumulative text content (if any) if full_content: chunk.choices = [{ @@ -210,7 +230,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): }] yield chunk full_content = "" - + # Accumulation of content in non-tool call mode if not tool_call_mode: full_content += token @@ -221,18 +241,17 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): else: # In tool call mode, continue to collect tool call related text buffer += token - + # If the tool call end marker is found if tool_calls_end_marker in buffer: try: - # Parsing Calling Text Extraction Tool Calling Information - + # Parse and extract tool calling information tool_calls = getTools(buffer) if len(tool_calls): # reset state tool_call_mode = False buffer = "" - + # Send tool call events for idx, tool_call in enumerate(tool_calls): # First tool call message @@ -254,7 +273,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): "finish_reason": None }] yield chunk - + # Sending Parameters chunk.choices = [{ "index": 0, @@ -267,7 +286,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): "finish_reason": None }] yield chunk - + # Send Completion Message chunk.choices = [{ "index": 0, @@ -275,7 +294,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): "finish_reason": "tool_calls" }] yield chunk - + # No further processing after return return else: @@ -287,7 +306,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): logger.error(f"Error processing tool call: {e}") tool_call_mode = False buffer = "" - + # Normal text output (only in non-tool call mode) if not tool_call_mode and token: if finish_reason is not None: @@ -307,17 +326,17 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): "finish_reason": None }] yield chunk - + # If gotten this far without returning, it means that the full tool call was not detected # Send Routine Completion Message if not tool_call_mode: chunk.choices = [{ - "index": 0, - "delta": {}, + "index": 0, + "delta": {}, "finish_reason": "stop" }] yield chunk - + return chat_stream_response(request, inner()) else: # non streaming response processing @@ -326,14 +345,14 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): tool_calls = [] buffer = "" tool_call_mode = False - + # Custom model special markers tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" - + async for res in interface.inference(input_message, id, create.temperature, create.top_p): if isinstance(res, RawUsage): raw_usage = res @@ -344,11 +363,11 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ) elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res - + # Detecting the start of model-specific formatting tool calls if not tool_call_mode and tool_calls_begin_marker in buffer + token: tool_call_mode = True - + # Adjust full_content to remove tool call section if buffer.endswith(tool_calls_begin_marker): full_content = full_content[:-len(tool_calls_begin_marker)] @@ -356,7 +375,7 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): idx = (buffer + token).find(tool_calls_begin_marker) full_content = full_content[:-(len(buffer) - idx)] buffer = "" - + # Accumulation of content in non-tool call mode if not tool_call_mode: full_content += token @@ -367,54 +386,18 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): else: # In tool call mode, continue to collect tool call related text buffer += token - + # If the tool call end marker is found if tool_calls_end_marker in buffer: - try: - # Parsing Calling Text Extraction Tool Calling Information - full_tool_call = buffer - - # Extract function name - function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker) - function_name_end = full_tool_call.find("\n", function_name_start) - function_name = full_tool_call[function_name_start:function_name_end].strip() - - # Extract JSON Parameters - Extracts the content between ```json and ```. - json_pattern = r'```json\s*(.*?)\s*```' - json_match = re.search(json_pattern, full_tool_call, re.DOTALL) - - if json_match: - arguments_str = json_match.group(1).strip() - # Generate tool call IDs - tool_call_id = f"call_{uuid4().hex[:24]}" - - # Add to tool call list - tool_calls.append({ - "id": tool_call_id, - "index": 0, - "type": "function", - "function": { - "name": function_name, - "arguments": arguments_str - } - }) - - # If the tool call is successfully parsed, set the reason for completion - finish_reason = "tool_calls" - - # reset state - tool_call_mode = False - buffer = "" - else: - # JSON extraction failed, probably incomplete formatting - logger.warning("Failed to extract JSON from tool call") - tool_call_mode = False - buffer = "" - except Exception as e: - logger.error(f"Error processing tool call: {e}") - tool_call_mode = False - buffer = "" - + # Extract tool calls + tool_calls = getTools(buffer) + if tool_calls: + finish_reason = "tool_calls" + + # Reset state + tool_call_mode = False + buffer = "" + # Build Response response = { "id": id, @@ -430,8 +413,8 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): }, "finish_reason": finish_reason or "stop" }], - "usage": usage.__dict__, + "usage": usage.__dict__ if 'usage' in locals() else None, "system_fingerprint": f"fp_{uuid4().hex[:12]}" } - - return response + + return response \ No newline at end of file