mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-07 04:59:55 +00:00
103 lines
3.9 KiB
Python
103 lines
3.9 KiB
Python
from typing import Optional
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi.testclient import TestClient
|
|
|
|
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
|
|
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
|
|
from ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject
|
|
from ktransformers.server.schemas.base import DeleteResponse, Order
|
|
from ktransformers.server.config.log import logger
|
|
|
|
|
|
router = APIRouter(prefix="/assistants")
|
|
assistant_manager = AssistantDatabaseManager()
|
|
runs_manager = RunsDatabaseManager()
|
|
|
|
|
|
@router.post("/", tags=['openai'])
|
|
async def create_assistant(
|
|
assistant: AssistantCreate,
|
|
):
|
|
return assistant_manager.db_create_assistant(assistant).as_api_response()
|
|
|
|
|
|
@router.get("/", tags=['openai'])
|
|
async def list_assistants(
|
|
limit: Optional[int] = 20,
|
|
order: Order = Order.DESC,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
):
|
|
return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)]
|
|
|
|
# list assistant with status
|
|
|
|
|
|
@router.get("/status", tags=['openai-ext'])
|
|
async def list_assistants_with_status(
|
|
limit: Optional[int] = 20,
|
|
order: Order = Order.DESC,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
):
|
|
return assistant_manager.db_list_assistants(limit, order)
|
|
|
|
|
|
@router.get("/{assistant_id}", tags=['openai'])
|
|
async def retrieve_assistant(
|
|
assistant_id: str,
|
|
):
|
|
return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response()
|
|
|
|
|
|
@router.post("/{assistant_id}", tags=['openai'])
|
|
async def modify_assistant(
|
|
assistant_id: str,
|
|
assistant: AssistantModify,
|
|
):
|
|
return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response()
|
|
|
|
|
|
@router.delete("/{assistant_id}", tags=['openai'], response_model=DeleteResponse)
|
|
async def delete_assistant(assistant_id: str):
|
|
assistant_manager.db_delete_assistant_by_id(assistant_id)
|
|
return DeleteResponse(id=assistant_id, object="assistant.deleted")
|
|
|
|
|
|
@router.get("/{assistant_id}/related_thread", tags=['openai'])
|
|
async def get_related_thread(assistant_id: ObjectID):
|
|
assistant = assistant_manager.db_get_assistant_by_id(assistant_id)
|
|
return assistant.get_related_threads_ids()
|
|
|
|
|
|
def create_default_assistant():
|
|
logger.info('Creating default assistant')
|
|
if assistant_manager.db_count_assistants() == 0:
|
|
default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name="KT Assistant",
|
|
model="default model",
|
|
instructions="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ +
|
|
"""Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ +
|
|
"""Please ensure that your responses are socially unbiased and positive in nature."""))
|
|
default_assistant.build_status.status = AssistantBuildStatus.Status.completed
|
|
default_assistant.sync_db()
|
|
|
|
|
|
# unit test
|
|
client = TestClient(router)
|
|
|
|
|
|
def test_create_assistant():
|
|
ass_create = AssistantCreate(model="awesome model", instructions="hello")
|
|
|
|
res = client.post("/", json=ass_create.model_dump(mode="json"))
|
|
|
|
assert res.status_code == 200
|
|
assistant = AssistantObject.model_validate(res.json())
|
|
|
|
assert assistant.model == ass_create.model
|
|
assert assistant.instructions == ass_create.instructions
|
|
|
|
res = client.get(f"/{assistant.id}")
|
|
ass1 = AssistantObject.model_validate(res.json())
|
|
assert assistant == ass1
|