kvcache-ai-ktransformers/ktransformers/server/schemas/assistants/assistants.py
2024-07-27 16:06:58 +08:00

202 lines
8.3 KiB
Python

from enum import Enum
from time import time
from typing import AsyncIterable, Callable, Dict, List, Optional, Union
from asyncio import Lock, Queue
from fastapi import logger
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator
import torch
from ktransformers.server.config.config import Config
from ktransformers.server.models.assistants.assistants import Assistant
from ktransformers.server.models.assistants.threads import Thread
from ktransformers.server.schemas.assistants.messages import Role
from ktransformers.server.schemas.assistants.runs import RunObject,RunStreamResponse,ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.base import Metadata,MetadataField,ObjectID
from ktransformers.server.schemas.assistants.tool import Tool,CodeInterpreter,FileSearch,RelatedThreads,FuntionTool,ToolResource,CodeInterpreterResource,FileSearchResource,RelatedThreadsResource,ToolType
from ktransformers.server.utils.sql_utils import SQLUtil
class AssistantBase(BaseModel):
name: Optional[str] = Field(None,description='The name of the assistant.')
description: Optional[str] = Field(None,description='The description of the assistant.')
instructions: Optional[str] = Field(None,description='Instructions which is added in front of the input of LLM')
tools: List[Tool] = Field([], max_length=128)
@field_validator('tools', mode='before')
def validate_tools(cls, value):
re = []
if not isinstance(value, list):
raise ValueError('Invalid type for tools')
for tool in value:
if 'type' not in tool:
raise ValueError('Invalid type for tools')
if tool['type'] == 'code_interpreter':
re.append(CodeInterpreter(**tool))
elif tool['type'] == 'file_search':
re.append(FileSearch(**tool))
elif tool['type'] == 'related_threads':
re.append(RelatedThreads(**tool))
elif tool['type'] == 'function':
re.append(FuntionTool(**tool))
else:
raise ValueError('Invalid type for tools')
return re
tool_resources: List[ToolResource] = Field([], max_length=128)
@field_validator('tool_resources', mode='before')
def validate_tool_resources(cls, value):
re = []
if not isinstance(value, list):
raise ValueError('Invalid type for tool resources')
for tool_re in value:
if 'file_ids' in tool_re:
re.append(CodeInterpreterResource(**tool_re))
elif 'vector_stores' in tool_re:
re.append(FileSearchResource(**tool_re))
elif 'thread_ids' in tool_re:
re.append(RelatedThreadsResource(**tool_re))
else:
raise ValueError('Invalid type for tool resources')
return re
meta_data: Metadata = MetadataField
@model_validator(mode='before')
def convert_meta_data(cls, values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
temperature: Optional[float] = Field(ge=0.0, le=2.0, default=1)
top_p: Optional[float] = Field(ge=0.0, le=1.0, default=1)
response_format: Union[str, Dict[str, str]] = "auto"
class AssistantCreate(AssistantBase):
model: str
class AssistantBuildStatus(BaseModel):
class Status(Enum):
not_build = "not_build"
in_queue = "in_queue"
parsing = "parsing"
prefilling = "prefilling"
dumping = "dumping"
completed = "completed"
paused = "paused"
_lock: Lock = PrivateAttr(default_factory=Lock)
_queue: Optional[Queue] = PrivateAttr(None)
status: Status = Field(default=Status.not_build)
total_file_count: int = Field(default=0)
parsed_file_count: int = Field(default=0)
prefilling_current: int = Field(default=0)
prefilling_total: int = Field(default=0)
build_started_time: Optional[int] = Field(default=None)
build_completed_time: Optional[int] = Field(default=None)
# in megabytes
assistant_usage: int = Field(default=0, description='')
assistant_total_usage: int = Field(default=0)
disk_free_space: int = Field(default=0)
disk_total_space: int = Field(default=0)
def to_stream_reply(self) -> str:
return f"event: assistant.build.status\ndata: {self.model_dump_json()}\n\n"
class AssistantObject(AssistantBase, ObjectWithCreatedTime):
model: Optional[str] = Field(
default=Config().model_name)
related_threads_objects: Optional[List] = Field(None, exclude=True)
_encoded_instruction: Optional[torch.Tensor] = PrivateAttr(default=None)
build_status: AssistantBuildStatus = Field(default=AssistantBuildStatus())
def as_api_response(self):
return self.model_dump(exclude={'build_status'})
def get_related_threads_ids(self) -> List[ObjectID]:
re = []
for tool, tool_re in zip(self.tools, self.tool_resources):
if tool.type == ToolType.RELATED_THREADS:
re += tool_re.thread_ids or []
return re
def get_related_threads_objects(self) -> List:
# raise NotImplementedError # should be replaced
sql_utils = SQLUtil()
if self.related_threads_objects is None:
with sql_utils.get_db() as db:
db_threads = db.query(Thread).all()
self.related_threads_objects = [tool for tool in [ThreadObject.model_validate(
tool.__dict__) for tool in db_threads] if tool.is_related_threads and tool.meta_data['assistant_id'] == self.id]
# logger.debug(
# f'Found {len(self.related_threads_objects)} related threads')
return self.related_threads_objects
def append_related_threads(self, thread_ids: List[ObjectID]):
# logger.debug(f'{self.tools} {self.tool_resources}')
for tool, tool_re in zip(self.tools, self.tool_resources):
if tool.type == ToolType.RELATED_THREADS:
tool_re.thread_ids += thread_ids
return
self.tools.append(RelatedThreads(type=ToolType.RELATED_THREADS))
self.tool_resources.append(
RelatedThreadsResource(thread_ids=thread_ids))
async def update_build_status(self, events: AsyncIterable) -> AsyncIterable:
async for event in events:
# logger.debug(event)
if isinstance(event, RunStreamResponse):
if event.event == RunObject.Status.completed:
self.build_status.status = AssistantBuildStatus.Status.completed
self.build_status.build_completed_time = int(time())
self.sync_db()
yield self.build_status.model_copy()
elif isinstance(event, dict):
# logger.debug('dict')
if 'stage' in event:
if event['stage'] == 'prefill':
self.build_status.status = AssistantBuildStatus.Status.prefilling
self.build_status.prefilling_current = event['curr_progress']
self.build_status.prefilling_total = event['max_progress']
if event['stage'] == 'parse':
self.build_status.status = AssistantBuildStatus.Status.parsing
self.build_status.parsed_file_count = event['curr_progress']
self.build_status.total_file_count = event['max_progress']
yield self.build_status.model_copy()
def get_build_status(self) -> AssistantBuildStatus:
return self.build_status
def sync_db(self)->None:
# raise NotImplementedError # should be replaced
sql_utils = SQLUtil()
db_assistant = Assistant(
**self.model_dump(mode='json'),
)
with sql_utils.get_db() as db:
sql_utils.db_merge_commit(db, db_assistant)
def get_encoded_instruction(self,encode_fn:Callable)->torch.Tensor:
if self._encoded_instruction is None:
logger.info(f'encoding assistant instruction: {self.instructions}')
self._encoded_instruction = encode_fn(self.instructions, Role.user)
return self._encoded_instruction
class AssistantModify(AssistantBase):
model: Optional[str] = None
# Non API Backend