kvcache-ai-ktransformers/ktransformers/server/schemas/assistants/messages.py
2024-07-27 16:06:58 +08:00

213 lines
6 KiB
Python

from enum import Enum
from typing import ForwardRef, List, Optional, Union,Callable
import torch
from pydantic import BaseModel, PrivateAttr, model_validator
from ktransformers.server.exceptions import not_implemented
from ktransformers.server.config.log import logger
from ktransformers.server.models.assistants.messages import Message
from ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.tool import Field,CodeInterpreter,FileSearch
from ktransformers.server.utils.sql_utils import SQLUtil
class IncompleteDetails(BaseModel):
reason: str
class ContentType(Enum):
image_file = "image_file"
image_url = "image_url"
text = "text"
class ContentObject(BaseModel):
type: ContentType
class ImageFile(BaseModel):
file_id: str
detail: str
class ImageFileObject(ContentObject):
image_file: ImageFile
class ImageUrl(BaseModel):
url: str
detail: str
class ImageUrlObject(ContentObject):
image_url: ImageUrl
class Annotation(BaseModel):
todo: str
class Text(BaseModel):
value: str
annotations: List[Annotation] = Field(default=[])
class TextObject(ContentObject):
text: Text
delta_index: int = Field(default=0,exclude=True)
special_tokens_on: bool = Field(default=False,exclude=True)
last_two: str= Field(default='',exclude=True)
def filter_append(self,text:str):
self.text.value+=text
self.delta_index+=1
return True
Content = Union[ImageFileObject, ImageUrlObject, TextObject]
class Attachment(BaseModel):
file_id: Optional[str] = Field(default=None)
tools: Optional[List[Union[CodeInterpreter, FileSearch]]] = Field(default=None)
class Role(Enum):
user = "user"
assistant = "assistant"
def is_user(self)->bool:
return self == Role.user
class MessageCore(BaseModel):
role: Role
content: List[Content]
attachments: Optional[List[Attachment]]
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
class MessageBase(MessageCore):
class Status(Enum):
created = "created" # only used for stream
in_progress = "in_progress"
incomplete = "incomplete"
completed = "completed"
thread_id: str
status: Status
incomplete_details: Optional[IncompleteDetails] = None
completed_at: Optional[int] = None
incomplete_at: Optional[int] = None
assistant_id: Optional[str] = None
run_id: Optional[str]
MessageStreamResponse = ForwardRef('MessageStreamResponse')
class MessageObject(MessageBase, ObjectWithCreatedTime):
_encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None)
def get_text_content(self) -> str:
text_content = ""
for content in self.content:
if content.type == ContentType.text:
text_content += content.text.value
else:
raise not_implemented("Content other than text")
return text_content
async def get_encoded_content(self,encode_fn:Callable):
if self._encoded_content is None:
logger.info(f'encoding {self.role.value} message({self.status.value}): {self.get_text_content()}')
self._encoded_content = encode_fn(self.get_text_content(),self.role)
for f in self.get_attached_files():
logger.info(f'encoding file: {f.filename}')
self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1)
yield None
yield self._encoded_content
def get_attached_files(self):
raise NotImplementedError # should be replaced
def append_message_delta(self,text:str):
raise NotImplementedError # should be replaced
def sync_db(self):
# raise NotImplementedError # should be replaced
sql_utils = SQLUtil()
db_message = Message(
**self.model_dump(mode="json"),
)
with sql_utils.get_db() as db:
sql_utils.db_merge_commit(db, db_message)
def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse:
match event:
case MessageObject.Status.created:
self.status = MessageObject.Status.in_progress
case _:
self.status = event
return MessageStreamResponse(message=self, event=event)
class MessageStreamResponse(BaseModel):
message: MessageObject
event: MessageObject.Status
def to_stream_reply(self):
return f"event: thread.message.{self.event.value}\ndata: {self.message.model_dump_json()}\n\n"
class MessageCreate(BaseModel):
role: Role = Field(default=Role.user)
content: Union[str | List[Content]]
attachments: Optional[List[Attachment]] = None
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
def to_core(self) -> MessageCore:
# logger.debug(f"Converting message create to core {self.model_dump()}")
core = MessageCore(
role=self.role,
content=[],
attachments=self.attachments,
meta_data=self.meta_data,
)
if isinstance(self.content, str):
core.content = [TextObject(type="text", text=Text(value=self.content, annotations=[]))]
elif isinstance(self.content, list):
core.content = self.content
else:
raise ValueError("Invalid content type")
return core
class MessageModify(BaseModel):
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values