Initial commit

This commit is contained in:
chenxl 2024-07-27 16:06:58 +08:00
commit 18c42e67df
247 changed files with 53775 additions and 0 deletions

View file

@ -0,0 +1,93 @@
from time import time
from typing import Optional,List
from uuid import uuid4
from ktransformers.server.models.assistants.messages import Message
from ktransformers.server.models.assistants.threads import Thread
from ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.schemas.conversation import ThreadPreview
from ktransformers.server.utils.sql_utils import SQLUtil
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.config.log import logger
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
class ThreadsDatabaseManager:
def __init__(self) -> None:
self.sql_util = SQLUtil()
self.message_manager = MessageDatabaseManager()
self.assistant_maanager = AssistantDatabaseManager()
def db_create_thread(self, thread: ThreadCreate):
thread_id = str(uuid4())
db_messages = []
with self.sql_util.get_db() as db:
if thread.messages is not None:
logger.debug("Creating messages first for thread")
for message in thread.messages:
db_message: Message = MessageDatabaseManager.create_db_message_by_core(
message)
db_message.role = "user"
db_message.thread_id = thread_id
db.add(db_message)
db_messages.append(db_message)
db_thread = Thread(
**thread.model_dump(exclude="messages"),
id=str(uuid4()),
created_at=int(time()),
messages=db_messages,
)
self.sql_util.db_add_commit_refresh(db, db_thread)
thread_obj = ThreadObject.model_validate(db_thread.__dict__)
if 'assistant_id' in thread.meta_data:
# assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db)
assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'])
logger.info(
f'Append this related thread to assistant {assistant.id}')
assistant.append_related_threads([thread_obj.id])
assistant.sync_db(db)
return thread_obj
def db_get_thread_by_id(self, thread_id: ObjectID):
with self.sql_util.get_db() as db:
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
return ThreadObject.model_validate(db_thread.__dict__)
def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]:
with self.sql_util.get_db() as db:
query = db.query(Thread).order_by(order.to_sqlalchemy_order()(
Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id'))
if limit is not None:
db_threads = query.limit(limit)
else:
db_threads = query.all()
return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads]
def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]:
threads = self.db_list_threads(limit, order)
previews = []
for thread in threads:
messages = self.message_manager.db_list_messages_of_thread(
thread.id, limit=2, order=Order.ASC)
if len(messages) == 2:
message = messages[0]
assistant = self.assistant_maanager.db_get_assistant_by_id(
messages[1].assistant_id)
else:
message = None
assistant = None
previews.append(ThreadPreview(
assistant=assistant, thread=thread, first_message=message))
return previews
def db_delete_thread_by_id(self, thread_id: ObjectID):
with self.sql_util.get_db() as db:
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
db.delete(db_thread)
# TODO delete related messages and runs and other stuff or just gc
db.commit()