mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 09:09:42 +00:00
202 lines
8.3 KiB
Python
202 lines
8.3 KiB
Python
from enum import Enum
|
|
from time import time
|
|
from typing import AsyncIterable, Callable, Dict, List, Optional, Union
|
|
from asyncio import Lock, Queue
|
|
|
|
from fastapi import logger
|
|
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator
|
|
import torch
|
|
|
|
from ktransformers.server.config.config import Config
|
|
from ktransformers.server.models.assistants.assistants import Assistant
|
|
from ktransformers.server.models.assistants.threads import Thread
|
|
from ktransformers.server.schemas.assistants.messages import Role
|
|
from ktransformers.server.schemas.assistants.runs import RunObject,RunStreamResponse,ObjectWithCreatedTime
|
|
from ktransformers.server.schemas.assistants.threads import ThreadObject
|
|
from ktransformers.server.schemas.base import Metadata,MetadataField,ObjectID
|
|
from ktransformers.server.schemas.assistants.tool import Tool,CodeInterpreter,FileSearch,RelatedThreads,FuntionTool,ToolResource,CodeInterpreterResource,FileSearchResource,RelatedThreadsResource,ToolType
|
|
from ktransformers.server.utils.sql_utils import SQLUtil
|
|
|
|
|
|
class AssistantBase(BaseModel):
|
|
name: Optional[str] = Field(None,description='The name of the assistant.')
|
|
description: Optional[str] = Field(None,description='The description of the assistant.')
|
|
instructions: Optional[str] = Field(None,description='Instructions which is added in front of the input of LLM')
|
|
tools: List[Tool] = Field([], max_length=128)
|
|
|
|
@field_validator('tools', mode='before')
|
|
def validate_tools(cls, value):
|
|
re = []
|
|
if not isinstance(value, list):
|
|
raise ValueError('Invalid type for tools')
|
|
|
|
for tool in value:
|
|
if 'type' not in tool:
|
|
raise ValueError('Invalid type for tools')
|
|
if tool['type'] == 'code_interpreter':
|
|
re.append(CodeInterpreter(**tool))
|
|
elif tool['type'] == 'file_search':
|
|
re.append(FileSearch(**tool))
|
|
elif tool['type'] == 'related_threads':
|
|
re.append(RelatedThreads(**tool))
|
|
elif tool['type'] == 'function':
|
|
re.append(FuntionTool(**tool))
|
|
else:
|
|
raise ValueError('Invalid type for tools')
|
|
return re
|
|
|
|
tool_resources: List[ToolResource] = Field([], max_length=128)
|
|
|
|
@field_validator('tool_resources', mode='before')
|
|
def validate_tool_resources(cls, value):
|
|
re = []
|
|
if not isinstance(value, list):
|
|
raise ValueError('Invalid type for tool resources')
|
|
|
|
for tool_re in value:
|
|
if 'file_ids' in tool_re:
|
|
re.append(CodeInterpreterResource(**tool_re))
|
|
elif 'vector_stores' in tool_re:
|
|
re.append(FileSearchResource(**tool_re))
|
|
elif 'thread_ids' in tool_re:
|
|
re.append(RelatedThreadsResource(**tool_re))
|
|
else:
|
|
raise ValueError('Invalid type for tool resources')
|
|
return re
|
|
|
|
meta_data: Metadata = MetadataField
|
|
|
|
@model_validator(mode='before')
|
|
def convert_meta_data(cls, values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
temperature: Optional[float] = Field(ge=0.0, le=2.0, default=1)
|
|
top_p: Optional[float] = Field(ge=0.0, le=1.0, default=1)
|
|
response_format: Union[str, Dict[str, str]] = "auto"
|
|
|
|
|
|
class AssistantCreate(AssistantBase):
|
|
model: str
|
|
|
|
|
|
class AssistantBuildStatus(BaseModel):
|
|
class Status(Enum):
|
|
not_build = "not_build"
|
|
in_queue = "in_queue"
|
|
parsing = "parsing"
|
|
prefilling = "prefilling"
|
|
dumping = "dumping"
|
|
completed = "completed"
|
|
paused = "paused"
|
|
|
|
_lock: Lock = PrivateAttr(default_factory=Lock)
|
|
_queue: Optional[Queue] = PrivateAttr(None)
|
|
|
|
status: Status = Field(default=Status.not_build)
|
|
total_file_count: int = Field(default=0)
|
|
parsed_file_count: int = Field(default=0)
|
|
|
|
prefilling_current: int = Field(default=0)
|
|
prefilling_total: int = Field(default=0)
|
|
|
|
build_started_time: Optional[int] = Field(default=None)
|
|
build_completed_time: Optional[int] = Field(default=None)
|
|
|
|
# in megabytes
|
|
assistant_usage: int = Field(default=0, description='')
|
|
assistant_total_usage: int = Field(default=0)
|
|
disk_free_space: int = Field(default=0)
|
|
disk_total_space: int = Field(default=0)
|
|
|
|
def to_stream_reply(self) -> str:
|
|
return f"event: assistant.build.status\ndata: {self.model_dump_json()}\n\n"
|
|
|
|
|
|
class AssistantObject(AssistantBase, ObjectWithCreatedTime):
|
|
model: Optional[str] = Field(
|
|
default=Config().model_name)
|
|
related_threads_objects: Optional[List] = Field(None, exclude=True)
|
|
_encoded_instruction: Optional[torch.Tensor] = PrivateAttr(default=None)
|
|
build_status: AssistantBuildStatus = Field(default=AssistantBuildStatus())
|
|
|
|
def as_api_response(self):
|
|
return self.model_dump(exclude={'build_status'})
|
|
|
|
def get_related_threads_ids(self) -> List[ObjectID]:
|
|
re = []
|
|
for tool, tool_re in zip(self.tools, self.tool_resources):
|
|
if tool.type == ToolType.RELATED_THREADS:
|
|
re += tool_re.thread_ids or []
|
|
return re
|
|
|
|
def get_related_threads_objects(self) -> List:
|
|
# raise NotImplementedError # should be replaced
|
|
sql_utils = SQLUtil()
|
|
if self.related_threads_objects is None:
|
|
with sql_utils.get_db() as db:
|
|
db_threads = db.query(Thread).all()
|
|
self.related_threads_objects = [tool for tool in [ThreadObject.model_validate(
|
|
tool.__dict__) for tool in db_threads] if tool.is_related_threads and tool.meta_data['assistant_id'] == self.id]
|
|
# logger.debug(
|
|
# f'Found {len(self.related_threads_objects)} related threads')
|
|
return self.related_threads_objects
|
|
|
|
def append_related_threads(self, thread_ids: List[ObjectID]):
|
|
# logger.debug(f'{self.tools} {self.tool_resources}')
|
|
for tool, tool_re in zip(self.tools, self.tool_resources):
|
|
if tool.type == ToolType.RELATED_THREADS:
|
|
tool_re.thread_ids += thread_ids
|
|
return
|
|
|
|
self.tools.append(RelatedThreads(type=ToolType.RELATED_THREADS))
|
|
self.tool_resources.append(
|
|
RelatedThreadsResource(thread_ids=thread_ids))
|
|
|
|
async def update_build_status(self, events: AsyncIterable) -> AsyncIterable:
|
|
async for event in events:
|
|
# logger.debug(event)
|
|
if isinstance(event, RunStreamResponse):
|
|
if event.event == RunObject.Status.completed:
|
|
self.build_status.status = AssistantBuildStatus.Status.completed
|
|
self.build_status.build_completed_time = int(time())
|
|
self.sync_db()
|
|
yield self.build_status.model_copy()
|
|
elif isinstance(event, dict):
|
|
# logger.debug('dict')
|
|
if 'stage' in event:
|
|
if event['stage'] == 'prefill':
|
|
self.build_status.status = AssistantBuildStatus.Status.prefilling
|
|
self.build_status.prefilling_current = event['curr_progress']
|
|
self.build_status.prefilling_total = event['max_progress']
|
|
if event['stage'] == 'parse':
|
|
self.build_status.status = AssistantBuildStatus.Status.parsing
|
|
self.build_status.parsed_file_count = event['curr_progress']
|
|
self.build_status.total_file_count = event['max_progress']
|
|
yield self.build_status.model_copy()
|
|
|
|
def get_build_status(self) -> AssistantBuildStatus:
|
|
return self.build_status
|
|
|
|
|
|
def sync_db(self)->None:
|
|
# raise NotImplementedError # should be replaced
|
|
sql_utils = SQLUtil()
|
|
db_assistant = Assistant(
|
|
**self.model_dump(mode='json'),
|
|
)
|
|
with sql_utils.get_db() as db:
|
|
sql_utils.db_merge_commit(db, db_assistant)
|
|
|
|
def get_encoded_instruction(self,encode_fn:Callable)->torch.Tensor:
|
|
if self._encoded_instruction is None:
|
|
logger.info(f'encoding assistant instruction: {self.instructions}')
|
|
self._encoded_instruction = encode_fn(self.instructions, Role.user)
|
|
return self._encoded_instruction
|
|
|
|
|
|
class AssistantModify(AssistantBase):
|
|
model: Optional[str] = None
|
|
|
|
|
|
# Non API Backend
|