mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 13:09:50 +00:00
54 lines
2.4 KiB
Python
54 lines
2.4 KiB
Python
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter
|
|
|
|
from ktransformers.server.exceptions import not_implemented
|
|
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify
|
|
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
|
|
from ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order
|
|
from ktransformers.server.backend.base import ThreadContext
|
|
from ktransformers.server.utils.create_interface import get_thread_context_manager
|
|
router = APIRouter()
|
|
message_manager = MessageDatabaseManager()
|
|
|
|
|
|
@router.post("/{thread_id}/messages", tags=['openai'], response_model=MessageObject)
|
|
async def create_message(thread_id: str, msg: MessageCreate):
|
|
message = message_manager.db_create_message(
|
|
thread_id, msg, MessageObject.Status.in_progress)
|
|
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
|
|
if ctx is not None:
|
|
ctx.put_user_message(message)
|
|
return message
|
|
|
|
|
|
@router.get("/{thread_id}/messages", tags=['openai'], response_model=List[MessageObject])
|
|
async def list_messages(
|
|
thread_id: str,
|
|
limit: Optional[int] = 20,
|
|
order: Order = Order.DESC,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
run_id: Optional[str] = None,
|
|
):
|
|
return message_manager.db_list_messages_of_thread(thread_id, limit, order)
|
|
|
|
|
|
@router.get("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
|
|
async def retrieve_message(thread_id: ObjectID, message_id: ObjectID):
|
|
return message_manager.db_get_message_by_id(thread_id, message_id)
|
|
|
|
|
|
@router.post("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
|
|
async def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):
|
|
#raise not_implemented('modify message not implemented')
|
|
raise not_implemented('modify message')
|
|
|
|
|
|
@router.delete("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=DeleteResponse)
|
|
async def delete_message(thread_id: ObjectID, message_id: ObjectID):
|
|
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
|
|
if ctx is not None:
|
|
ctx.delete_user_message(message_id)
|
|
message_manager.db_delete_message_by_id(thread_id, message_id)
|
|
return DeleteResponse(id=message_id, object='thread.message.deleted')
|