mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-15 17:49:42 +00:00
Initial commit
This commit is contained in:
commit
18c42e67df
247 changed files with 53775 additions and 0 deletions
93
ktransformers/server/crud/assistants/threads.py
Normal file
93
ktransformers/server/crud/assistants/threads.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from time import time
|
||||
from typing import Optional,List
|
||||
from uuid import uuid4
|
||||
|
||||
from ktransformers.server.models.assistants.messages import Message
|
||||
from ktransformers.server.models.assistants.threads import Thread
|
||||
from ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject
|
||||
from ktransformers.server.schemas.base import ObjectID, Order
|
||||
from ktransformers.server.schemas.conversation import ThreadPreview
|
||||
from ktransformers.server.utils.sql_utils import SQLUtil
|
||||
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
|
||||
|
||||
class ThreadsDatabaseManager:
|
||||
def __init__(self) -> None:
|
||||
self.sql_util = SQLUtil()
|
||||
self.message_manager = MessageDatabaseManager()
|
||||
self.assistant_maanager = AssistantDatabaseManager()
|
||||
|
||||
def db_create_thread(self, thread: ThreadCreate):
|
||||
thread_id = str(uuid4())
|
||||
db_messages = []
|
||||
with self.sql_util.get_db() as db:
|
||||
if thread.messages is not None:
|
||||
logger.debug("Creating messages first for thread")
|
||||
for message in thread.messages:
|
||||
db_message: Message = MessageDatabaseManager.create_db_message_by_core(
|
||||
message)
|
||||
db_message.role = "user"
|
||||
db_message.thread_id = thread_id
|
||||
db.add(db_message)
|
||||
db_messages.append(db_message)
|
||||
|
||||
db_thread = Thread(
|
||||
**thread.model_dump(exclude="messages"),
|
||||
id=str(uuid4()),
|
||||
created_at=int(time()),
|
||||
messages=db_messages,
|
||||
)
|
||||
|
||||
self.sql_util.db_add_commit_refresh(db, db_thread)
|
||||
thread_obj = ThreadObject.model_validate(db_thread.__dict__)
|
||||
|
||||
if 'assistant_id' in thread.meta_data:
|
||||
# assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db)
|
||||
assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'])
|
||||
logger.info(
|
||||
f'Append this related thread to assistant {assistant.id}')
|
||||
assistant.append_related_threads([thread_obj.id])
|
||||
assistant.sync_db(db)
|
||||
return thread_obj
|
||||
|
||||
def db_get_thread_by_id(self, thread_id: ObjectID):
|
||||
with self.sql_util.get_db() as db:
|
||||
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
||||
return ThreadObject.model_validate(db_thread.__dict__)
|
||||
|
||||
def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]:
|
||||
with self.sql_util.get_db() as db:
|
||||
query = db.query(Thread).order_by(order.to_sqlalchemy_order()(
|
||||
Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id'))
|
||||
|
||||
if limit is not None:
|
||||
db_threads = query.limit(limit)
|
||||
else:
|
||||
db_threads = query.all()
|
||||
|
||||
return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads]
|
||||
|
||||
def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]:
|
||||
threads = self.db_list_threads(limit, order)
|
||||
previews = []
|
||||
for thread in threads:
|
||||
messages = self.message_manager.db_list_messages_of_thread(
|
||||
thread.id, limit=2, order=Order.ASC)
|
||||
if len(messages) == 2:
|
||||
message = messages[0]
|
||||
assistant = self.assistant_maanager.db_get_assistant_by_id(
|
||||
messages[1].assistant_id)
|
||||
else:
|
||||
message = None
|
||||
assistant = None
|
||||
previews.append(ThreadPreview(
|
||||
assistant=assistant, thread=thread, first_message=message))
|
||||
return previews
|
||||
|
||||
def db_delete_thread_by_id(self, thread_id: ObjectID):
|
||||
with self.sql_util.get_db() as db:
|
||||
db_thread = db.query(Thread).filter(Thread.id == thread_id).first()
|
||||
db.delete(db_thread)
|
||||
# TODO delete related messages and runs and other stuff or just gc
|
||||
db.commit()
|
Loading…
Add table
Add a link
Reference in a new issue