mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 09:09:42 +00:00
213 lines
6 KiB
Python
213 lines
6 KiB
Python
from enum import Enum
|
|
from typing import ForwardRef, List, Optional, Union,Callable
|
|
|
|
import torch
|
|
from pydantic import BaseModel, PrivateAttr, model_validator
|
|
|
|
from ktransformers.server.exceptions import not_implemented
|
|
from ktransformers.server.config.log import logger
|
|
from ktransformers.server.models.assistants.messages import Message
|
|
from ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime
|
|
from ktransformers.server.schemas.assistants.tool import Field,CodeInterpreter,FileSearch
|
|
from ktransformers.server.utils.sql_utils import SQLUtil
|
|
|
|
|
|
class IncompleteDetails(BaseModel):
|
|
reason: str
|
|
|
|
|
|
class ContentType(Enum):
|
|
image_file = "image_file"
|
|
image_url = "image_url"
|
|
text = "text"
|
|
|
|
|
|
class ContentObject(BaseModel):
|
|
type: ContentType
|
|
|
|
|
|
class ImageFile(BaseModel):
|
|
file_id: str
|
|
detail: str
|
|
|
|
|
|
class ImageFileObject(ContentObject):
|
|
image_file: ImageFile
|
|
|
|
|
|
class ImageUrl(BaseModel):
|
|
url: str
|
|
detail: str
|
|
|
|
|
|
class ImageUrlObject(ContentObject):
|
|
image_url: ImageUrl
|
|
|
|
|
|
class Annotation(BaseModel):
|
|
todo: str
|
|
|
|
|
|
class Text(BaseModel):
|
|
value: str
|
|
annotations: List[Annotation] = Field(default=[])
|
|
|
|
|
|
class TextObject(ContentObject):
|
|
text: Text
|
|
delta_index: int = Field(default=0,exclude=True)
|
|
special_tokens_on: bool = Field(default=False,exclude=True)
|
|
last_two: str= Field(default='',exclude=True)
|
|
|
|
def filter_append(self,text:str):
|
|
self.text.value+=text
|
|
self.delta_index+=1
|
|
return True
|
|
|
|
|
|
|
|
Content = Union[ImageFileObject, ImageUrlObject, TextObject]
|
|
|
|
|
|
class Attachment(BaseModel):
|
|
file_id: Optional[str] = Field(default=None)
|
|
tools: Optional[List[Union[CodeInterpreter, FileSearch]]] = Field(default=None)
|
|
|
|
|
|
class Role(Enum):
|
|
user = "user"
|
|
assistant = "assistant"
|
|
|
|
def is_user(self)->bool:
|
|
return self == Role.user
|
|
|
|
|
|
class MessageCore(BaseModel):
|
|
role: Role
|
|
content: List[Content]
|
|
attachments: Optional[List[Attachment]]
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
|
|
|
|
class MessageBase(MessageCore):
|
|
class Status(Enum):
|
|
created = "created" # only used for stream
|
|
in_progress = "in_progress"
|
|
incomplete = "incomplete"
|
|
completed = "completed"
|
|
thread_id: str
|
|
status: Status
|
|
incomplete_details: Optional[IncompleteDetails] = None
|
|
completed_at: Optional[int] = None
|
|
incomplete_at: Optional[int] = None
|
|
|
|
assistant_id: Optional[str] = None
|
|
run_id: Optional[str]
|
|
|
|
|
|
MessageStreamResponse = ForwardRef('MessageStreamResponse')
|
|
|
|
class MessageObject(MessageBase, ObjectWithCreatedTime):
|
|
_encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None)
|
|
|
|
|
|
def get_text_content(self) -> str:
|
|
text_content = ""
|
|
for content in self.content:
|
|
if content.type == ContentType.text:
|
|
text_content += content.text.value
|
|
else:
|
|
raise not_implemented("Content other than text")
|
|
return text_content
|
|
|
|
async def get_encoded_content(self,encode_fn:Callable):
|
|
if self._encoded_content is None:
|
|
logger.info(f'encoding {self.role.value} message({self.status.value}): {self.get_text_content()}')
|
|
self._encoded_content = encode_fn(self.get_text_content(),self.role)
|
|
|
|
for f in self.get_attached_files():
|
|
logger.info(f'encoding file: {f.filename}')
|
|
self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1)
|
|
yield None
|
|
|
|
yield self._encoded_content
|
|
|
|
|
|
def get_attached_files(self):
|
|
raise NotImplementedError # should be replaced
|
|
|
|
|
|
|
|
def append_message_delta(self,text:str):
|
|
raise NotImplementedError # should be replaced
|
|
|
|
def sync_db(self):
|
|
# raise NotImplementedError # should be replaced
|
|
sql_utils = SQLUtil()
|
|
db_message = Message(
|
|
**self.model_dump(mode="json"),
|
|
)
|
|
with sql_utils.get_db() as db:
|
|
sql_utils.db_merge_commit(db, db_message)
|
|
|
|
|
|
def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse:
|
|
match event:
|
|
case MessageObject.Status.created:
|
|
self.status = MessageObject.Status.in_progress
|
|
case _:
|
|
self.status = event
|
|
return MessageStreamResponse(message=self, event=event)
|
|
|
|
|
|
class MessageStreamResponse(BaseModel):
|
|
message: MessageObject
|
|
event: MessageObject.Status
|
|
|
|
def to_stream_reply(self):
|
|
return f"event: thread.message.{self.event.value}\ndata: {self.message.model_dump_json()}\n\n"
|
|
|
|
|
|
class MessageCreate(BaseModel):
|
|
role: Role = Field(default=Role.user)
|
|
content: Union[str | List[Content]]
|
|
attachments: Optional[List[Attachment]] = None
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|
|
|
|
def to_core(self) -> MessageCore:
|
|
# logger.debug(f"Converting message create to core {self.model_dump()}")
|
|
core = MessageCore(
|
|
role=self.role,
|
|
content=[],
|
|
attachments=self.attachments,
|
|
meta_data=self.meta_data,
|
|
)
|
|
if isinstance(self.content, str):
|
|
core.content = [TextObject(type="text", text=Text(value=self.content, annotations=[]))]
|
|
elif isinstance(self.content, list):
|
|
core.content = self.content
|
|
else:
|
|
raise ValueError("Invalid content type")
|
|
return core
|
|
|
|
|
|
class MessageModify(BaseModel):
|
|
meta_data: Metadata = MetadataField
|
|
@model_validator(mode='before')
|
|
@classmethod
|
|
def convert_meta_data(cls,values):
|
|
if 'meta_data' in values:
|
|
values['metadata'] = values['meta_data']
|
|
return values
|