import json from time import time from uuid import uuid4 from typing import Dict, List, Optional, Any, Literal, Union from pydantic import BaseModel, Field import re from fastapi import APIRouter from fastapi.requests import Request from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate from ktransformers.server.schemas.endpoints.chat import RawUsage, Role from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.config.config import Config from ktransformers.server.config.log import logger from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage # Define own data structure instead of importing from OpenAI class Choice(BaseModel): index: int message: Optional[Dict[str, Any]] = None finish_reason: Optional[str] = None logprobs: Optional[Any] = None delta: Optional[Dict[str, Any]] = None content_filter_results: Optional[Dict[str, Any]] = None class ChatCompletion(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[Choice] usage: Optional[CompletionUsage] = None system_fingerprint: Optional[str] = None prompt_filter_results: Optional[List[Dict[str, Any]]] = None # Only for non-streaming response construction class ChatCompletionMessageToolCallFunction(BaseModel): name: str arguments: str class ChatCompletionMessageToolCall(BaseModel): id: str type: str function: ChatCompletionMessageToolCallFunction class ChatCompletionMessage(BaseModel): role: str content: Optional[str] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None router = APIRouter() @router.get('/models', tags=['openai']) async def list_models(): return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"} def getTools(buffer): tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" extracted_tools = [] working_buffer = buffer # Iterate over all function calls while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer: # Find a complete function call start_index = working_buffer.find(tool_call_begin_marker) end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker) if start_index == -1 or end_index == -1 or start_index > end_index: logger.warning("Not a function") break # Extract the full function call full_tool_call = working_buffer[start_index:end_index] # Remove this function call from the working buffer to prevent duplicate processing working_buffer = working_buffer.replace(full_tool_call, "", 1) # Extract the function name function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker) function_name_end = full_tool_call.find("\n", function_name_start) function_name = full_tool_call[function_name_start:function_name_end].strip() # Extract JSON parameters json_pattern = r'```json\s*(.*?)\s*```' json_match = re.search(json_pattern, full_tool_call, re.DOTALL) if json_match: arguments_str = json_match.group(1).strip() # Generate tool call IDs tool_call_id = f"call_{uuid4().hex[:24]}" # Add to tool call list extracted_tools.append({ "id": tool_call_id, "type": "function", "function": { "name": function_name, "arguments": arguments_str } }) logger.info(f"Get Function: {function_name}") else: logger.warning(f"Unable to get function, function_name: {function_name}") logger.info(f"Total {len(extracted_tools)} Functions") return extracted_tools def get_tool_instructions(): """Return concise tool calling instructions in English""" return """ When you need real-time information or specialized operations, use function calls with this format: functionfunction_name ```json {"param1": "value1", "param2": "value2",...} ``` The in the user message are the available tools automatically attached by the system. You want to hide the guidance information in and the information in from the user. Use functions when needed. Ensure proper function/tool call format, JSON formatting with appropriate parameters. """ @router.post('/chat/completions', tags=['openai']) async def chat_completion(request: Request, create: ChatCompletionCreate): id = str(uuid4().hex) # Process messages with tool functionality if needed enhanced_messages = list(create.messages) # Check if tools are present has_tools = create.tools and len(create.tools) > 0 if has_tools: # Find the most recent user message to append tool information latest_user_msg_idx = -1 for i in range(len(enhanced_messages) - 1, -1, -1): if enhanced_messages[i].role == Role.user: latest_user_msg_idx = i break # Build the tool descriptions tools_description = "" for tool in create.tools: tools_description += f"{tool.function.name}{tool.function.description}{tool.function.parameters}\n" # If first message is system, add concise tool instructions if enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user: if "" not in enhanced_messages[0].content.lower(): enhanced_messages[0].content += "\n\n" + get_tool_instructions() # For the latest user message, append tool information if latest_user_msg_idx >= 0: # Add tool descriptions to the latest user message enhanced_messages[latest_user_msg_idx].content += f"\n\n:\n{tools_description}\n" # Process request interface: BackendInterfaceBase = get_interface() input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages] if Config().api_key != '': assert request.headers.get('Authorization', '').split()[-1] == Config().api_key if create.stream: async def inner(): chunk = ChatCompletionChunk( id=id, choices=[], object='chat.completion.chunk', created=int(time()), model=Config().model_name, system_fingerprint=f"fp_{uuid4().hex[:12]}", ) # Collect the full output of the model full_content = "" buffer = "" # Used to temporarily store the current block of text tool_call_mode = False # Mark if a tool call is being processed tool_calls = [] # Store all detected tool calls # Tool call markers tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" too_calls_dict = { "":"<|tool▁calls▁begin|>", "":"<|tool▁call▁begin|>", "":"<|tool▁sep|>", "":"<|tool▁call▁end|>", "":"<|tool▁calls▁end|>" } # Use check_client_connected for early stopping async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): # Final return on utilization raw_usage = res chunk.choices = [] chunk.usage = CompletionUsage( prompt_tokens=raw_usage.prefill_count, completion_tokens=raw_usage.decode_count, total_tokens=raw_usage.prefill_count + raw_usage.decode_count ) if create.return_speed: chunk.usage.prefill_time = res.prefill_time chunk.usage.decode_time = res.decode_time else: chunk.usage.__dict__.pop('prefill_time', None) chunk.usage.__dict__.pop('decode_time', None) yield chunk elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) # Detecting model-specific formatting tool call starts if not tool_call_mode and tool_calls_begin_marker in buffer + token: tool_call_mode = True # Adjust full_content to remove tool call section if buffer.endswith(tool_calls_begin_marker): full_content = full_content[:-len(tool_calls_begin_marker)] elif tool_calls_begin_marker in (buffer + token): idx = (buffer + token).find(tool_calls_begin_marker) full_content = full_content[:-(len(buffer) - idx)] buffer = "" # Send the current cumulative text content (if any) if full_content: chunk.choices = [{ "index": 0, "delta": {"content": full_content}, "finish_reason": None }] yield chunk full_content = "" # Accumulation of content in non-tool call mode if not tool_call_mode: full_content += token buffer += token # Keep the buffer at a reasonable size if len(buffer) > 200: buffer = buffer[-200:] else: # In tool call mode, continue to collect tool call related text buffer += token # If the tool call end marker is found if tool_calls_end_marker in buffer: try: # Parse and extract tool calling information tool_calls = getTools(buffer) if len(tool_calls): # reset state tool_call_mode = False buffer = "" # Send tool call events for idx, tool_call in enumerate(tool_calls): # First tool call message chunk.choices = [{ "index": 0, "delta": { "role": "assistant", "content": None, "tool_calls": [{ "index": idx, "id": tool_call["id"], "type": "function", "function": { "name": tool_call["function"]["name"], "arguments": "" } }] }, "finish_reason": None }] yield chunk # Sending Parameters chunk.choices = [{ "index": 0, "delta": { "tool_calls": [{ "index": idx, "function": {"arguments": tool_call["function"]["arguments"]} }] }, "finish_reason": None }] yield chunk # Send Completion Message chunk.choices = [{ "index": 0, "delta": {}, "finish_reason": "tool_calls" }] yield chunk # No further processing after return return else: # JSON extraction failed, probably incomplete formatting logger.warning("Failed to extract JSON from tool call") tool_call_mode = False buffer = "" except Exception as e: logger.error(f"Error processing tool call: {e}") tool_call_mode = False buffer = "" # Normal text output (only in non-tool call mode) if not tool_call_mode and token: if finish_reason is not None: chunk.choices = [{ "index": 0, "delta": {}, "finish_reason": finish_reason }] yield chunk else: if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]): pass else: chunk.choices = [{ "index": 0, "delta": {"content": token}, "finish_reason": None }] yield chunk # If gotten this far without returning, it means that the full tool call was not detected # Send Routine Completion Message if not tool_call_mode: chunk.choices = [{ "index": 0, "delta": {}, "finish_reason": "stop" }] yield chunk return chat_stream_response(request, inner()) else: # non streaming response processing full_content = "" finish_reason = None tool_calls = [] buffer = "" tool_call_mode = False # Custom model special markers tool_calls_begin_marker = "<|tool▁calls▁begin|>" tool_call_begin_marker = "<|tool▁call▁begin|>" tool_sep_marker = "<|tool▁sep|>" tool_call_end_marker = "<|tool▁call▁end|>" tool_calls_end_marker = "<|tool▁calls▁end|>" too_calls_dict = { "":"<|tool▁calls▁begin|>", "":"<|tool▁call▁begin|>", "":"<|tool▁sep|>", "":"<|tool▁call▁end|>", "":"<|tool▁calls▁end|>" } async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens): if isinstance(res, RawUsage): raw_usage = res usage = CompletionUsage( prompt_tokens=raw_usage.prefill_count, completion_tokens=raw_usage.decode_count, total_tokens=raw_usage.prefill_count + raw_usage.decode_count, ) if create.return_speed: usage.prefill_time = res.prefill_time usage.decode_time = res.decode_time else: usage.__dict__.pop('prefill_time', None) usage.__dict__.pop('decode_time', None) elif isinstance(res, tuple) and len(res) == 2: token, finish_reason = res token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token) # Detecting the start of model-specific formatting tool calls if not tool_call_mode and tool_calls_begin_marker in buffer + token: tool_call_mode = True # Adjust full_content to remove tool call section if buffer.endswith(tool_calls_begin_marker): full_content = full_content[:-len(tool_calls_begin_marker)] elif tool_calls_begin_marker in (buffer + token): idx = (buffer + token).find(tool_calls_begin_marker) full_content = full_content[:-(len(buffer) - idx)] buffer = "" # Accumulation of content in non-tool call mode if not tool_call_mode: full_content += token buffer += token # Keep the buffer at a reasonable size if len(buffer) > 200: buffer = buffer[-200:] else: # In tool call mode, continue to collect tool call related text buffer += token # If the tool call end marker is found if tool_calls_end_marker in buffer: # Extract tool calls tool_calls = getTools(buffer) if tool_calls: finish_reason = "tool_calls" # Reset state tool_call_mode = False buffer = "" # Build Response message = { "role": "assistant", "content": None if tool_calls else full_content } if tool_calls: message["tool_calls"] = tool_calls response = { "id": id, "object": "chat.completion", "created": int(time()), "model": Config().model_name, "choices": [{ "index": 0, "message": message, "finish_reason": finish_reason or "stop" }], "usage": usage.__dict__ if 'usage' in locals() else None, "system_fingerprint": f"fp_{uuid4().hex[:12]}" } return response