Initial commit

This commit is contained in:
chenxl 2024-07-27 16:06:58 +08:00
commit 18c42e67df
247 changed files with 53775 additions and 0 deletions

View file

@ -0,0 +1,15 @@
from fastapi import APIRouter
from .assistants import router as assistants_router,create_default_assistant
from .endpoints.chat import router as chat_router
from .legacy import router as legacy_router
router = APIRouter(prefix='/v1')
router.include_router(assistants_router)
router.include_router(chat_router)
router.include_router(legacy_router)
def post_db_creation_operations():
create_default_assistant()

View file

@ -0,0 +1,14 @@
from fastapi import APIRouter
from .assistants import router as assistants_router, create_default_assistant
from .messages import router as messages_router
from .runs import router as runs_router
from .threads import router as threads_router
router = APIRouter()
threads_router.include_router(runs_router)
threads_router.include_router(messages_router)
router.include_router(assistants_router)
router.include_router(threads_router)

View file

@ -0,0 +1,103 @@
from typing import Optional
from fastapi import APIRouter
from fastapi.testclient import TestClient
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject
from ktransformers.server.schemas.base import DeleteResponse, Order
from ktransformers.server.config.log import logger
router = APIRouter(prefix="/assistants")
assistant_manager = AssistantDatabaseManager()
runs_manager = RunsDatabaseManager()
@router.post("/", tags=['openai'])
async def create_assistant(
assistant: AssistantCreate,
):
return assistant_manager.db_create_assistant(assistant).as_api_response()
@router.get("/", tags=['openai'])
async def list_assistants(
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
):
return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)]
# list assistant with status
@router.get("/status", tags=['openai-ext'])
async def list_assistants_with_status(
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
):
return assistant_manager.db_list_assistants(limit, order)
@router.get("/{assistant_id}", tags=['openai'])
async def retrieve_assistant(
assistant_id: str,
):
return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response()
@router.post("/{assistant_id}", tags=['openai'])
async def modify_assistant(
assistant_id: str,
assistant: AssistantModify,
):
return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response()
@router.delete("/{assistant_id}", tags=['openai'], response_model=DeleteResponse)
async def delete_assistant(assistant_id: str):
assistant_manager.db_delete_assistant_by_id(assistant_id)
return DeleteResponse(id=assistant_id, object="assistant.deleted")
@router.get("/{assistant_id}/related_thread", tags=['openai'])
async def get_related_thread(assistant_id: ObjectID):
assistant = assistant_manager.db_get_assistant_by_id(assistant_id)
return assistant.get_related_threads_ids()
def create_default_assistant():
logger.info('Creating default assistant')
if assistant_manager.db_count_assistants() == 0:
default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name="KT Assistant",
model="default model",
instructions="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ +
"""Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ +
"""Please ensure that your responses are socially unbiased and positive in nature."""))
default_assistant.build_status.status = AssistantBuildStatus.Status.completed
default_assistant.sync_db()
# unit test
client = TestClient(router)
def test_create_assistant():
ass_create = AssistantCreate(model="awesome model", instructions="hello")
res = client.post("/", json=ass_create.model_dump(mode="json"))
assert res.status_code == 200
assistant = AssistantObject.model_validate(res.json())
assert assistant.model == ass_create.model
assert assistant.instructions == ass_create.instructions
res = client.get(f"/{assistant.id}")
ass1 = AssistantObject.model_validate(res.json())
assert assistant == ass1

View file

@ -0,0 +1,54 @@
from typing import List, Optional
from fastapi import APIRouter
from ktransformers.server.exceptions import not_implemented
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order
from ktransformers.server.backend.base import ThreadContext
from ktransformers.server.utils.create_interface import get_thread_context_manager
router = APIRouter()
message_manager = MessageDatabaseManager()
@router.post("/{thread_id}/messages", tags=['openai'], response_model=MessageObject)
async def create_message(thread_id: str, msg: MessageCreate):
message = message_manager.db_create_message(
thread_id, msg, MessageObject.Status.in_progress)
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
ctx.put_user_message(message)
return message
@router.get("/{thread_id}/messages", tags=['openai'], response_model=List[MessageObject])
async def list_messages(
thread_id: str,
limit: Optional[int] = 20,
order: Order = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
run_id: Optional[str] = None,
):
return message_manager.db_list_messages_of_thread(thread_id, limit, order)
@router.get("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
async def retrieve_message(thread_id: ObjectID, message_id: ObjectID):
return message_manager.db_get_message_by_id(thread_id, message_id)
@router.post("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
async def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):
#raise not_implemented('modify message not implemented')
raise not_implemented('modify message')
@router.delete("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=DeleteResponse)
async def delete_message(thread_id: ObjectID, message_id: ObjectID):
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
ctx.delete_user_message(message_id)
message_manager.db_delete_message_by_id(thread_id, message_id)
return DeleteResponse(id=message_id, object='thread.message.deleted')

View file

@ -0,0 +1,99 @@
from typing import List, Optional
from fastapi import APIRouter, Request
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.backend.base import ThreadContext
from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject,RunThreadCreate,RunModify,RunSubmit
from ktransformers.server.schemas.assistants.streaming import api_stream_response
from ktransformers.server.utils.create_interface import get_thread_context_manager
from ktransformers.server.schemas.base import Order
from ktransformers.server.config.log import logger
from ktransformers.server.exceptions import internal_server_error
router = APIRouter()
runs_manager = RunsDatabaseManager()
@router.post("/{thread_id}/runs",tags=['openai'])
async def create_run(request: Request, thread_id: str, run_create: RunCreate):
if run_create.stream:
async def inner():
run = runs_manager.db_create_run(thread_id, run_create)
yield run.stream_response_with_event(event=RunObject.Status.created)
ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)
async for event in ctx.work():
yield event
return api_stream_response(request, inner())
else:
run = runs_manager.db_create_run(thread_id, run_create)
ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)
async for event in ctx.work():
pass
return run
@router.post("/runs",tags=['openai'], response_model=RunObject)
async def create_thread_and_run(run_thread: RunThreadCreate):
raise NotImplementedError
@router.get("/{thread_id}/runs",tags=['openai'], response_model=List[RunObject])
async def list_runs(
thread_id: str,
limit: Optional[int] = 20,
order: Optional[Order] = Order.DESC,
after: Optional[str] = None,
before: Optional[str] = None,
):
raise NotImplementedError
@router.get("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject)
async def retrieve_run(
thread_id: str,
run_id: str,
):
runobj= runs_manager.db_get_run(run_id)
assert runobj.thread_id == thread_id
return runobj
@router.post("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject)
async def modify_run(
thread_id: str,
run_id: str,
run: RunModify,
):
raise NotImplementedError
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=['openai'],response_model=RunObject)
async def submit_tool_outputs_to_run(thread_id: str, run_id: str, submit: RunSubmit):
raise NotImplementedError
@router.post("/{thread_id}/runs/{run_id}/cancel",tags=['openai'], response_model=RunObject)
async def cancel_run(thread_id: str, run_id: str):
ctx: ThreadContext = await get_thread_context_manager().get_context_by_thread_id(thread_id)
if ctx is not None:
if ctx.run is None:
logger.warn(f'Run {ctx.run.id} is expected to be in_progress, but no context is found')
raise internal_server_error('ctx do not have run')
if ctx.run.id == run_id:
logger.info(f'Cancelling thread: {thread_id} and run: {run_id}')
ctx.run.stream_response_with_event(RunObject.Status.cancelling)
return ctx.run
else:
run = runs_manager.db_get_run(run_id)
logger.info(f'Run {run_id} not in this thread context')
return run
else:
run = runs_manager.db_get_run(run_id)
logger.info(f'Run {run_id} not in context manager')
return run

View file

@ -0,0 +1,36 @@
from typing import List,Optional
from fastapi import APIRouter
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager,Order,ObjectID
from ktransformers.server.schemas.assistants.threads import ThreadObject,ThreadCreate,ThreadModify
from ktransformers.server.schemas.base import DeleteResponse
from ktransformers.server.schemas.conversation import ThreadPreview
router = APIRouter(prefix='/threads')
threads_manager = ThreadsDatabaseManager()
@router.post("/",tags=['openai'], response_model=ThreadObject)
async def create_thread(thread: ThreadCreate):
return threads_manager.db_create_thread(thread)
@router.get("/", tags=['openai-ext'],response_model=List[ThreadPreview])
async def list_threads(limit: Optional[int] = 20, order: Order = Order.DESC):
return threads_manager.db_list_threads_preview(limit, order)
@router.get("/{thread_id}",tags=['openai'], response_model=ThreadObject)
async def retrieve_thread(thread_id: ObjectID):
return threads_manager.db_get_thread_by_id(thread_id)
@router.post("/{thread_id}",tags=['openai'], response_model=ThreadObject)
async def modify_thread(thread_id: ObjectID, thread: ThreadModify):
raise NotImplementedError
@router.delete("/{thread_id}",tags=['openai'], response_model=DeleteResponse)
async def delete_thread(thread_id: ObjectID):
threads_manager.db_delete_thread_by_id(thread_id=thread_id)
return DeleteResponse(id=thread_id, object='thread.deleted')

View file

@ -0,0 +1,34 @@
import json
from time import time
from uuid import uuid4
from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject
from ktransformers.server.backend.base import BackendInterfaceBase
router = APIRouter()
@router.post('/chat/completions',tags=['openai'])
async def chat_completion(request:Request,create:ChatCompletionCreate):
id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
if create.stream:
async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id):
chunk.set_token(token)
yield chunk
return chat_stream_response(request,inner())
else:
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time()))
async for token in interface.inference(input_message,id):
comp.append_token(token)
return comp

View file

@ -0,0 +1,6 @@
from fastapi import APIRouter
from . import completions
router = APIRouter()
router.include_router(completions.router)

View file

@ -0,0 +1,33 @@
import json
from time import time
from uuid import uuid4
from fastapi import APIRouter
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
router = APIRouter()
@router.post("/completions",tags=['openai'])
async def create_completion(request:Request,create:CompletionCreate):
id = str(uuid4())
interface = get_interface()
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
if create.stream:
async def inner():
async for token in interface.inference(create.prompt,id):
d = {'choices':[{'delta':{'content':token}}]}
yield f"data:{json.dumps(d)}\n\n"
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
yield f"data:{json.dumps(d)}\n\n"
return stream_response(request,inner())
else:
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
async for token in interface.inference(create.prompt,id):
comp.append_token(token)
return comp