mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-16 01:59:42 +00:00
Initial commit
This commit is contained in:
commit
18c42e67df
247 changed files with 53775 additions and 0 deletions
15
ktransformers/server/api/openai/__init__.py
Normal file
15
ktransformers/server/api/openai/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from .assistants import router as assistants_router,create_default_assistant
|
||||
from .endpoints.chat import router as chat_router
|
||||
from .legacy import router as legacy_router
|
||||
|
||||
router = APIRouter(prefix='/v1')
|
||||
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(chat_router)
|
||||
router.include_router(legacy_router)
|
||||
|
||||
def post_db_creation_operations():
|
||||
create_default_assistant()
|
14
ktransformers/server/api/openai/assistants/__init__.py
Normal file
14
ktransformers/server/api/openai/assistants/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from .assistants import router as assistants_router, create_default_assistant
|
||||
from .messages import router as messages_router
|
||||
from .runs import router as runs_router
|
||||
from .threads import router as threads_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
threads_router.include_router(runs_router)
|
||||
threads_router.include_router(messages_router)
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(threads_router)
|
103
ktransformers/server/api/openai/assistants/assistants.py
Normal file
103
ktransformers/server/api/openai/assistants/assistants.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
|
||||
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
|
||||
from ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject
|
||||
from ktransformers.server.schemas.base import DeleteResponse, Order
|
||||
from ktransformers.server.config.log import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/assistants")
|
||||
assistant_manager = AssistantDatabaseManager()
|
||||
runs_manager = RunsDatabaseManager()
|
||||
|
||||
|
||||
@router.post("/", tags=['openai'])
|
||||
async def create_assistant(
|
||||
assistant: AssistantCreate,
|
||||
):
|
||||
return assistant_manager.db_create_assistant(assistant).as_api_response()
|
||||
|
||||
|
||||
@router.get("/", tags=['openai'])
|
||||
async def list_assistants(
|
||||
limit: Optional[int] = 20,
|
||||
order: Order = Order.DESC,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
):
|
||||
return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)]
|
||||
|
||||
# list assistant with status
|
||||
|
||||
|
||||
@router.get("/status", tags=['openai-ext'])
|
||||
async def list_assistants_with_status(
|
||||
limit: Optional[int] = 20,
|
||||
order: Order = Order.DESC,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
):
|
||||
return assistant_manager.db_list_assistants(limit, order)
|
||||
|
||||
|
||||
@router.get("/{assistant_id}", tags=['openai'])
|
||||
async def retrieve_assistant(
|
||||
assistant_id: str,
|
||||
):
|
||||
return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response()
|
||||
|
||||
|
||||
@router.post("/{assistant_id}", tags=['openai'])
|
||||
async def modify_assistant(
|
||||
assistant_id: str,
|
||||
assistant: AssistantModify,
|
||||
):
|
||||
return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response()
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}", tags=['openai'], response_model=DeleteResponse)
|
||||
async def delete_assistant(assistant_id: str):
|
||||
assistant_manager.db_delete_assistant_by_id(assistant_id)
|
||||
return DeleteResponse(id=assistant_id, object="assistant.deleted")
|
||||
|
||||
|
||||
@router.get("/{assistant_id}/related_thread", tags=['openai'])
|
||||
async def get_related_thread(assistant_id: ObjectID):
|
||||
assistant = assistant_manager.db_get_assistant_by_id(assistant_id)
|
||||
return assistant.get_related_threads_ids()
|
||||
|
||||
|
||||
def create_default_assistant():
|
||||
logger.info('Creating default assistant')
|
||||
if assistant_manager.db_count_assistants() == 0:
|
||||
default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name="KT Assistant",
|
||||
model="default model",
|
||||
instructions="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ +
|
||||
"""Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ +
|
||||
"""Please ensure that your responses are socially unbiased and positive in nature."""))
|
||||
default_assistant.build_status.status = AssistantBuildStatus.Status.completed
|
||||
default_assistant.sync_db()
|
||||
|
||||
|
||||
# unit test
|
||||
client = TestClient(router)
|
||||
|
||||
|
||||
def test_create_assistant():
|
||||
ass_create = AssistantCreate(model="awesome model", instructions="hello")
|
||||
|
||||
res = client.post("/", json=ass_create.model_dump(mode="json"))
|
||||
|
||||
assert res.status_code == 200
|
||||
assistant = AssistantObject.model_validate(res.json())
|
||||
|
||||
assert assistant.model == ass_create.model
|
||||
assert assistant.instructions == ass_create.instructions
|
||||
|
||||
res = client.get(f"/{assistant.id}")
|
||||
ass1 = AssistantObject.model_validate(res.json())
|
||||
assert assistant == ass1
|
54
ktransformers/server/api/openai/assistants/messages.py
Normal file
54
ktransformers/server/api/openai/assistants/messages.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ktransformers.server.exceptions import not_implemented
|
||||
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify
|
||||
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
|
||||
from ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order
|
||||
from ktransformers.server.backend.base import ThreadContext
|
||||
from ktransformers.server.utils.create_interface import get_thread_context_manager
|
||||
router = APIRouter()
|
||||
message_manager = MessageDatabaseManager()
|
||||
|
||||
|
||||
@router.post("/{thread_id}/messages", tags=['openai'], response_model=MessageObject)
|
||||
async def create_message(thread_id: str, msg: MessageCreate):
|
||||
message = message_manager.db_create_message(
|
||||
thread_id, msg, MessageObject.Status.in_progress)
|
||||
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
|
||||
if ctx is not None:
|
||||
ctx.put_user_message(message)
|
||||
return message
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages", tags=['openai'], response_model=List[MessageObject])
|
||||
async def list_messages(
|
||||
thread_id: str,
|
||||
limit: Optional[int] = 20,
|
||||
order: Order = Order.DESC,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
run_id: Optional[str] = None,
|
||||
):
|
||||
return message_manager.db_list_messages_of_thread(thread_id, limit, order)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
|
||||
async def retrieve_message(thread_id: ObjectID, message_id: ObjectID):
|
||||
return message_manager.db_get_message_by_id(thread_id, message_id)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=MessageObject)
|
||||
async def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):
|
||||
#raise not_implemented('modify message not implemented')
|
||||
raise not_implemented('modify message')
|
||||
|
||||
|
||||
@router.delete("/{thread_id}/messages/{message_id}", tags=['openai'], response_model=DeleteResponse)
|
||||
async def delete_message(thread_id: ObjectID, message_id: ObjectID):
|
||||
ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)
|
||||
if ctx is not None:
|
||||
ctx.delete_user_message(message_id)
|
||||
message_manager.db_delete_message_by_id(thread_id, message_id)
|
||||
return DeleteResponse(id=message_id, object='thread.message.deleted')
|
99
ktransformers/server/api/openai/assistants/runs.py
Normal file
99
ktransformers/server/api/openai/assistants/runs.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
|
||||
from ktransformers.server.backend.base import ThreadContext
|
||||
from ktransformers.server.schemas.assistants.runs import RunCreate,RunObject,RunThreadCreate,RunModify,RunSubmit
|
||||
from ktransformers.server.schemas.assistants.streaming import api_stream_response
|
||||
from ktransformers.server.utils.create_interface import get_thread_context_manager
|
||||
from ktransformers.server.schemas.base import Order
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.server.exceptions import internal_server_error
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
runs_manager = RunsDatabaseManager()
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs",tags=['openai'])
|
||||
async def create_run(request: Request, thread_id: str, run_create: RunCreate):
|
||||
if run_create.stream:
|
||||
async def inner():
|
||||
run = runs_manager.db_create_run(thread_id, run_create)
|
||||
yield run.stream_response_with_event(event=RunObject.Status.created)
|
||||
|
||||
ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)
|
||||
|
||||
async for event in ctx.work():
|
||||
yield event
|
||||
return api_stream_response(request, inner())
|
||||
else:
|
||||
run = runs_manager.db_create_run(thread_id, run_create)
|
||||
ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)
|
||||
async for event in ctx.work():
|
||||
pass
|
||||
return run
|
||||
|
||||
|
||||
@router.post("/runs",tags=['openai'], response_model=RunObject)
|
||||
async def create_thread_and_run(run_thread: RunThreadCreate):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs",tags=['openai'], response_model=List[RunObject])
|
||||
async def list_runs(
|
||||
thread_id: str,
|
||||
limit: Optional[int] = 20,
|
||||
order: Optional[Order] = Order.DESC,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject)
|
||||
async def retrieve_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
):
|
||||
runobj= runs_manager.db_get_run(run_id)
|
||||
assert runobj.thread_id == thread_id
|
||||
return runobj
|
||||
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}",tags=['openai'], response_model=RunObject)
|
||||
async def modify_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
run: RunModify,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=['openai'],response_model=RunObject)
|
||||
async def submit_tool_outputs_to_run(thread_id: str, run_id: str, submit: RunSubmit):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel",tags=['openai'], response_model=RunObject)
|
||||
async def cancel_run(thread_id: str, run_id: str):
|
||||
ctx: ThreadContext = await get_thread_context_manager().get_context_by_thread_id(thread_id)
|
||||
if ctx is not None:
|
||||
if ctx.run is None:
|
||||
logger.warn(f'Run {ctx.run.id} is expected to be in_progress, but no context is found')
|
||||
raise internal_server_error('ctx do not have run')
|
||||
|
||||
if ctx.run.id == run_id:
|
||||
logger.info(f'Cancelling thread: {thread_id} and run: {run_id}')
|
||||
ctx.run.stream_response_with_event(RunObject.Status.cancelling)
|
||||
return ctx.run
|
||||
else:
|
||||
run = runs_manager.db_get_run(run_id)
|
||||
logger.info(f'Run {run_id} not in this thread context')
|
||||
return run
|
||||
else:
|
||||
run = runs_manager.db_get_run(run_id)
|
||||
logger.info(f'Run {run_id} not in context manager')
|
||||
return run
|
36
ktransformers/server/api/openai/assistants/threads.py
Normal file
36
ktransformers/server/api/openai/assistants/threads.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from typing import List,Optional
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager,Order,ObjectID
|
||||
from ktransformers.server.schemas.assistants.threads import ThreadObject,ThreadCreate,ThreadModify
|
||||
from ktransformers.server.schemas.base import DeleteResponse
|
||||
from ktransformers.server.schemas.conversation import ThreadPreview
|
||||
|
||||
router = APIRouter(prefix='/threads')
|
||||
threads_manager = ThreadsDatabaseManager()
|
||||
|
||||
|
||||
@router.post("/",tags=['openai'], response_model=ThreadObject)
|
||||
async def create_thread(thread: ThreadCreate):
|
||||
return threads_manager.db_create_thread(thread)
|
||||
|
||||
|
||||
@router.get("/", tags=['openai-ext'],response_model=List[ThreadPreview])
|
||||
async def list_threads(limit: Optional[int] = 20, order: Order = Order.DESC):
|
||||
return threads_manager.db_list_threads_preview(limit, order)
|
||||
|
||||
|
||||
@router.get("/{thread_id}",tags=['openai'], response_model=ThreadObject)
|
||||
async def retrieve_thread(thread_id: ObjectID):
|
||||
return threads_manager.db_get_thread_by_id(thread_id)
|
||||
|
||||
|
||||
@router.post("/{thread_id}",tags=['openai'], response_model=ThreadObject)
|
||||
async def modify_thread(thread_id: ObjectID, thread: ThreadModify):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@router.delete("/{thread_id}",tags=['openai'], response_model=DeleteResponse)
|
||||
async def delete_thread(thread_id: ObjectID):
|
||||
threads_manager.db_delete_thread_by_id(thread_id=thread_id)
|
||||
return DeleteResponse(id=thread_id, object='thread.deleted')
|
0
ktransformers/server/api/openai/endpoints/__init__.py
Normal file
0
ktransformers/server/api/openai/endpoints/__init__.py
Normal file
34
ktransformers/server/api/openai/endpoints/chat.py
Normal file
34
ktransformers/server/api/openai/endpoints/chat.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import json
|
||||
from time import time
|
||||
from uuid import uuid4
|
||||
from fastapi import APIRouter
|
||||
from fastapi.requests import Request
|
||||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
|
||||
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject
|
||||
from ktransformers.server.backend.base import BackendInterfaceBase
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post('/chat/completions',tags=['openai'])
|
||||
async def chat_completion(request:Request,create:ChatCompletionCreate):
|
||||
id = str(uuid4())
|
||||
|
||||
interface: BackendInterfaceBase = get_interface()
|
||||
# input_ids = interface.format_and_tokenize_input_ids(id,messages=create.get_tokenizer_messages())
|
||||
|
||||
input_message = [json.loads(m.model_dump_json()) for m in create.messages]
|
||||
|
||||
if create.stream:
|
||||
async def inner():
|
||||
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
|
||||
async for token in interface.inference(input_message,id):
|
||||
chunk.set_token(token)
|
||||
yield chunk
|
||||
return chat_stream_response(request,inner())
|
||||
else:
|
||||
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time()))
|
||||
async for token in interface.inference(input_message,id):
|
||||
comp.append_token(token)
|
||||
return comp
|
6
ktransformers/server/api/openai/legacy/__init__.py
Normal file
6
ktransformers/server/api/openai/legacy/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from fastapi import APIRouter
|
||||
|
||||
from . import completions
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(completions.router)
|
33
ktransformers/server/api/openai/legacy/completions.py
Normal file
33
ktransformers/server/api/openai/legacy/completions.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import json
|
||||
from time import time
|
||||
from uuid import uuid4
|
||||
from fastapi import APIRouter
|
||||
from fastapi.requests import Request
|
||||
from ktransformers.server.utils.create_interface import get_interface
|
||||
from ktransformers.server.schemas.assistants.streaming import stream_response
|
||||
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/completions",tags=['openai'])
|
||||
async def create_completion(request:Request,create:CompletionCreate):
|
||||
id = str(uuid4())
|
||||
|
||||
interface = get_interface()
|
||||
print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
|
||||
|
||||
|
||||
|
||||
if create.stream:
|
||||
async def inner():
|
||||
async for token in interface.inference(create.prompt,id):
|
||||
d = {'choices':[{'delta':{'content':token}}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}
|
||||
yield f"data:{json.dumps(d)}\n\n"
|
||||
return stream_response(request,inner())
|
||||
else:
|
||||
comp = CompletionObject(id=id,object='text_completion',created=int(time()))
|
||||
async for token in interface.inference(create.prompt,id):
|
||||
comp.append_token(token)
|
||||
return comp
|
Loading…
Add table
Add a link
Reference in a new issue