kvcache-ai-ktransformers/ktransformers/server/api/openai/assistants/messages.py
2024-07-27 16:06:58 +08:00

54 lines
2.4 KiB
Python

from typing import List, Optional
from fastapi import APIRouter
from ktransformers.server.exceptions import not_implemented
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order
from ktransformers.server.backend.base import ThreadContext
from ktransformers.server.utils.create_interface import get_thread_context_manager
router = APIRouter()
message_manager = MessageDatabaseManager()
@router.post("/{thread_id}/messages", tags=['openai'], response_model=MessageObject)
async def create_message(thread_id: str, msg: MessageCreate):
message = message_manager.db_create_message(
thread_id, msg, MessageObject.Status.in_progress)
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
ctx.put_user_message(message)
return message
@router.get("/{thread_id}/messages", tags=['openai'], response_model=List[MessageObject])
async def list_messages(
thread_id: str,
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
run_id: Optional[str] = None,
):
return message_manager.db_list_messages_of_thread(thread_id, limit, order)
@router.get("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
async def retrieve_message(thread_id: ObjectID, message_id: ObjectID):
return message_manager.db_get_message_by_id(thread_id, message_id)
@router.post("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
async def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):
#raise not_implemented('modify message not implemented')
raise not_implemented('modify message')
@router.delete("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=DeleteResponse)
async def delete_message(thread_id: ObjectID, message_id: ObjectID):
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
ctx.delete_user_message(message_id)
message_manager.db_delete_message_by_id(thread_id, message_id)
return DeleteResponse(id=message_id, object='thread.message.deleted')