kvcache-ai-ktransformers/ktransformers/server/api/openai/assistants/assistants.py
2024-07-27 16:06:58 +08:00

103 lines
3.9 KiB
Python

from typing import Optional
from fastapi import APIRouter
from fastapi.testclient import TestClient
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject
from ktransformers.server.schemas.base import DeleteResponse, Order
from ktransformers.server.config.log import logger
router = APIRouter(prefix="/assistants")
assistant_manager = AssistantDatabaseManager()
runs_manager = RunsDatabaseManager()
@router.post("/", tags=['openai'])
async def create_assistant(
assistant: AssistantCreate,
):
return assistant_manager.db_create_assistant(assistant).as_api_response()
@router.get("/", tags=['openai'])
async def list_assistants(
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
):
return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)]
# list assistant with status
@router.get("/status", tags=['openai-ext'])
async def list_assistants_with_status(
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
):
return assistant_manager.db_list_assistants(limit, order)
@router.get("/{assistant_id}", tags=['openai'])
async def retrieve_assistant(
assistant_id: str,
):
return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response()
@router.post("/{assistant_id}", tags=['openai'])
async def modify_assistant(
assistant_id: str,
assistant: AssistantModify,
):
return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response()
@router.delete("/{assistant_id}", tags=['openai'], response_model=DeleteResponse)
async def delete_assistant(assistant_id: str):
assistant_manager.db_delete_assistant_by_id(assistant_id)
return DeleteResponse(id=assistant_id, object="assistant.deleted")
@router.get("/{assistant_id}/related_thread", tags=['openai'])
async def get_related_thread(assistant_id: ObjectID):
assistant = assistant_manager.db_get_assistant_by_id(assistant_id)
return assistant.get_related_threads_ids()
def create_default_assistant():
logger.info('Creating default assistant')
if assistant_manager.db_count_assistants() == 0:
default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name="KT Assistant",
model="default model",
instructions="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ +
"""Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ +
"""Please ensure that your responses are socially unbiased and positive in nature."""))
default_assistant.build_status.status = AssistantBuildStatus.Status.completed
default_assistant.sync_db()
# unit test
client = TestClient(router)
def test_create_assistant():
ass_create = AssistantCreate(model="awesome model", instructions="hello")
res = client.post("/", json=ass_create.model_dump(mode="json"))
assert res.status_code == 200
assistant = AssistantObject.model_validate(res.json())
assert assistant.model == ass_create.model
assert assistant.instructions == ass_create.instructions
res = client.get(f"/{assistant.id}")
ass1 = AssistantObject.model_validate(res.json())
assert assistant == ass1