kvcache-ai-ktransformers/ktransformers/server/backend/base.py
2024-07-27 16:06:58 +08:00

162 lines
No EOL
6.4 KiB
Python

from asyncio import Queue
from enum import Enum
import sys, os
from typing import AsyncIterator, Dict, List, Optional, Tuple
import torch
from ktransformers.server.config.log import logger
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager
from ktransformers.server.exceptions import request_error
from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler
from .args import ConfigArgs,default_args
class BackendInterfaceBase:
'''
Interface to inference frameworks. e.g. transformers, exllama.
Implement __init__ and work
'''
args: ConfigArgs
profiler:Profiler = Profiler()
def __init__(self, args:ConfigArgs = default_args):
raise NotImplementedError
async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:
'''
work can be called directly, or by ThreadContext
local_messages:
When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().
Please deal with different local_messages
request_unique_id:
unique id of different requests, useful when using cache
return:
async str output for stream update
'''
raise NotImplementedError
def report_last_time_performance(self):
try:
tokenize_time = self.profiler.get_timer_sec('tokenize')
prefill_time = self.profiler.get_timer_sec('prefill')
decode_time = self.profiler.get_timer_sec('decode')
prefill_count = self.profiler.get_counter('prefill')
decode_count = self.profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class ThreadContext:
'''
A thread context holding assistant logics
'''
args: ConfigArgs
# Assistant Logic
assistant: Optional[AssistantObject] = None
related_threads : List[ThreadObject]
thread: ThreadObject
messages: List[MessageObject] = []
run: RunObject
interface: Optional[BackendInterfaceBase] = None
queue: Optional[Queue] = None
timer: Profiler = Profiler()
def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:
self.args = args
self.thread_manager = ThreadsDatabaseManager()
self.message_manager = MessageDatabaseManager()
self.runs_manager = RunsDatabaseManager()
self.assistant_manager = AssistantDatabaseManager()
self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)
self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)
self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)
logger.debug(f"{len(self.messages)} messages loaded from database")
self.interface = interface
self.update_by_run(run,args)
def get_local_messages(self):
'''
Get local messages, as the input to interface.work
This function is intended to message preprocess e.g. apply chat template
'''
raise NotImplementedError
def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):
self.run = run
self.args = args
def put_user_message(self, message: MessageObject):
assert (
message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress
)
self.messages.append(message)
def delete_user_message(self,message_id: ObjectID):
self.messages = [m for m in self.messages if m.id != message_id]
async def work(self)->AsyncIterator:
logger.debug('start working')
user_message = self.messages[-1]
if not user_message.role.is_user():
raise request_error('user must talk before LLM can talk')
user_message.status = MessageObject.Status.completed
user_message.sync_db()
local_messages = self.get_local_messages() # must get this before we interseted reply_message
response_str_count = 0
reply_message = self.message_manager.create_message_object(
self.thread.id,
self.run.id,
MessageCreate(role=Role.assistant, content=""),
)
reply_message.assistant_id = self.assistant.id
self.messages.append(reply_message)
yield reply_message.stream_response_with_event(MessageObject.Status.created)
yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
yield self.run.stream_response_with_event(RunObject.Status.in_progress)
async for token in self.interface.inference(local_messages,self.thread.id):
if self.run.status == RunObject.Status.cancelling:
logger.warn(f'Run {self.run.id} cancelling')
break
yield reply_message.append_message_delta(token)
response_str_count+=1
if self.run.status == RunObject.Status.cancelling:
yield self.run.stream_response_with_event(RunObject.Status.cancelled)
yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)
elif self.run.status == RunObject.Status.in_progress:
yield self.run.stream_response_with_event(RunObject.Status.completed)
yield reply_message.stream_response_with_event(MessageObject.Status.completed)
else:
raise NotImplementedError(f'{self.run.status} should not appear here')
reply_message.sync_db()
self.run.sync_db()