kvcache-ai-ktransformers/ktransformers/server/schemas/assistants/runs.py
2024-07-27 16:06:58 +08:00

201 lines
6 KiB
Python

from enum import Enum
from typing import Dict, List, Optional, Union, ForwardRef
from pydantic import BaseModel, Field, model_validator
from ktransformers.server.models.assistants.runs import Run
from ktransformers.server.schemas.base import TODO, Metadata, MetadataField, ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.threads import ThreadCreate
from ktransformers.server.schemas.assistants.tool import Tool, ToolResource
from ktransformers.server.utils.sql_utils import SQLUtil
class ToolCall(BaseModel):
id: str
type: str
function: TODO
class SubmitToolOutputs(BaseModel):
tool_calls: List[ToolCall]
class RequiredAction(BaseModel):
type: str
submit_tool_outputs: TODO
class LastError(BaseModel):
code: str
message: str
class IncompleteDetails(BaseModel):
reason: str
class Usage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
class TruncationStrategy(BaseModel):
type: str = "auto"
last_message: Optional[int]
class ToolChoiceType(Enum):
none = "none"
auto = "auto"
required = "required"
class RunBase(BaseModel):
class Status(Enum):
created = "created" # only stream event will have this created status
queued = "queued"
in_progress = "in_progress"
requires_action = "requires_action"
cancelling = "cancelling"
cancelled = "cancelled"
failed = "failed"
completed = "completed"
expired = "expired"
thread_id: str
assistant_id: str
status: Status = Status.queued
required_action: Optional[RequiredAction] = Field(None)
last_error: Optional[LastError] = Field(None)
expires_at: Optional[int]= Field(None)
started_at: Optional[int] = Field(None)
cancelled_at: Optional[int] = Field(None)
failed_at: Optional[int] = Field(None)
completed_at: Optional[int] = Field(None)
incomplete_details: Optional[IncompleteDetails] = Field(None)
model: Optional[str] = Field(None)
instructions: Optional[str] = Field(None)
tools: Optional[List[Tool]] = Field([])
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
def set_compute_save(self,save:int):
self.meta_data['compute_save'] = str(save)
usage: Optional[Usage] = Field(None)
temperature: Optional[float] = Field(None)
top_p: Optional[float]= Field(None)
max_propmp_tokens: Optional[int]= Field(None)
truncation_strategy: Optional[TruncationStrategy]= Field(None)
tool_choice: Optional[Union[ToolChoiceType, dict]]= Field(None)
response_format: Union[str, Dict[str, str]] = "auto"
RunStreamResponse = ForwardRef('RunStreamResponse')
class RunObject(RunBase, ObjectWithCreatedTime):
def stream_response_with_event(self,event:RunBase.Status)->RunStreamResponse:
match event:
case RunBase.Status.created:
self.status = RunBase.Status.queued
case _:
self.status = event
return RunStreamResponse(run=self, event=event)
def sync_db(self):
# raise NotImplementedError # should be replaced in crud
sql_utils = SQLUtil()
db_run = Run(
**self.model_dump(mode='json'),
)
with sql_utils.get_db() as db:
sql_utils.db_merge_commit(db, db_run)
def create_message_creation_step(self):
raise NotImplementedError # should be replaced
class RunStreamResponse(BaseModel):
run: RunObject
event: RunObject.Status
def to_stream_reply(self):
return f"event: thread.run.{self.event.value}\ndata: {self.run.model_dump_json()}\n\n"
class RunCreate(BaseModel):
assistant_id: str
model: Optional[str] = Field(default=None)
instructions: Optional[str] = Field(default=None)
# TODO: Add this
# additional_instructions: Optional[str]
# additional_messages: Optional[List[MessageCore]]
tools: List[Tool] = Field(default=[])
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
stream: Optional[bool] = Field(default=None)
max_propmp_tokens: Optional[int] = Field(default=None)
# TODO: Add this
# max_completion_tokens: Optional[int]
truncation_strategy: Optional[TruncationStrategy] = Field(default=None)
tool_choice: Optional[Union[ToolChoiceType, dict]] = Field(default=None)
response_format: Union[str, Dict[str, str]] = Field(default="auto")
class RunThreadCreate(BaseModel):
assistant_id: str
thread: Optional[ThreadCreate]
model: Optional[str]
instructions: Optional[str]
tools: List[Tool]
tool_resources: List[ToolResource]
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
temperature: Optional[float]
top_p: Optional[float]
stream: Optional[bool]
max_propmp_tokens: Optional[int]
# TODO: Add this
# max_completion_tokens: Optional[int]
truncation_strategy: TruncationStrategy
tool_choice: Union[ToolChoiceType, dict]
response_format: Union[str, Dict[str, str]] = "auto"
class RunModify(BaseModel):
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
class ToolOutput(BaseModel):
tool_call_id: Optional[str]
output: Optional[str]
class RunSubmit(BaseModel):
tool_outputs: List[ToolOutput]
stream: Optional[bool]