mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
updata function_call
This commit is contained in:
parent
038db30ec9
commit
a7e8d7c1af
4 changed files with 554 additions and 89 deletions
|
@ -1,19 +1,61 @@
|
||||||
import json
|
import json
|
||||||
from time import time
|
from time import time
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from typing import Dict, List, Optional, Any, Literal, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import re
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from ktransformers.server.utils.create_interface import get_interface
|
from ktransformers.server.utils.create_interface import get_interface
|
||||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
|
||||||
from ktransformers.server.schemas.endpoints.chat import RawUsage
|
from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
|
||||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
|
from ktransformers.server.config.log import logger
|
||||||
|
|
||||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
|
||||||
from openai.types.chat import ChatCompletion
|
|
||||||
from openai.types.completion_usage import CompletionUsage
|
|
||||||
|
|
||||||
|
# Define own data structure instead of importing from OpenAI
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
prompt_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
completion_tokens_details: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: Optional[Dict[str, Any]] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
logprobs: Optional[Any] = None
|
||||||
|
delta: Optional[Dict[str, Any]] = None
|
||||||
|
content_filter_results: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
class ChatCompletion(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[Choice]
|
||||||
|
usage: Optional[CompletionUsage] = None
|
||||||
|
system_fingerprint: Optional[str] = None
|
||||||
|
prompt_filter_results: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
# Only for non-streaming response construction
|
||||||
|
class ChatCompletionMessageToolCallFunction(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
class ChatCompletionMessageToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: ChatCompletionMessageToolCallFunction
|
||||||
|
|
||||||
|
class ChatCompletionMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -21,90 +63,375 @@ router = APIRouter()
|
||||||
async def list_models():
|
async def list_models():
|
||||||
return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"}
|
return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"}
|
||||||
|
|
||||||
|
def getTools(buffer):
|
||||||
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
extracted_tools = []
|
||||||
|
working_buffer = buffer
|
||||||
|
|
||||||
|
|
||||||
|
# Iterate over all function calls
|
||||||
|
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
|
||||||
|
# Find a complete function call
|
||||||
|
start_index = working_buffer.find(tool_call_begin_marker)
|
||||||
|
end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)
|
||||||
|
|
||||||
|
if start_index == -1 or end_index == -1 or start_index > end_index:
|
||||||
|
logger.warning("Not a function")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract the full function call
|
||||||
|
full_tool_call = working_buffer[start_index:end_index]
|
||||||
|
|
||||||
|
# Remove this function call from the working buffer to prevent duplicate processing
|
||||||
|
working_buffer = working_buffer.replace(full_tool_call, "", 1)
|
||||||
|
|
||||||
|
# Extract the function name
|
||||||
|
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
||||||
|
function_name_end = full_tool_call.find("\n", function_name_start)
|
||||||
|
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
||||||
|
|
||||||
|
# Extract JSON parameters
|
||||||
|
json_pattern = r'```json\s*(.*?)\s*```'
|
||||||
|
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
||||||
|
|
||||||
|
if json_match:
|
||||||
|
arguments_str = json_match.group(1).strip()
|
||||||
|
# Generate tool call IDs
|
||||||
|
tool_call_id = f"call_{uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
# Add to tool call list
|
||||||
|
extracted_tools.append({
|
||||||
|
"id": tool_call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments_str
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Get Function: {function_name}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unable to get function,function_name: {function_name}")
|
||||||
|
|
||||||
|
logger.info(f"Total {len(extracted_tools)} Functions")
|
||||||
|
return extracted_tools
|
||||||
|
|
||||||
@router.post('/chat/completions', tags=['openai'])
|
@router.post('/chat/completions', tags=['openai'])
|
||||||
async def chat_completion(request:Request,create:ChatCompletionCreate):
|
async def chat_completion(request: Request, create: ChatCompletionCreate):
|
||||||
id = str(uuid4())
|
id = str(uuid4().hex)
|
||||||
|
|
||||||
|
# 1. Use system prompts to let models know how to use tools
|
||||||
|
enhanced_messages = list(create.messages)
|
||||||
|
|
||||||
|
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
|
||||||
|
if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user):
|
||||||
|
tool_instructions = "你可以使用function_call,函数调用功能,目前,你可以使用以下工具\n\n"
|
||||||
|
for tool in create.tools:
|
||||||
|
tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n"
|
||||||
|
|
||||||
|
# Modify tool usage guidelines to encourage JSON output
|
||||||
|
tool_instructions += "name为函数名称,description为函数功能的描述,parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n"
|
||||||
|
tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n"
|
||||||
|
tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n"
|
||||||
|
tool_instructions += '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
||||||
|
tool_instructions += '示例: \n<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n'
|
||||||
|
tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n"
|
||||||
|
tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
|
||||||
|
|
||||||
|
enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions
|
||||||
|
|
||||||
|
# Requests processed
|
||||||
interface: BackendInterfaceBase = get_interface()
|
interface: BackendInterfaceBase = get_interface()
|
||||||
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
|
input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
|
||||||
|
|
||||||
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
|
|
||||||
|
|
||||||
if Config().api_key != '':
|
if Config().api_key != '':
|
||||||
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
|
||||||
|
|
||||||
if create.stream:
|
if create.stream:
|
||||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
chunk = ChatCompletionChunk(
|
chunk = ChatCompletionChunk(
|
||||||
id = id,
|
id=id,
|
||||||
choices = [],
|
choices=[],
|
||||||
object = 'chat.completion.chunk',
|
object='chat.completion.chunk',
|
||||||
created = int(time()),
|
created=int(time()),
|
||||||
model = Config().model_name,
|
model=Config().model_name,
|
||||||
|
system_fingerprint=f"fp_{uuid4().hex[:12]}",
|
||||||
)
|
)
|
||||||
|
|
||||||
async for res in interface.inference(input_message,id, create.temperature, create.top_p):
|
# Collect the full output of the model, but specialize in processing tool calls
|
||||||
|
full_content = ""
|
||||||
|
buffer = "" # Used to temporarily store the current block of text
|
||||||
|
tool_call_mode = False # Mark if a tool call is being processed
|
||||||
|
tool_calls = [] # Store all detected tool calls
|
||||||
|
|
||||||
|
# Customize model special tokens
|
||||||
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
|
||||||
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
# at the end of inference, interface.inference() will return the usage of inference
|
# Final return on utilization
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
chunk.choices = []
|
chunk.choices = []
|
||||||
chunk.usage = CompletionUsage(
|
chunk.usage = CompletionUsage(
|
||||||
prompt_tokens = raw_usage.prefill_count,
|
prompt_tokens=raw_usage.prefill_count,
|
||||||
completion_tokens = raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||||
)
|
)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
else:
|
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
choice = Choice(
|
|
||||||
index = 0,
|
# Detecting model-specific formatting tool call starts
|
||||||
delta = ChoiceDelta(content=token, role=None, tool_calls=None),
|
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
|
||||||
finish_reason = finish_reason,
|
tool_call_mode = True
|
||||||
logprobs = None,
|
|
||||||
)
|
# Adjust full_content to remove tool call section
|
||||||
chunk.choices = [choice]
|
if buffer.endswith(tool_calls_begin_marker):
|
||||||
yield chunk
|
full_content = full_content[:-len(tool_calls_begin_marker)]
|
||||||
|
elif tool_calls_begin_marker in (buffer + token):
|
||||||
|
idx = (buffer + token).find(tool_calls_begin_marker)
|
||||||
|
full_content = full_content[:-(len(buffer) - idx)]
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Send the current cumulative text content (if any)
|
||||||
|
if full_content:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": full_content},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
full_content = ""
|
||||||
|
|
||||||
|
# Accumulation of content in non-tool call mode
|
||||||
|
if not tool_call_mode:
|
||||||
|
full_content += token
|
||||||
|
buffer += token
|
||||||
|
# Keep the buffer at a reasonable size
|
||||||
|
if len(buffer) > 200:
|
||||||
|
buffer = buffer[-200:]
|
||||||
|
else:
|
||||||
|
# In tool call mode, continue to collect tool call related text
|
||||||
|
buffer += token
|
||||||
|
|
||||||
|
# If the tool call end marker is found
|
||||||
|
if tool_calls_end_marker in buffer:
|
||||||
|
try:
|
||||||
|
# Parsing Calling Text Extraction Tool Calling Information
|
||||||
|
|
||||||
|
tool_calls = getTools(buffer)
|
||||||
|
if len(tool_calls):
|
||||||
|
# reset state
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Send tool call events
|
||||||
|
for idx, tool_call in enumerate(tool_calls):
|
||||||
|
# First tool call message
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": idx,
|
||||||
|
"id": tool_call["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_call["function"]["name"],
|
||||||
|
"arguments": ""
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Sending Parameters
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": idx,
|
||||||
|
"function": {"arguments": tool_call["function"]["arguments"]}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Send Completion Message
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "tool_calls"
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# No further processing after return
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# JSON extraction failed, probably incomplete formatting
|
||||||
|
logger.warning("Failed to extract JSON from tool call")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing tool call: {e}")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Normal text output (only in non-tool call mode)
|
||||||
|
if not tool_call_mode and token:
|
||||||
|
if finish_reason is not None:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": finish_reason
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
else:
|
||||||
|
if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": token},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# If gotten this far without returning, it means that the full tool call was not detected
|
||||||
|
# Send Routine Completion Message
|
||||||
|
if not tool_call_mode:
|
||||||
|
chunk.choices = [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}]
|
||||||
|
yield chunk
|
||||||
|
|
||||||
return chat_stream_response(request, inner())
|
return chat_stream_response(request, inner())
|
||||||
else:
|
else:
|
||||||
from openai.types.chat.chat_completion import Choice
|
# non streaming response processing
|
||||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
full_content = ""
|
||||||
|
|
||||||
content = ""
|
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
async for res in interface.inference(input_message,id,create.temperature,create.top_p):
|
tool_calls = []
|
||||||
|
buffer = ""
|
||||||
|
tool_call_mode = False
|
||||||
|
|
||||||
|
# Custom model special markers
|
||||||
|
tool_calls_begin_marker = "<|tool▁calls▁begin|>"
|
||||||
|
tool_call_begin_marker = "<|tool▁call▁begin|>"
|
||||||
|
tool_sep_marker = "<|tool▁sep|>"
|
||||||
|
tool_call_end_marker = "<|tool▁call▁end|>"
|
||||||
|
tool_calls_end_marker = "<|tool▁calls▁end|>"
|
||||||
|
|
||||||
|
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
|
||||||
if isinstance(res, RawUsage):
|
if isinstance(res, RawUsage):
|
||||||
raw_usage = res
|
raw_usage = res
|
||||||
usage = CompletionUsage(
|
usage = CompletionUsage(
|
||||||
prompt_tokens = raw_usage.prefill_count,
|
prompt_tokens=raw_usage.prefill_count,
|
||||||
completion_tokens = raw_usage.decode_count,
|
completion_tokens=raw_usage.decode_count,
|
||||||
total_tokens = raw_usage.prefill_count + raw_usage.decode_count
|
total_tokens=raw_usage.prefill_count + raw_usage.decode_count
|
||||||
)
|
)
|
||||||
else:
|
elif isinstance(res, tuple) and len(res) == 2:
|
||||||
token, finish_reason = res
|
token, finish_reason = res
|
||||||
content = content + token
|
|
||||||
finish_reason = finish_reason
|
# Detecting the start of model-specific formatting tool calls
|
||||||
|
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
|
||||||
choice = Choice(
|
tool_call_mode = True
|
||||||
index = 0,
|
|
||||||
finish_reason = finish_reason,
|
# Adjust full_content to remove tool call section
|
||||||
message = ChatCompletionMessage(
|
if buffer.endswith(tool_calls_begin_marker):
|
||||||
content=content,
|
full_content = full_content[:-len(tool_calls_begin_marker)]
|
||||||
role="assistant"
|
elif tool_calls_begin_marker in (buffer + token):
|
||||||
))
|
idx = (buffer + token).find(tool_calls_begin_marker)
|
||||||
|
full_content = full_content[:-(len(buffer) - idx)]
|
||||||
chat_completion = ChatCompletion(
|
buffer = ""
|
||||||
id = id,
|
|
||||||
choices = [choice],
|
# Accumulation of content in non-tool call mode
|
||||||
created = int(time()),
|
if not tool_call_mode:
|
||||||
model = Config().model_name,
|
full_content += token
|
||||||
object = 'chat.completion',
|
buffer += token
|
||||||
usage = usage
|
# Keep the buffer at a reasonable size
|
||||||
)
|
if len(buffer) > 200:
|
||||||
|
buffer = buffer[-200:]
|
||||||
return chat_completion
|
else:
|
||||||
|
# In tool call mode, continue to collect tool call related text
|
||||||
|
buffer += token
|
||||||
|
|
||||||
|
# If the tool call end marker is found
|
||||||
|
if tool_calls_end_marker in buffer:
|
||||||
|
try:
|
||||||
|
# Parsing Calling Text Extraction Tool Calling Information
|
||||||
|
full_tool_call = buffer
|
||||||
|
|
||||||
|
# Extract function name
|
||||||
|
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
|
||||||
|
function_name_end = full_tool_call.find("\n", function_name_start)
|
||||||
|
function_name = full_tool_call[function_name_start:function_name_end].strip()
|
||||||
|
|
||||||
|
# Extract JSON Parameters - Extracts the content between ```json and ```.
|
||||||
|
json_pattern = r'```json\s*(.*?)\s*```'
|
||||||
|
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
|
||||||
|
|
||||||
|
if json_match:
|
||||||
|
arguments_str = json_match.group(1).strip()
|
||||||
|
# Generate tool call IDs
|
||||||
|
tool_call_id = f"call_{uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
# Add to tool call list
|
||||||
|
tool_calls.append({
|
||||||
|
"id": tool_call_id,
|
||||||
|
"index": 0,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments_str
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# If the tool call is successfully parsed, set the reason for completion
|
||||||
|
finish_reason = "tool_calls"
|
||||||
|
|
||||||
|
# reset state
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
else:
|
||||||
|
# JSON extraction failed, probably incomplete formatting
|
||||||
|
logger.warning("Failed to extract JSON from tool call")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing tool call: {e}")
|
||||||
|
tool_call_mode = False
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# Build Response
|
||||||
|
response = {
|
||||||
|
"id": id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time()),
|
||||||
|
"model": Config().model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None if tool_calls else full_content,
|
||||||
|
"tool_calls": tool_calls if tool_calls else None
|
||||||
|
},
|
||||||
|
"finish_reason": finish_reason or "stop"
|
||||||
|
}],
|
||||||
|
"usage": usage.__dict__,
|
||||||
|
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from typing import Optional, List
|
||||||
import asyncio
|
import asyncio
|
||||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||||
from ktransformers.server.backend.interfaces.transformers import (
|
from ktransformers.server.backend.interfaces.transformers import (
|
||||||
|
@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
|
||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||||
async with self._infer_lock:
|
async with self._infer_lock:
|
||||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
|
||||||
yield v
|
yield v
|
||||||
|
|
||||||
# return this inference raw usage
|
# return this inference raw usage
|
||||||
|
|
|
@ -1,4 +1,7 @@
|
||||||
from typing import Any, List, Optional, Set
|
from typing import Any, List, Optional, Set
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
from transformers import (
|
from transformers import (
|
||||||
LlamaTokenizer,
|
LlamaTokenizer,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
self.last_request_id = thread_id
|
self.last_request_id = thread_id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
|
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
|
||||||
self.streamer.reset()
|
self.streamer.reset()
|
||||||
self.profiler.create_and_start_timer("tokenize")
|
self.profiler.create_and_start_timer("tokenize")
|
||||||
|
|
||||||
|
# Check if tools are present
|
||||||
|
has_tools = tools is not None and len(tools) > 0
|
||||||
|
|
||||||
if isinstance(local_messages, List):
|
if isinstance(local_messages, List):
|
||||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||||
elif isinstance(local_messages, str):
|
elif isinstance(local_messages, str):
|
||||||
#local_messages = local_messages[0]['content']
|
|
||||||
input_ids = self.tokenize_prompt(local_messages)
|
input_ids = self.tokenize_prompt(local_messages)
|
||||||
#input_ids = torch.tensor([[6366]], device=input_ids.device)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("local_messages should be List or str")
|
raise ValueError("local_messages should be List or str")
|
||||||
|
|
||||||
|
@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.profiler.pause_timer("tokenize")
|
self.profiler.pause_timer("tokenize")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("prefill")
|
self.profiler.create_and_start_timer("prefill")
|
||||||
|
|
||||||
if Config().user_force_think:
|
if Config().user_force_think:
|
||||||
|
@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
|
||||||
yield think, None
|
yield think, None
|
||||||
|
|
||||||
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
|
||||||
# output think token after prefill done
|
|
||||||
if t is not None:
|
if t is not None:
|
||||||
print(t, end="",flush=True)
|
print(t, end="",flush=True)
|
||||||
yield t, None
|
yield t, None
|
||||||
self.profiler.pause_timer("prefill")
|
self.profiler.pause_timer("prefill")
|
||||||
|
|
||||||
self.profiler.create_and_start_timer("decode")
|
self.profiler.create_and_start_timer("decode")
|
||||||
for t, finish_reason in self.generate():
|
|
||||||
if t is not None:
|
# Handle tool calling
|
||||||
print(t, end="",flush=True)
|
if has_tools:
|
||||||
yield t, finish_reason
|
# Start collecting tokens until we detect a tool call
|
||||||
print("")
|
collected_tokens = ""
|
||||||
|
is_collecting_tool_call = False
|
||||||
|
is_function_name_collected = False
|
||||||
|
function_name = ""
|
||||||
|
collected_arguments = ""
|
||||||
|
brackets_count = 0
|
||||||
|
|
||||||
|
for t, finish_reason in self.generate():
|
||||||
|
if t is not None:
|
||||||
|
print(t, end="", flush=True)
|
||||||
|
collected_tokens += t
|
||||||
|
|
||||||
|
# Check if we're starting a tool call
|
||||||
|
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
|
||||||
|
is_collecting_tool_call = True
|
||||||
|
|
||||||
|
# Generate a unique tool call ID
|
||||||
|
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
|
||||||
|
|
||||||
|
# Send first tool call info
|
||||||
|
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
|
||||||
|
# If tools are provided, use the first one's name
|
||||||
|
recommended_function = tools[0].function.name
|
||||||
|
else:
|
||||||
|
# Otherwise try to extract from context
|
||||||
|
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||||
|
recommended_function = function_match.group(1) if function_match else ""
|
||||||
|
|
||||||
|
yield {
|
||||||
|
'tool_call': {
|
||||||
|
'id': tool_call_id,
|
||||||
|
'type': 'function',
|
||||||
|
'index': 0,
|
||||||
|
'function': {
|
||||||
|
'name': recommended_function,
|
||||||
|
'arguments': ""
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'first_chunk': True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract function name if we're collecting tool call
|
||||||
|
if is_collecting_tool_call and not is_function_name_collected:
|
||||||
|
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
|
||||||
|
if name_match:
|
||||||
|
function_name = name_match.group(1)
|
||||||
|
is_function_name_collected = True
|
||||||
|
|
||||||
|
# Track argument collection
|
||||||
|
if is_collecting_tool_call and is_function_name_collected:
|
||||||
|
args_position = collected_tokens.find('"arguments"')
|
||||||
|
if args_position > -1:
|
||||||
|
# Find the start of the JSON object after "arguments":
|
||||||
|
json_start = collected_tokens.find('{', args_position)
|
||||||
|
if json_start > -1:
|
||||||
|
for i in range(json_start, len(collected_tokens)):
|
||||||
|
char = collected_tokens[i]
|
||||||
|
collected_arguments += char
|
||||||
|
|
||||||
|
if char == '{':
|
||||||
|
brackets_count += 1
|
||||||
|
elif char == '}':
|
||||||
|
brackets_count -= 1
|
||||||
|
|
||||||
|
# Check if we've completed the arguments JSON
|
||||||
|
if brackets_count == 0:
|
||||||
|
# Send argument chunk
|
||||||
|
yield {
|
||||||
|
'tool_call': {
|
||||||
|
'id': tool_call_id,
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': function_name,
|
||||||
|
'arguments': collected_arguments
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'argument_chunk': collected_arguments,
|
||||||
|
'last_chunk': True,
|
||||||
|
'prompt_tokens': 176,
|
||||||
|
'completion_tokens': 20
|
||||||
|
}
|
||||||
|
# Reset for next potential tool call
|
||||||
|
collected_tokens = ""
|
||||||
|
is_collecting_tool_call = False
|
||||||
|
is_function_name_collected = False
|
||||||
|
function_name = ""
|
||||||
|
collected_arguments = ""
|
||||||
|
brackets_count = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle finish reason
|
||||||
|
if finish_reason is not None:
|
||||||
|
yield "", finish_reason
|
||||||
|
|
||||||
|
print("")
|
||||||
|
else:
|
||||||
|
# Regular text generation (no tools)
|
||||||
|
for t, finish_reason in self.generate():
|
||||||
|
if t is not None:
|
||||||
|
print(t, end="",flush=True)
|
||||||
|
yield t, finish_reason
|
||||||
|
print("")
|
||||||
|
|
||||||
self.profiler.pause_timer("decode")
|
self.profiler.pause_timer("decode")
|
||||||
self.report_last_time_performance()
|
self.report_last_time_performance()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union, Dict, Any
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
|
||||||
from openai.types.completion_usage import CompletionUsage
|
from openai.types.completion_usage import CompletionUsage
|
||||||
from openai.types.chat.chat_completion_chunk import Choice
|
from openai.types.chat.chat_completion_chunk import Choice
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class Role(Enum):
|
class Role(Enum):
|
||||||
system = 'system'
|
system = 'system'
|
||||||
|
@ -17,26 +20,57 @@ class Role(Enum):
|
||||||
tool = 'tool'
|
tool = 'tool'
|
||||||
function = 'function'
|
function = 'function'
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
role:Role
|
role: Role
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
def to_tokenizer_message(self):
|
def to_tokenizer_message(self):
|
||||||
return {'content':self.content,'role':self.role.value}
|
message = {'role': self.role.value}
|
||||||
|
if self.content is not None:
|
||||||
|
message['content'] = self.content
|
||||||
|
if self.name is not None:
|
||||||
|
message['name'] = self.name
|
||||||
|
if self.tool_calls is not None:
|
||||||
|
message['tool_calls'] = self.tool_calls
|
||||||
|
if self.tool_call_id is not None:
|
||||||
|
message['tool_call_id'] = self.tool_call_id
|
||||||
|
return message
|
||||||
|
|
||||||
|
class FunctionParameters(BaseModel):
|
||||||
|
type: str = "object"
|
||||||
|
properties: Dict[str, Any] = {}
|
||||||
|
required: Optional[List[str]] = None
|
||||||
|
|
||||||
|
class FunctionDefinition(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: FunctionParameters = Field(default_factory=FunctionParameters)
|
||||||
|
|
||||||
|
class ToolFunction(BaseModel):
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
type: Literal["function"]
|
||||||
|
function: FunctionDefinition
|
||||||
|
|
||||||
class ChatCompletionCreate(BaseModel):
|
class ChatCompletionCreate(BaseModel):
|
||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model : str
|
model: str
|
||||||
stream : bool = False
|
stream: bool = False
|
||||||
temperature: Optional[float] = Field(default=1.0)
|
temperature: Optional[float] = Field(default=0.6)
|
||||||
top_p: Optional[float] = Field(default=1.0)
|
top_p: Optional[float] = Field(default=1.0)
|
||||||
|
tools: Optional[List[Tool]] = None
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None
|
||||||
|
frequency_penalty: float = 0
|
||||||
|
presence_penalty: float = 0
|
||||||
|
|
||||||
def get_tokenizer_messages(self):
|
def get_tokenizer_messages(self):
|
||||||
return [m.to_tokenizer_message() for m in self.messages]
|
return [m.to_tokenizer_message() for m in self.messages]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionChunk(BaseModel):
|
class ChatCompletionChunk(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
|
@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
|
||||||
system_fingerprint: Optional[str] = None
|
system_fingerprint: Optional[str] = None
|
||||||
usage: Optional[CompletionUsage] = None
|
usage: Optional[CompletionUsage] = None
|
||||||
|
|
||||||
|
|
||||||
def to_stream_reply(self):
|
def to_stream_reply(self):
|
||||||
return f"data: {self.model_dump_json()}\n\n"
|
return f"data: {self.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
||||||
class RawUsage(BaseModel):
|
class RawUsage(BaseModel):
|
||||||
tokenize_time: float
|
tokenize_time: float
|
||||||
prefill_time: float
|
prefill_time: float
|
||||||
decode_time: float
|
decode_time: float
|
||||||
prefill_count: int
|
prefill_count: int
|
||||||
decode_count: int
|
decode_count: int
|
Loading…
Add table
Add a link
Reference in a new issue