mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-15 09:39:42 +00:00
93 lines
4.2 KiB
Python
93 lines
4.2 KiB
Python
from time import time
|
|
from typing import Optional,List
|
|
from uuid import uuid4
|
|
|
|
from ktransformers.server.models.assistants.messages import Message
|
|
from ktransformers.server.models.assistants.threads import Thread
|
|
from ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject
|
|
from ktransformers.server.schemas.base import ObjectID, Order
|
|
from ktransformers.server.schemas.conversation import ThreadPreview
|
|
from ktransformers.server.utils.sql_utils import SQLUtil
|
|
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
|
|
from ktransformers.server.config.log import logger
|
|
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
|
|
|
|
class ThreadsDatabaseManager:
|
|
def __init__(self) -> None:
|
|
self.sql_util = SQLUtil()
|
|
self.message_manager = MessageDatabaseManager()
|
|
self.assistant_maanager = AssistantDatabaseManager()
|
|
|
|
def db_create_thread(self, thread: ThreadCreate):
|
|
thread_id = str(uuid4())
|
|
db_messages = []
|
|
with self.sql_util.get_db() as db:
|
|
if thread.messages is not None:
|
|
logger.debug("Creating messages first for thread")
|
|
for message in thread.messages:
|
|
db_message: Message = MessageDatabaseManager.create_db_message_by_core(
|
|
message)
|
|
db_message.role = "user"
|
|
db_message.thread_id = thread_id
|
|
db.add(db_message)
|
|
db_messages.append(db_message)
|
|
|
|
db_thread = Thread(
|
|
**thread.model_dump(exclude="messages"),
|
|
id=str(uuid4()),
|
|
created_at=int(time()),
|
|
messages=db_messages,
|
|
)
|
|
|
|
self.sql_util.db_add_commit_refresh(db, db_thread)
|
|
thread_obj = ThreadObject.model_validate(db_thread.__dict__)
|
|
|
|
if 'assistant_id' in thread.meta_data:
|
|
# assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db)
|
|
assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'])
|
|
logger.info(
|
|
f'Append this related thread to assistant {assistant.id}')
|
|
assistant.append_related_threads([thread_obj.id])
|
|
assistant.sync_db(db)
|
|
return thread_obj
|
|
|
|
def db_get_thread_by_id(self, thread_id: ObjectID):
|
|
with self.sql_util.get_db() as db:
|
|
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
|
return ThreadObject.model_validate(db_thread.__dict__)
|
|
|
|
def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]:
|
|
with self.sql_util.get_db() as db:
|
|
query = db.query(Thread).order_by(order.to_sqlalchemy_order()(
|
|
Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id'))
|
|
|
|
if limit is not None:
|
|
db_threads = query.limit(limit)
|
|
else:
|
|
db_threads = query.all()
|
|
|
|
return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads]
|
|
|
|
def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]:
|
|
threads = self.db_list_threads(limit, order)
|
|
previews = []
|
|
for thread in threads:
|
|
messages = self.message_manager.db_list_messages_of_thread(
|
|
thread.id, limit=2, order=Order.ASC)
|
|
if len(messages) == 2:
|
|
message = messages[0]
|
|
assistant = self.assistant_maanager.db_get_assistant_by_id(
|
|
messages[1].assistant_id)
|
|
else:
|
|
message = None
|
|
assistant = None
|
|
previews.append(ThreadPreview(
|
|
assistant=assistant, thread=thread, first_message=message))
|
|
return previews
|
|
|
|
def db_delete_thread_by_id(self, thread_id: ObjectID):
|
|
with self.sql_util.get_db() as db:
|
|
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
|
db.delete(db_thread)
|
|
# TODO delete related messages and runs and other stuff or just gc
|
|
db.commit()
|