Refactor the chat interface to support tool calling and parameter processing

Defined new data structures in chat.py to replace OpenAI's original implementation, adding support for tool calling.

Implemented logic for extracting and processing tool calls, enabling dynamic function invocation during conversations.

Added methods in balance_serve.py to retrieve sampling parameters, handling default values and edge cases.

Updated ktransformers.py and transformers.py to support the passing of tool parameters.

Modified the default value of top_p in config.py to 1.0 to increase generation diversity.

Extended the message model in chat.py to support the transmission of tool call information.

These changes enhance the system's flexibility and functionality, enabling more complex interaction patterns.
This commit is contained in:
sean.su 2025-04-14 15:23:37 +08:00
parent 038db30ec9
commit 8699109129
6 changed files with 574 additions and 99 deletions

View file

@ -1,19 +1,61 @@
import json import json
from time import time from time import time
from uuid import uuid4 from uuid import uuid4
from typing import Dict, List, Optional, Any, Literal, Union
from pydantic import BaseModel, Field
import re
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.requests import Request from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from openai.types.chat import ChatCompletion
from openai.types.completion_usage import CompletionUsage
# Define own data structure instead of importing from OpenAI
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: Optional[Dict[str, Any]] = None
completion_tokens_details: Optional[Dict[str, Any]] = None
class Choice(BaseModel):
index: int
message: Optional[Dict[str, Any]] = None
finish_reason: Optional[str] = None
logprobs: Optional[Any] = None
delta: Optional[Dict[str, Any]] = None
content_filter_results: Optional[Dict[str, Any]] = None
class ChatCompletion(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Choice]
usage: Optional[CompletionUsage] = None
system_fingerprint: Optional[str] = None
prompt_filter_results: Optional[List[Dict[str, Any]]] = None
# Only for non-streaming response construction
class ChatCompletionMessageToolCallFunction(BaseModel):
name: str
arguments: str
class ChatCompletionMessageToolCall(BaseModel):
id: str
type: str
function: ChatCompletionMessageToolCallFunction
class ChatCompletionMessage(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
router = APIRouter() router = APIRouter()
@ -21,90 +63,375 @@ router = APIRouter()
async def list_models(): async def list_models():
return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"} return {"data": [{"id": Config().model_name, "name": Config().model_name}], "object": "list"}
def getTools(buffer):
tool_calls_begin_marker = "<tool▁calls▁begin>"
tool_call_begin_marker = "<tool▁call▁begin>"
tool_sep_marker = "<tool▁sep>"
tool_call_end_marker = "<tool▁call▁end>"
tool_calls_end_marker = "<tool▁calls▁end>"
extracted_tools = []
working_buffer = buffer
# Iterate over all function calls
while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:
# Find a complete function call
start_index = working_buffer.find(tool_call_begin_marker)
end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)
if start_index == -1 or end_index == -1 or start_index > end_index:
logger.warning("Not a function")
break
# Extract the full function call
full_tool_call = working_buffer[start_index:end_index]
# Remove this function call from the working buffer to prevent duplicate processing
working_buffer = working_buffer.replace(full_tool_call, "", 1)
# Extract the function name
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
function_name_end = full_tool_call.find("\n", function_name_start)
function_name = full_tool_call[function_name_start:function_name_end].strip()
# Extract JSON parameters
json_pattern = r'```json\s*(.*?)\s*```'
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
if json_match:
arguments_str = json_match.group(1).strip()
# Generate tool call IDs
tool_call_id = f"call_{uuid4().hex[:24]}"
# Add to tool call list
extracted_tools.append({
"id": tool_call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments_str
}
})
logger.info(f"Get Function: {function_name}")
else:
logger.warning(f"Unable to get functionfunction_name: {function_name}")
logger.info(f"Total {len(extracted_tools)} Functions")
return extracted_tools
@router.post('/chat/completions', tags=['openai']) @router.post('/chat/completions', tags=['openai'])
async def chat_completion(request:Request,create:ChatCompletionCreate): async def chat_completion(request: Request, create: ChatCompletionCreate):
id = str(uuid4()) id = str(uuid4().hex)
# 1. Use system prompts to let models know how to use tools
enhanced_messages = list(create.messages)
# If there is a tool and the first message is system, add instructions on how to use the tool in the system tip
if create.tools and len(create.tools) > 0 and (enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user):
tool_instructions = "你可以使用function_call函数调用功能目前你可以使用以下工具\n\n"
for tool in create.tools:
tool_instructions += f" \"function\":{{\"name\" : {tool.function.name},\"description\" : {tool.function.description} , \"parameters\" : {tool.function.parameters}}}\n"
# Modify tool usage guidelines to encourage JSON output
tool_instructions += "name为函数名称description为函数功能的描述parameters中含有函数需要使用的参数和参数的描述, 其中required为必要参数\n"
tool_instructions += "工具仅在用户明确提出,或者你认为需要调用工具的时候调用,注意,当需要高度实时性的信息比如时间或者最近的事情等,优先调用工具来获取!。当确实调用工具的关键信息时,你可以先向用户索取关键信息再调用工具\n"
tool_instructions += "\n当你需要使用工具时,请以下列格式输出,格式为:\n"
tool_instructions += '<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>name\n```json {"参数名": "参数值","参数名2": "参数值2"...}\n```<tool▁call▁end><tool▁calls▁end>\n'
tool_instructions += '示例: \n<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>the_functnion_name_will_be_called\n```json {"arg1": "value1","arg2": "value2"}\n```<tool▁call▁end><tool▁calls▁end>\n'
tool_instructions += "这样可以调用名为\"the_functnion_name_will_be_called\",并将value1和value2传入参数arg1,arg2\n"
tool_instructions += "不要尝试解释你在做什么,直接输出工具函数调用即可。确保函数调用语句格式正确且完整。"
enhanced_messages[0].content = enhanced_messages[0].content + "\n\n" + tool_instructions
# Requests processed
interface: BackendInterfaceBase = get_interface() interface: BackendInterfaceBase = get_interface()
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages()) input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
if Config().api_key != '': if Config().api_key != '':
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
async def inner(): async def inner():
chunk = ChatCompletionChunk( chunk = ChatCompletionChunk(
id = id, id=id,
choices = [], choices=[],
object = 'chat.completion.chunk', object='chat.completion.chunk',
created = int(time()), created=int(time()),
model = Config().model_name, model=Config().model_name,
system_fingerprint=f"fp_{uuid4().hex[:12]}",
) )
async for res in interface.inference(input_message,id, create.temperature, create.top_p): # Collect the full output of the model, but specialize in processing tool calls
full_content = ""
buffer = "" # Used to temporarily store the current block of text
tool_call_mode = False # Mark if a tool call is being processed
tool_calls = [] # Store all detected tool calls
# Customize model special tokens
tool_calls_begin_marker = "<tool▁calls▁begin>"
tool_call_begin_marker = "<tool▁call▁begin>"
tool_sep_marker = "<tool▁sep>"
tool_call_end_marker = "<tool▁call▁end>"
tool_calls_end_marker = "<tool▁calls▁end>"
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
# at the end of inference, interface.inference() will return the usage of inference # Final return on utilization
raw_usage = res raw_usage = res
chunk.choices = [] chunk.choices = []
chunk.usage = CompletionUsage( chunk.usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count, prompt_tokens=raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count, completion_tokens=raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count total_tokens=raw_usage.prefill_count + raw_usage.decode_count
) )
yield chunk yield chunk
elif isinstance(res, tuple) and len(res) == 2:
else:
token, finish_reason = res token, finish_reason = res
choice = Choice(
index = 0, # Detecting model-specific formatting tool call starts
delta = ChoiceDelta(content=token, role=None, tool_calls=None), if not tool_call_mode and tool_calls_begin_marker in buffer + token:
finish_reason = finish_reason, tool_call_mode = True
logprobs = None,
) # Adjust full_content to remove tool call section
chunk.choices = [choice] if buffer.endswith(tool_calls_begin_marker):
yield chunk full_content = full_content[:-len(tool_calls_begin_marker)]
elif tool_calls_begin_marker in (buffer + token):
idx = (buffer + token).find(tool_calls_begin_marker)
full_content = full_content[:-(len(buffer) - idx)]
buffer = ""
# Send the current cumulative text content (if any)
if full_content:
chunk.choices = [{
"index": 0,
"delta": {"content": full_content},
"finish_reason": None
}]
yield chunk
full_content = ""
# Accumulation of content in non-tool call mode
if not tool_call_mode:
full_content += token
buffer += token
# Keep the buffer at a reasonable size
if len(buffer) > 200:
buffer = buffer[-200:]
else:
# In tool call mode, continue to collect tool call related text
buffer += token
# If the tool call end marker is found
if tool_calls_end_marker in buffer:
try:
# Parsing Calling Text Extraction Tool Calling Information
tool_calls = getTools(buffer)
if len(tool_calls):
# reset state
tool_call_mode = False
buffer = ""
# Send tool call events
for idx, tool_call in enumerate(tool_calls):
# First tool call message
chunk.choices = [{
"index": 0,
"delta": {
"role": "assistant",
"content": None,
"tool_calls": [{
"index": idx,
"id": tool_call["id"],
"type": "function",
"function": {
"name": tool_call["function"]["name"],
"arguments": ""
}
}]
},
"finish_reason": None
}]
yield chunk
# Sending Parameters
chunk.choices = [{
"index": 0,
"delta": {
"tool_calls": [{
"index": idx,
"function": {"arguments": tool_call["function"]["arguments"]}
}]
},
"finish_reason": None
}]
yield chunk
# Send Completion Message
chunk.choices = [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls"
}]
yield chunk
# No further processing after return
return
else:
# JSON extraction failed, probably incomplete formatting
logger.warning("Failed to extract JSON from tool call")
tool_call_mode = False
buffer = ""
except Exception as e:
logger.error(f"Error processing tool call: {e}")
tool_call_mode = False
buffer = ""
# Normal text output (only in non-tool call mode)
if not tool_call_mode and token:
if finish_reason is not None:
chunk.choices = [{
"index": 0,
"delta": {},
"finish_reason": finish_reason
}]
yield chunk
else:
if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]):
pass
else:
chunk.choices = [{
"index": 0,
"delta": {"content": token},
"finish_reason": None
}]
yield chunk
# If gotten this far without returning, it means that the full tool call was not detected
# Send Routine Completion Message
if not tool_call_mode:
chunk.choices = [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}]
yield chunk
return chat_stream_response(request, inner()) return chat_stream_response(request, inner())
else: else:
from openai.types.chat.chat_completion import Choice # non streaming response processing
from openai.types.chat.chat_completion_message import ChatCompletionMessage full_content = ""
content = ""
finish_reason = None finish_reason = None
async for res in interface.inference(input_message,id,create.temperature,create.top_p): tool_calls = []
buffer = ""
tool_call_mode = False
# Custom model special markers
tool_calls_begin_marker = "<tool▁calls▁begin>"
tool_call_begin_marker = "<tool▁call▁begin>"
tool_sep_marker = "<tool▁sep>"
tool_call_end_marker = "<tool▁call▁end>"
tool_calls_end_marker = "<tool▁calls▁end>"
async for res in interface.inference(input_message, id, create.temperature, create.top_p):
if isinstance(res, RawUsage): if isinstance(res, RawUsage):
raw_usage = res raw_usage = res
usage = CompletionUsage( usage = CompletionUsage(
prompt_tokens = raw_usage.prefill_count, prompt_tokens=raw_usage.prefill_count,
completion_tokens = raw_usage.decode_count, completion_tokens=raw_usage.decode_count,
total_tokens = raw_usage.prefill_count + raw_usage.decode_count total_tokens=raw_usage.prefill_count + raw_usage.decode_count
) )
else: elif isinstance(res, tuple) and len(res) == 2:
token, finish_reason = res token, finish_reason = res
content = content + token
finish_reason = finish_reason # Detecting the start of model-specific formatting tool calls
if not tool_call_mode and tool_calls_begin_marker in buffer + token:
choice = Choice( tool_call_mode = True
index = 0,
finish_reason = finish_reason, # Adjust full_content to remove tool call section
message = ChatCompletionMessage( if buffer.endswith(tool_calls_begin_marker):
content=content, full_content = full_content[:-len(tool_calls_begin_marker)]
role="assistant" elif tool_calls_begin_marker in (buffer + token):
)) idx = (buffer + token).find(tool_calls_begin_marker)
full_content = full_content[:-(len(buffer) - idx)]
chat_completion = ChatCompletion( buffer = ""
id = id,
choices = [choice], # Accumulation of content in non-tool call mode
created = int(time()), if not tool_call_mode:
model = Config().model_name, full_content += token
object = 'chat.completion', buffer += token
usage = usage # Keep the buffer at a reasonable size
) if len(buffer) > 200:
buffer = buffer[-200:]
return chat_completion else:
# In tool call mode, continue to collect tool call related text
buffer += token
# If the tool call end marker is found
if tool_calls_end_marker in buffer:
try:
# Parsing Calling Text Extraction Tool Calling Information
full_tool_call = buffer
# Extract function name
function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)
function_name_end = full_tool_call.find("\n", function_name_start)
function_name = full_tool_call[function_name_start:function_name_end].strip()
# Extract JSON Parameters - Extracts the content between ```json and ```.
json_pattern = r'```json\s*(.*?)\s*```'
json_match = re.search(json_pattern, full_tool_call, re.DOTALL)
if json_match:
arguments_str = json_match.group(1).strip()
# Generate tool call IDs
tool_call_id = f"call_{uuid4().hex[:24]}"
# Add to tool call list
tool_calls.append({
"id": tool_call_id,
"index": 0,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments_str
}
})
# If the tool call is successfully parsed, set the reason for completion
finish_reason = "tool_calls"
# reset state
tool_call_mode = False
buffer = ""
else:
# JSON extraction failed, probably incomplete formatting
logger.warning("Failed to extract JSON from tool call")
tool_call_mode = False
buffer = ""
except Exception as e:
logger.error(f"Error processing tool call: {e}")
tool_call_mode = False
buffer = ""
# Build Response
response = {
"id": id,
"object": "chat.completion",
"created": int(time()),
"model": Config().model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": None if tool_calls else full_content,
"tool_calls": tool_calls if tool_calls else None
},
"finish_reason": finish_reason or "stop"
}],
"usage": usage.__dict__,
"system_fingerprint": f"fp_{uuid4().hex[:12]}"
}
return response

View file

@ -264,6 +264,7 @@ class BalanceServeInterface(BackendInterfaceBase):
# thread_related # thread_related
last_request_id: Optional[str] = None last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set() ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
self.queue_map:dict[int,asyncio.Queue] = {} self.queue_map:dict[int,asyncio.Queue] = {}
@ -282,7 +283,21 @@ class BalanceServeInterface(BackendInterfaceBase):
p.start() p.start()
processes.append(p) processes.append(p)
start_event.wait() start_event.wait()
def get_sampling_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None) -> tuple[float, float]:
"""Get sampling parameters and handle default values and edge cases"""
if temperature is None:
temperature = Config().temperature
if top_p is None:
top_p = Config().top_p
if temperature == 0:
temperature = 0.0001
if top_p == 0:
top_p = 0.0001
return temperature, top_p
def run_queue_proxy(self): def run_queue_proxy(self):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -342,7 +357,6 @@ class BalanceServeInterface(BackendInterfaceBase):
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
@ -352,12 +366,9 @@ class BalanceServeInterface(BackendInterfaceBase):
[input_ids, token_thinks], dim=1 [input_ids, token_thinks], dim=1
) )
profiler.pause_timer("tokenize") profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill") profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd() query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist() query_add.query_token = input_ids[0].tolist()
@ -367,11 +378,10 @@ class BalanceServeInterface(BackendInterfaceBase):
#@TODO add server #@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")] stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria query_add.stop_criteria = stop_criteria
if temperature == 0:
temperature = 0.0001 temperature, top_p = self.get_sampling_params(temperature, top_p)
query_add.sample_options.temperature = temperature query_add.sample_options.temperature = temperature
if top_p == 0:
top_p = 0.0001
query_add.sample_options.top_p = top_p query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens) query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)

View file

@ -1,4 +1,5 @@
import torch import torch
from typing import Optional, List
import asyncio import asyncio
from transformers import AutoTokenizer, AutoConfig, GenerationConfig from transformers import AutoTokenizer, AutoConfig, GenerationConfig
from ktransformers.server.backend.interfaces.transformers import ( from ktransformers.server.backend.interfaces.transformers import (
@ -228,9 +229,9 @@ class KTransformersInterface(TransformersInterface):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device) return torch.tensor([self.seq_length - 1], device=device)
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
async with self._infer_lock: async with self._infer_lock:
async for v in super().inference(local_messages, thread_id, temperature, top_p): async for v in super().inference(local_messages, thread_id, temperature, top_p, tools):
yield v yield v
# return this inference raw usage # return this inference raw usage

View file

@ -1,4 +1,7 @@
from typing import Any, List, Optional, Set from typing import Any, List, Optional, Set
import re
import json
import uuid
from transformers import ( from transformers import (
LlamaTokenizer, LlamaTokenizer,
AutoTokenizer, AutoTokenizer,
@ -375,15 +378,17 @@ class TransformersInterface(BackendInterfaceBase):
self.last_request_id = thread_id self.last_request_id = thread_id
return True return True
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None): async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, tools: Optional[List] = None):
self.streamer.reset() self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
# Check if tools are present
has_tools = tools is not None and len(tools) > 0
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
@ -394,7 +399,6 @@ class TransformersInterface(BackendInterfaceBase):
) )
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think: if Config().user_force_think:
@ -403,17 +407,118 @@ class TransformersInterface(BackendInterfaceBase):
yield think, None yield think, None
for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p): for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p):
# output think token after prefill done
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t, None yield t, None
self.profiler.pause_timer("prefill") self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode") self.profiler.create_and_start_timer("decode")
for t, finish_reason in self.generate():
if t is not None: # Handle tool calling
print(t, end="",flush=True) if has_tools:
yield t, finish_reason # Start collecting tokens until we detect a tool call
print("") collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
for t, finish_reason in self.generate():
if t is not None:
print(t, end="", flush=True)
collected_tokens += t
# Check if we're starting a tool call
if not is_collecting_tool_call and any(keyword in collected_tokens.lower() for keyword in ['"function"', 'function', 'tool_call', 'tool call']):
is_collecting_tool_call = True
# Generate a unique tool call ID
tool_call_id = f"call_{uuid.uuid4().hex.replace('-', '')}"
# Send first tool call info
if len(tools) > 0 and hasattr(tools[0], 'function') and hasattr(tools[0].function, 'name'):
# If tools are provided, use the first one's name
recommended_function = tools[0].function.name
else:
# Otherwise try to extract from context
function_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
recommended_function = function_match.group(1) if function_match else ""
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'index': 0,
'function': {
'name': recommended_function,
'arguments': ""
}
},
'first_chunk': True
}
# Extract function name if we're collecting tool call
if is_collecting_tool_call and not is_function_name_collected:
name_match = re.search(r'"name":\s*"([^"]+)"', collected_tokens)
if name_match:
function_name = name_match.group(1)
is_function_name_collected = True
# Track argument collection
if is_collecting_tool_call and is_function_name_collected:
args_position = collected_tokens.find('"arguments"')
if args_position > -1:
# Find the start of the JSON object after "arguments":
json_start = collected_tokens.find('{', args_position)
if json_start > -1:
for i in range(json_start, len(collected_tokens)):
char = collected_tokens[i]
collected_arguments += char
if char == '{':
brackets_count += 1
elif char == '}':
brackets_count -= 1
# Check if we've completed the arguments JSON
if brackets_count == 0:
# Send argument chunk
yield {
'tool_call': {
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': collected_arguments
}
},
'argument_chunk': collected_arguments,
'last_chunk': True,
'prompt_tokens': 176,
'completion_tokens': 20
}
# Reset for next potential tool call
collected_tokens = ""
is_collecting_tool_call = False
is_function_name_collected = False
function_name = ""
collected_arguments = ""
brackets_count = 0
break
# Handle finish reason
if finish_reason is not None:
yield "", finish_reason
print("")
else:
# Regular text generation (no tools)
for t, finish_reason in self.generate():
if t is not None:
print(t, end="",flush=True)
yield t, finish_reason
print("")
self.profiler.pause_timer("decode") self.profiler.pause_timer("decode")
self.report_last_time_performance() self.report_last_time_performance()

View file

@ -133,7 +133,7 @@ class Config(metaclass=Singleton):
self.smoothing_factor = self.model.get("smoothing_factor", 0.0) self.smoothing_factor = self.model.get("smoothing_factor", 0.0)
self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None) self.dynamic_temperature: Optional[str] = self.model.get("dynamic_temperature", None)
self.top_k = self.model.get("top_k", 50) self.top_k = self.model.get("top_k", 50)
self.top_p = self.model.get("top_p", 0.8) self.top_p = self.model.get("top_p", 1.0)
self.top_a = self.model.get("top_a", 0.0) self.top_a = self.model.get("top_a", 0.0)
self.skew = self.model.get("skew", 0.0) self.skew = self.model.get("skew", 0.0)
self.typical = self.model.get("typical", 0.0) self.typical = self.model.get("typical", 0.0)

View file

@ -1,4 +1,4 @@
from typing import List, Optional from typing import List, Optional, Union, Dict, Any
from typing_extensions import Literal from typing_extensions import Literal
from enum import Enum from enum import Enum
@ -9,6 +9,9 @@ from ktransformers.server.schemas.base import Object
from openai.types.completion_usage import CompletionUsage from openai.types.completion_usage import CompletionUsage
from openai.types.chat.chat_completion_chunk import Choice from openai.types.chat.chat_completion_chunk import Choice
from uuid import uuid4
from pydantic import BaseModel, Field
class Role(Enum): class Role(Enum):
system = 'system' system = 'system'
@ -17,26 +20,57 @@ class Role(Enum):
tool = 'tool' tool = 'tool'
function = 'function' function = 'function'
class Message(BaseModel): class Message(BaseModel):
content: str content: Optional[str] = None
role:Role role: Role
name: Optional[str] = None name: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
def to_tokenizer_message(self): def to_tokenizer_message(self):
return {'content':self.content,'role':self.role.value} message = {'role': self.role.value}
if self.content is not None:
message['content'] = self.content
if self.name is not None:
message['name'] = self.name
if self.tool_calls is not None:
message['tool_calls'] = self.tool_calls
if self.tool_call_id is not None:
message['tool_call_id'] = self.tool_call_id
return message
class FunctionParameters(BaseModel):
type: str = "object"
properties: Dict[str, Any] = {}
required: Optional[List[str]] = None
class FunctionDefinition(BaseModel):
name: str
description: Optional[str] = None
parameters: FunctionParameters = Field(default_factory=FunctionParameters)
class ToolFunction(BaseModel):
function: FunctionDefinition
class Tool(BaseModel):
type: Literal["function"]
function: FunctionDefinition
class ChatCompletionCreate(BaseModel): class ChatCompletionCreate(BaseModel):
messages: List[Message] messages: List[Message]
model : str model: str
stream : bool = False stream: bool = False
temperature: Optional[float] = Field(default=1.0) temperature: Optional[float] = Field(default=0.6)
top_p: Optional[float] = Field(default=1.0) top_p: Optional[float] = Field(default=1.0)
tools: Optional[List[Tool]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
stream_options: Optional[Dict[str, Any]] = None
frequency_penalty: float = 0
presence_penalty: float = 0
def get_tokenizer_messages(self): def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages] return [m.to_tokenizer_message() for m in self.messages]
class ChatCompletionChunk(BaseModel): class ChatCompletionChunk(BaseModel):
id: str id: str
choices: List[Choice] choices: List[Choice]
@ -47,14 +81,12 @@ class ChatCompletionChunk(BaseModel):
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
usage: Optional[CompletionUsage] = None usage: Optional[CompletionUsage] = None
def to_stream_reply(self): def to_stream_reply(self):
return f"data: {self.model_dump_json()}\n\n" return f"data: {self.model_dump_json()}\n\n"
class RawUsage(BaseModel): class RawUsage(BaseModel):
tokenize_time: float tokenize_time: float
prefill_time: float prefill_time: float
decode_time: float decode_time: float
prefill_count: int prefill_count: int
decode_count: int decode_count: int