mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 09:09:42 +00:00
201 lines
6 KiB
Python
201 lines
6 KiB
Python
from enum import Enum
|
|
from typing import Dict, List, Optional, Union, ForwardRef
|
|
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
from ktransformers.server.models.assistants.runs import Run
|
|
from ktransformers.server.schemas.base import TODO, Metadata, MetadataField, ObjectWithCreatedTime
|
|
from ktransformers.server.schemas.assistants.threads import ThreadCreate
|
|
from ktransformers.server.schemas.assistants.tool import Tool, ToolResource
|
|
from ktransformers.server.utils.sql_utils import SQLUtil
|
|
|
|
|
|
class ToolCall(BaseModel):
|
|
id: str
|
|
type: str
|
|
function: TODO
|
|
|
|
|
|
class SubmitToolOutputs(BaseModel):
|
|
tool_calls: List[ToolCall]
|
|
|
|
|
|
class RequiredAction(BaseModel):
|
|
type: str
|
|
submit_tool_outputs: TODO
|
|
|
|
|
|
class LastError(BaseModel):
|
|
code: str
|
|
message: str
|
|
|
|
|
|
class IncompleteDetails(BaseModel):
|
|
reason: str
|
|
|
|
|
|
class Usage(BaseModel):
|
|
completion_tokens: int
|
|
prompt_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class TruncationStrategy(BaseModel):
|
|
type: str = "auto"
|
|
last_message: Optional[int]
|
|
|
|
|
|
class ToolChoiceType(Enum):
|
|
none = "none"
|
|
auto = "auto"
|
|
required = "required"
|
|
|
|
|
|
class RunBase(BaseModel):
|
|
class Status(Enum):
|
|
created = "created" # only stream event will have this created status
|
|
queued = "queued"
|
|
in_progress = "in_progress"
|
|
requires_action = "requires_action"
|
|
cancelling = "cancelling"
|
|
cancelled = "cancelled"
|
|
failed = "failed"
|
|
completed = "completed"
|
|
expired = "expired"
|
|
|
|
|
|
thread_id: str
|
|
assistant_id: str
|
|
status: Status = Status.queued
|
|
required_action: Optional[RequiredAction] = Field(None)
|
|
last_error: Optional[LastError] = Field(None)
|
|
expires_at: Optional[int]= Field(None)
|
|
started_at: Optional[int] = Field(None)
|
|
cancelled_at: Optional[int] = Field(None)
|
|
failed_at: Optional[int] = Field(None)
|
|
completed_at: Optional[int] = Field(None)
|
|
incomplete_details: Optional[IncompleteDetails] = Field(None)
|
|
model: Optional[str] = Field(None)
|
|
instructions: Optional[str] = Field(None)
|
|
tools: Optional[List[Tool]] = Field([])
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
|
|
def set_compute_save(self,save:int):
|
|
self.meta_data['compute_save'] = str(save)
|
|
|
|
|
|
usage: Optional[Usage] = Field(None)
|
|
temperature: Optional[float] = Field(None)
|
|
top_p: Optional[float]= Field(None)
|
|
max_propmp_tokens: Optional[int]= Field(None)
|
|
truncation_strategy: Optional[TruncationStrategy]= Field(None)
|
|
tool_choice: Optional[Union[ToolChoiceType, dict]]= Field(None)
|
|
response_format: Union[str, Dict[str, str]] = "auto"
|
|
|
|
|
|
RunStreamResponse = ForwardRef('RunStreamResponse')
|
|
|
|
class RunObject(RunBase, ObjectWithCreatedTime):
|
|
def stream_response_with_event(self,event:RunBase.Status)->RunStreamResponse:
|
|
match event:
|
|
case RunBase.Status.created:
|
|
self.status = RunBase.Status.queued
|
|
case _:
|
|
self.status = event
|
|
return RunStreamResponse(run=self, event=event)
|
|
|
|
|
|
def sync_db(self):
|
|
# raise NotImplementedError # should be replaced in crud
|
|
sql_utils = SQLUtil()
|
|
db_run = Run(
|
|
**self.model_dump(mode='json'),
|
|
)
|
|
with sql_utils.get_db() as db:
|
|
sql_utils.db_merge_commit(db, db_run)
|
|
|
|
def create_message_creation_step(self):
|
|
raise NotImplementedError # should be replaced
|
|
|
|
|
|
class RunStreamResponse(BaseModel):
|
|
run: RunObject
|
|
event: RunObject.Status
|
|
def to_stream_reply(self):
|
|
return f"event: thread.run.{self.event.value}\ndata: {self.run.model_dump_json()}\n\n"
|
|
|
|
class RunCreate(BaseModel):
|
|
assistant_id: str
|
|
model: Optional[str] = Field(default=None)
|
|
instructions: Optional[str] = Field(default=None)
|
|
# TODO: Add this
|
|
# additional_instructions: Optional[str]
|
|
# additional_messages: Optional[List[MessageCore]]
|
|
tools: List[Tool] = Field(default=[])
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
temperature: Optional[float] = Field(default=None)
|
|
top_p: Optional[float] = Field(default=None)
|
|
stream: Optional[bool] = Field(default=None)
|
|
max_propmp_tokens: Optional[int] = Field(default=None)
|
|
# TODO: Add this
|
|
# max_completion_tokens: Optional[int]
|
|
truncation_strategy: Optional[TruncationStrategy] = Field(default=None)
|
|
tool_choice: Optional[Union[ToolChoiceType, dict]] = Field(default=None)
|
|
response_format: Union[str, Dict[str, str]] = Field(default="auto")
|
|
|
|
|
|
class RunThreadCreate(BaseModel):
|
|
assistant_id: str
|
|
thread: Optional[ThreadCreate]
|
|
model: Optional[str]
|
|
instructions: Optional[str]
|
|
tools: List[Tool]
|
|
tool_resources: List[ToolResource]
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
temperature: Optional[float]
|
|
top_p: Optional[float]
|
|
stream: Optional[bool]
|
|
max_propmp_tokens: Optional[int]
|
|
# TODO: Add this
|
|
# max_completion_tokens: Optional[int]
|
|
truncation_strategy: TruncationStrategy
|
|
tool_choice: Union[ToolChoiceType, dict]
|
|
response_format: Union[str, Dict[str, str]] = "auto"
|
|
|
|
|
|
class RunModify(BaseModel):
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
|
|
|
|
class ToolOutput(BaseModel):
|
|
tool_call_id: Optional[str]
|
|
output: Optional[str]
|
|
|
|
|
|
class RunSubmit(BaseModel):
|
|
tool_outputs: List[ToolOutput]
|
|
stream: Optional[bool]
|