mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
Refactor the chat interface to support tool calling and parameter processing
Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling. Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations. Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases. Updated ktransformers.py and transformers.py to support the passing of tool parameters. Modified the default value of top_p in config.py to 1.0 to increase generation diversity. Extended the message model in chat.py to support the transmission of tool call information. These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
This commit is contained in:
parent
038db30ec9
commit
8699109129
6 changed files with 574 additions and 99 deletions
|
@ -1,19 +1,61 @@
|
||||||
import json
|
import json
|
||||||
from time import time
|
from time import time
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from typing import Dict, List, Optional, Any, Literal, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import re
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from ktransformers.server.utils.create_interface import get_interface
|
from ktransformers.server.utils.create_interface import get_interface
|
||||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
||||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
|
||||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
from ktransformers.server.config.log import logger
|
||||||
|
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
||||||
from openai.types.chat import ChatCompletion
|
|
||||||
from openai.types.completion_usage import CompletionUsage
|
|
||||||
|
|
||||||
|
# Define own data structure instead of importing from OpenAI
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
completion_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: Optional[Dict[str, Any]] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
logprobs: Optional[Any] = None
|
||||||
|
delta: Optional[Dict[str, Any]] = None
|
||||||
|
content_filter_results: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class ChatCompletion(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[Choice]
|
||||||
|
usage: Optional[CompletionUsage] = None
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
|
prompt_filter_results: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
# Only for non-streaming response construction
|
||||||
|
class ChatCompletionMessageToolCallFunction(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
class ChatCompletionMessageToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: ChatCompletionMessageToolCallFunction
|
||||||
|
|
||||||
|
class ChatCompletionMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -21,90 +63,375 @@ router = APIRouter()
|
||||||
async def list_models():
|
async def list_models():
|
||||||
return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"}
|
return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"}
|
||||||
|
|
||||||
|
def getTools(buffer):
|
||||||
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
extracted_tools = []
|
||||||
|
working_buffer = buffer
|
||||||
|
|
||||||
|
|
||||||
|
# Iterate over all function calls
|
||||||
|
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
|
||||||
|
# Find a complete function call
|
||||||
|
start_index = working_buffer.find(tool_call_begin_marker)
|
||||||
|
end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)
|
||||||
|
|
||||||
|
if start_index == -1 or end_index == -1 or start_index > end_index:
|
||||||
|
logger.warning("Not a function")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract the full function call
|
||||||
|
full_tool_call = working_buffer[start_index:end_index]
|
||||||
|
|
||||||
|
# Remove this function call from the working buffer to prevent duplicate processing
|
||||||
|
working_buffer = working_buffer.replace(full_tool_call, "", 1)
|
||||||
|
|
||||||
|
# Extract the function name
|
||||||
|
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
||||||
|
function_name_end = full_tool_call.find("\n", function_name_start)
|
||||||
|
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
||||||
|
|
||||||
|
# Extract JSON parameters
|
||||||
|
json_pattern = r'```json\s*(.*?)\s*```'
|
||||||
|
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
||||||
|
|
||||||
|
if json_match:
|
||||||
|
arguments_str = json_match.group(1).strip()
|
||||||
|
# Generate tool call IDs
|
||||||
|
tool_call_id = f"call_{uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
# Add to tool call list
|
||||||
|
extracted_tools.append({
|
||||||
|
"id": tool_call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments_str
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Get Function: {function_name}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unable to get function,function_name: {function_name}")
|
||||||
|
|
||||||
|
logger.info(f"Total {len(extracted_tools)} Functions")
|
||||||
|
return extracted_tools
|
||||||
|
|
||||||
@router.post('/chat/completions', tags=['openai'])
|
@router.post('/chat/completions', tags=['openai'])
|
||||||
async def chat_completion(request:Request,create:ChatCompletionCreate):
|
async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
id = str(uuid4())
|
id = str(uuid4().hex)
|
||||||
|
|
||||||
|
# 1. Use system prompts to let models know how to use tools
|
||||||
|
enhanced_messages = list(create.messages)
|
||||||
|
|
||||||
|
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
|
||||||
|
if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user):
|
||||||
|
tool_instructions = "你可以使用function_call,函数调用功能,目前,你可以使用以下工具\n\n"
|
||||||
|
for tool in create.tools:
|
||||||
|
tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n"
|
||||||
|
|
||||||
|
# Modify tool usage guidelines to encourage JSON output
|
||||||
|
tool_instructions += "name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n"
|
||||||
|
tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n"
|
||||||
|
tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n"
|
||||||
|
tool_instructions += '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
||||||
|
tool_instructions += '示例: \n<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
||||||
|
tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n"
|
||||||
|
tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
|
||||||
|
|
||||||
|
enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions
|
||||||
|
|
||||||
|
# Requests processed
|
||||||
interface: BackendInterfaceBase = get_interface()
|
interface: BackendInterfaceBase = get_interface()
|
||||||
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
|
input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
|
||||||
|
|
||||||
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
|
|
||||||
|
|
||||||
if Config().api_key != '':
|
if Config().api_key != '':
|
||||||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
chunk = ChatCompletionChunk(
|
chunk = ChatCompletionChunk(
|
||||||
id = id,
|
id=id,
|
||||||
choices = [],
|
choices=[],
|
||||||
object = 'chat.completion.chunk',
|
object='chat.completion.chunk',
|
||||||
created = int(time()),
|
created=int(time()),
|
||||||
model = Config().model_name,
|
model=Config().model_name,
|
||||||
|
system_fingerprint=f"fp_{uuid4().hex[:12]}",
|
||||||
)
|
)
|
||||||
|
|
||||||
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
|
# Collect the full output of the model, but specialize in processing tool calls
|
||||||
|
full_content = ""
|
||||||
|
buffer = "" # Used to temporarily store the current block of text
|
||||||
|
tool_call_mode = False # Mark if a tool call is being processed
|
||||||
|
tool_calls = [] # Store all detected tool calls
|
||||||
|
|
||||||
|
# Customize model special tokens
|
||||||
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
|
||||||
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
# at the end of inference, interface.inference() will return the usage of inference
|
# Final return on utilization
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
chunk.choices = []
|
chunk.choices = []
|
||||||
chunk.usage = CompletionUsage(
|
chunk.usage = CompletionUsage(
|
||||||
prompt_tokens = raw_usage.prefill_count,
|
prompt_tokens=raw_usage.prefill_count,
|
||||||
completion_tokens = raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||||
)
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
else:
|
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
choice = Choice(
|
|
||||||
index = 0,
|
# Detecting model-specific formatting tool call starts
|
||||||
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
|
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
|
||||||
finish_reason = finish_reason,
|
tool_call_mode = True
|
||||||
logprobs = None,
|
|
||||||
)
|
# Adjust full_content to remove tool call section
|
||||||
chunk.choices = [choice]
|
if buffer.endswith(tool_calls_begin_marker):
|
||||||
yield chunk
|
full_content = full_content[:-len(tool_calls_begin_marker)]
|
||||||
|
elif tool_calls_begin_marker in (buffer + token):
|
||||||
|
idx = (buffer + token).find(tool_calls_begin_marker)
|
||||||
|
full_content = full_content[:-(len(buffer) - idx)]
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Send the current cumulative text content (if any)
|
||||||
|
if full_content:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": full_content},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
full_content = ""
|
||||||
|
|
||||||
|
# Accumulation of content in non-tool call mode
|
||||||
|
if not tool_call_mode:
|
||||||
|
full_content += token
|
||||||
|
buffer += token
|
||||||
|
# Keep the buffer at a reasonable size
|
||||||
|
if len(buffer) > 200:
|
||||||
|
buffer = buffer[-200:]
|
||||||
|
else:
|
||||||
|
# In tool call mode, continue to collect tool call related text
|
||||||
|
buffer += token
|
||||||
|
|
||||||
|
# If the tool call end marker is found
|
||||||
|
if tool_calls_end_marker in buffer:
|
||||||
|
try:
|
||||||
|
# Parsing Calling Text Extraction Tool Calling Information
|
||||||
|
|
||||||
|
tool_calls = getTools(buffer)
|
||||||
|
if len(tool_calls):
|
||||||
|
# reset state
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Send tool call events
|
||||||
|
for idx, tool_call in enumerate(tool_calls):
|
||||||
|
# First tool call message
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": idx,
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"arguments": ""
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Sending Parameters
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": idx,
|
||||||
|
"function": {"arguments": tool_call["function"]["arguments"]}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Send Completion Message
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "tool_calls"
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# No further processing after return
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# JSON extraction failed, probably incomplete formatting
|
||||||
|
logger.warning("Failed to extract JSON from tool call")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing tool call: {e}")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Normal text output (only in non-tool call mode)
|
||||||
|
if not tool_call_mode and token:
|
||||||
|
if finish_reason is not None:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": finish_reason
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": token},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# If gotten this far without returning, it means that the full tool call was not detected
|
||||||
|
# Send Routine Completion Message
|
||||||
|
if not tool_call_mode:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
return chat_stream_response(request, inner())
|
return chat_stream_response(request, inner())
|
||||||
else:
|
else:
|
||||||
from openai.types.chat.chat_completion import Choice
|
# non streaming response processing
|
||||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
full_content = ""
|
||||||
|
|
||||||
content = ""
|
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
|
tool_calls = []
|
||||||
|
buffer = ""
|
||||||
|
tool_call_mode = False
|
||||||
|
|
||||||
|
# Custom model special markers
|
||||||
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
|
||||||
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
usage = CompletionUsage(
|
usage = CompletionUsage(
|
||||||
prompt_tokens = raw_usage.prefill_count,
|
prompt_tokens=raw_usage.prefill_count,
|
||||||
completion_tokens = raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
content = content + token
|
|
||||||
finish_reason = finish_reason
|
# Detecting the start of model-specific formatting tool calls
|
||||||
|
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
|
||||||
choice = Choice(
|
tool_call_mode = True
|
||||||
index = 0,
|
|
||||||
finish_reason = finish_reason,
|
# Adjust full_content to remove tool call section
|
||||||
message = ChatCompletionMessage(
|
if buffer.endswith(tool_calls_begin_marker):
|
||||||
content=content,
|
full_content = full_content[:-len(tool_calls_begin_marker)]
|
||||||
role="assistant"
|
elif tool_calls_begin_marker in (buffer + token):
|
||||||
))
|
idx = (buffer + token).find(tool_calls_begin_marker)
|
||||||
|
full_content = full_content[:-(len(buffer) - idx)]
|
||||||
chat_completion = ChatCompletion(
|
buffer = ""
|
||||||
id = id,
|
|
||||||
choices = [choice],
|
# Accumulation of content in non-tool call mode
|
||||||
created = int(time()),
|
if not tool_call_mode:
|
||||||
model = Config().model_name,
|
full_content += token
|
||||||
object = 'chat.completion',
|
buffer += token
|
||||||
usage = usage
|
# Keep the buffer at a reasonable size
|
||||||
)
|
if len(buffer) > 200:
|
||||||
|
buffer = buffer[-200:]
|
||||||
return chat_completion
|
else:
|
||||||
|
# In tool call mode, continue to collect tool call related text
|
||||||
|
buffer += token
|
||||||
|
|
||||||
|
# If the tool call end marker is found
|
||||||
|
if tool_calls_end_marker in buffer:
|
||||||
|
try:
|
||||||
|
# Parsing Calling Text Extraction Tool Calling Information
|
||||||
|
full_tool_call = buffer
|
||||||
|
|
||||||
|
# Extract function name
|
||||||
|
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
||||||
|
function_name_end = full_tool_call.find("\n", function_name_start)
|
||||||
|
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
||||||
|
|
||||||
|
# Extract JSON Parameters - Extracts the content between ```json and ```.
|
||||||
|
json_pattern = r'```json\s*(.*?)\s*```'
|
||||||
|
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
||||||
|
|
||||||
|
if json_match:
|
||||||
|
arguments_str = json_match.group(1).strip()
|
||||||
|
# Generate tool call IDs
|
||||||
|
tool_call_id = f"call_{uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
# Add to tool call list
|
||||||
|
tool_calls.append({
|
||||||
|
"id": tool_call_id,
|
||||||
|
"index": 0,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments_str
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# If the tool call is successfully parsed, set the reason for completion
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
# reset state
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
else:
|
||||||
|
# JSON extraction failed, probably incomplete formatting
|
||||||
|
logger.warning("Failed to extract JSON from tool call")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing tool call: {e}")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Build Response
|
||||||
|
response = {
|
||||||
|
"id": id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time()),
|
||||||
|
"model": Config().model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None if tool_calls else full_content,
|
||||||
|
"tool_calls": tool_calls if tool_calls else None
|
||||||
|
},
|
||||||
|
"finish_reason": finish_reason or "stop"
|
||||||
|
}],
|
||||||
|
"usage": usage.__dict__,
|
||||||
|
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
|
@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
# thread_related
|
# thread_related
|
||||||
last_request_id: Optional[str] = None
|
last_request_id: Optional[str] = None
|
||||||
ever_generated_ids: Set[int] = set()
|
ever_generated_ids: Set[int] = set()
|
||||||
|
|
||||||
def __init__(self, args: ConfigArgs = default_args):
|
def __init__(self, args: ConfigArgs = default_args):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.queue_map:dict[int,asyncio.Queue] = {}
|
self.queue_map:dict[int,asyncio.Queue] = {}
|
||||||
|
@ -282,7 +283,21 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
p.start()
|
p.start()
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
start_event.wait()
|
start_event.wait()
|
||||||
|
|
||||||
|
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
|
||||||
|
"""Get sampling parameters and handle default values and edge cases"""
|
||||||
|
if temperature is None:
|
||||||
|
temperature = Config().temperature
|
||||||
|
if top_p is None:
|
||||||
|
top_p = Config().top_p
|
||||||
|
|
||||||
|
if temperature == 0:
|
||||||
|
temperature = 0.0001
|
||||||
|
if top_p == 0:
|
||||||
|
top_p = 0.0001
|
||||||
|
|
||||||
|
return temperature, top_p
|
||||||
|
|
||||||
def run_queue_proxy(self):
|
def run_queue_proxy(self):
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
if isinstance(local_messages, List):
|
if isinstance(local_messages, List):
|
||||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||||
elif isinstance(local_messages, str):
|
elif isinstance(local_messages, str):
|
||||||
#local_messages = local_messages[0]['content']
|
|
||||||
input_ids = self.tokenize_prompt(local_messages)
|
input_ids = self.tokenize_prompt(local_messages)
|
||||||
else:
|
else:
|
||||||
raise ValueError("local_messages should be List or str")
|
raise ValueError("local_messages should be List or str")
|
||||||
|
@ -352,12 +366,9 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
[input_ids, token_thinks], dim=1
|
[input_ids, token_thinks], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
profiler.pause_timer("tokenize")
|
profiler.pause_timer("tokenize")
|
||||||
|
|
||||||
profiler.create_and_start_timer("prefill")
|
profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
query_add = sched_ext.QueryAdd()
|
query_add = sched_ext.QueryAdd()
|
||||||
query_add.query_token = input_ids[0].tolist()
|
query_add.query_token = input_ids[0].tolist()
|
||||||
|
@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||||
#@TODO add server
|
#@TODO add server
|
||||||
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
|
||||||
query_add.stop_criteria = stop_criteria
|
query_add.stop_criteria = stop_criteria
|
||||||
if temperature == 0:
|
|
||||||
temperature = 0.0001
|
temperature, top_p = self.get_sampling_params(temperature, top_p)
|
||||||
|
|
||||||
query_add.sample_options.temperature = temperature
|
query_add.sample_options.temperature = temperature
|
||||||
if top_p == 0:
|
|
||||||
top_p = 0.0001
|
|
||||||
query_add.sample_options.top_p = top_p
|
query_add.sample_options.top_p = top_p
|
||||||
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from typing import Optional, List
|
||||||
import asyncio
|
import asyncio
|
||||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||||
from ktransformers.server.backend.interfaces.transformers import (
|
from ktransformers.server.backend.interfaces.transformers import (
|
||||||
|
@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
|
||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||||
async with self._infer_lock:
|
async with self._infer_lock:
|
||||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
|
||||||
yield v
|
yield v
|
||||||
|
|
||||||
# return this inference raw usage
|
# return this inference raw usage
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from typing import Any, List, Optional, Set
|
from typing import Any, List, Optional, Set
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||||
self.streamer.reset()
|
self.streamer.reset()
|
||||||
self.profiler.create_and_start_timer("tokenize")
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
|
|
||||||
|
# Check if tools are present
|
||||||
|
has_tools = tools is not None and len(tools) > 0
|
||||||
|
|
||||||
if isinstance(local_messages, List):
|
if isinstance(local_messages, List):
|
||||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||||
elif isinstance(local_messages, str):
|
elif isinstance(local_messages, str):
|
||||||
#local_messages = local_messages[0]['content']
|
|
||||||
input_ids = self.tokenize_prompt(local_messages)
|
input_ids = self.tokenize_prompt(local_messages)
|
||||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("local_messages should be List or str")
|
raise ValueError("local_messages should be List or str")
|
||||||
|
|
||||||
|
@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.profiler.pause_timer("tokenize")
|
self.profiler.pause_timer("tokenize")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("prefill")
|
self.profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
|
@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
yield think, None
|
yield think, None
|
||||||
|
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||||
# output think token after prefill done
|
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
yield t, None
|
yield t, None
|
||||||
self.profiler.pause_timer("prefill")
|
self.profiler.pause_timer("prefill")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("decode")
|
self.profiler.create_and_start_timer("decode")
|
||||||
for t, finish_reason in self.generate():
|
|
||||||
if t is not None:
|
# Handle tool calling
|
||||||
print(t, end="",flush=True)
|
if has_tools:
|
||||||
yield t, finish_reason
|
# Start collecting tokens until we detect a tool call
|
||||||
print("")
|
collected_tokens = ""
|
||||||
|
is_collecting_tool_call = False
|
||||||
|
is_function_name_collected = False
|
||||||
|
function_name = ""
|
||||||
|
collected_arguments = ""
|
||||||
|
brackets_count = 0
|
||||||
|
|
||||||
|
for t, finish_reason in self.generate():
|
||||||
|
if t is not None:
|
||||||
|
print(t, end="", flush=True)
|
||||||
|
collected_tokens += t
|
||||||
|
|
||||||
|
# Check if we're starting a tool call
|
||||||
|
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
|
||||||
|
is_collecting_tool_call = True
|
||||||
|
|
||||||
|
# Generate a unique tool call ID
|
||||||
|
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
|
||||||
|
|
||||||
|
# Send first tool call info
|
||||||
|
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
|
||||||
|
# If tools are provided, use the first one's name
|
||||||
|
recommended_function = tools[0].function.name
|
||||||
|
else:
|
||||||
|
# Otherwise try to extract from context
|
||||||
|
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||||
|
recommended_function = function_match.group(1) if function_match else ""
|
||||||
|
|
||||||
|
yield {
|
||||||
|
'tool_call': {
|
||||||
|
'id': tool_call_id,
|
||||||
|
'type': 'function',
|
||||||
|
'index': 0,
|
||||||
|
'function': {
|
||||||
|
'name': recommended_function,
|
||||||
|
'arguments': ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'first_chunk': True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract function name if we're collecting tool call
|
||||||
|
if is_collecting_tool_call and not is_function_name_collected:
|
||||||
|
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||||
|
if name_match:
|
||||||
|
function_name = name_match.group(1)
|
||||||
|
is_function_name_collected = True
|
||||||
|
|
||||||
|
# Track argument collection
|
||||||
|
if is_collecting_tool_call and is_function_name_collected:
|
||||||
|
args_position = collected_tokens.find('"arguments"')
|
||||||
|
if args_position > -1:
|
||||||
|
# Find the start of the JSON object after "arguments":
|
||||||
|
json_start = collected_tokens.find('{', args_position)
|
||||||
|
if json_start > -1:
|
||||||
|
for i in range(json_start, len(collected_tokens)):
|
||||||
|
char = collected_tokens[i]
|
||||||
|
collected_arguments += char
|
||||||
|
|
||||||
|
if char == '{':
|
||||||
|
brackets_count += 1
|
||||||
|
elif char == '}':
|
||||||
|
brackets_count -= 1
|
||||||
|
|
||||||
|
# Check if we've completed the arguments JSON
|
||||||
|
if brackets_count == 0:
|
||||||
|
# Send argument chunk
|
||||||
|
yield {
|
||||||
|
'tool_call': {
|
||||||
|
'id': tool_call_id,
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': function_name,
|
||||||
|
'arguments': collected_arguments
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'argument_chunk': collected_arguments,
|
||||||
|
'last_chunk': True,
|
||||||
|
'prompt_tokens': 176,
|
||||||
|
'completion_tokens': 20
|
||||||
|
}
|
||||||
|
# Reset for next potential tool call
|
||||||
|
collected_tokens = ""
|
||||||
|
is_collecting_tool_call = False
|
||||||
|
is_function_name_collected = False
|
||||||
|
function_name = ""
|
||||||
|
collected_arguments = ""
|
||||||
|
brackets_count = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle finish reason
|
||||||
|
if finish_reason is not None:
|
||||||
|
yield "", finish_reason
|
||||||
|
|
||||||
|
print("")
|
||||||
|
else:
|
||||||
|
# Regular text generation (no tools)
|
||||||
|
for t, finish_reason in self.generate():
|
||||||
|
if t is not None:
|
||||||
|
print(t, end="",flush=True)
|
||||||
|
yield t, finish_reason
|
||||||
|
print("")
|
||||||
|
|
||||||
self.profiler.pause_timer("decode")
|
self.profiler.pause_timer("decode")
|
||||||
self.report_last_time_performance()
|
self.report_last_time_performance()
|
||||||
|
|
|
@ -133,7 +133,7 @@ class Config(metaclass=Singleton):
|
||||||
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
|
self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
|
||||||
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
|
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
|
||||||
self.top_k = self.model.get("top_k", 50)
|
self.top_k = self.model.get("top_k", 50)
|
||||||
self.top_p = self.model.get("top_p", 0.8)
|
self.top_p = self.model.get("top_p", 1.0)
|
||||||
self.top_a = self.model.get("top_a", 0.0)
|
self.top_a = self.model.get("top_a", 0.0)
|
||||||
self.skew = self.model.get("skew", 0.0)
|
self.skew = self.model.get("skew", 0.0)
|
||||||
self.typical = self.model.get("typical", 0.0)
|
self.typical = self.model.get("typical", 0.0)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union, Dict, Any
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
|
||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionUsage
|
||||||
from openai.types.chat.chat_completion_chunk import Choice
|
from openai.types.chat.chat_completion_chunk import Choice
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
system = 'system'
|
system = 'system'
|
||||||
|
@ -17,26 +20,57 @@ class Role(Enum):
|
||||||
tool = 'tool'
|
tool = 'tool'
|
||||||
function = 'function'
|
function = 'function'
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
role:Role
|
role: Role
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
def to_tokenizer_message(self):
|
def to_tokenizer_message(self):
|
||||||
return {'content':self.content,'role':self.role.value}
|
message = {'role': self.role.value}
|
||||||
|
if self.content is not None:
|
||||||
|
message['content'] = self.content
|
||||||
|
if self.name is not None:
|
||||||
|
message['name'] = self.name
|
||||||
|
if self.tool_calls is not None:
|
||||||
|
message['tool_calls'] = self.tool_calls
|
||||||
|
if self.tool_call_id is not None:
|
||||||
|
message['tool_call_id'] = self.tool_call_id
|
||||||
|
return message
|
||||||
|
|
||||||
|
class FunctionParameters(BaseModel):
|
||||||
|
type: str = "object"
|
||||||
|
properties: Dict[str, Any] = {}
|
||||||
|
required: Optional[List[str]] = None
|
||||||
|
|
||||||
|
class FunctionDefinition(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: FunctionParameters = Field(default_factory=FunctionParameters)
|
||||||
|
|
||||||
|
class ToolFunction(BaseModel):
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
type: Literal["function"]
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
class ChatCompletionCreate(BaseModel):
|
class ChatCompletionCreate(BaseModel):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model : str
|
model: str
|
||||||
stream : bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=1.0)
|
temperature: Optional[float] = Field(default=0.6)
|
||||||
top_p: Optional[float] = Field(default=1.0)
|
top_p: Optional[float] = Field(default=1.0)
|
||||||
|
tools: Optional[List[Tool]] = None
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
|
frequency_penalty: float = 0
|
||||||
|
presence_penalty: float = 0
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChunk(BaseModel):
|
class ChatCompletionChunk(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
|
@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
|
||||||
system_fingerprint: Optional[str] = None
|
system_fingerprint: Optional[str] = None
|
||||||
usage: Optional[CompletionUsage] = None
|
usage: Optional[CompletionUsage] = None
|
||||||
|
|
||||||
|
|
||||||
def to_stream_reply(self):
|
def to_stream_reply(self):
|
||||||
return f"data: {self.model_dump_json()}\n\n"
|
return f"data: {self.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class RawUsage(BaseModel):
|
class RawUsage(BaseModel):
|
||||||
tokenize_time: float
|
tokenize_time: float
|
||||||
prefill_time: float
|
prefill_time: float
|
||||||
decode_time: float
|
decode_time: float
|
||||||
prefill_count: int
|
prefill_count: int
|
||||||
decode_count: int
|
decode_count: int
|
Loading…
Add table
Add a link
Reference in a new issue