kvcache-ai-ktransformers/ktransformers/server/backend/context_manager.py
2024-07-27 16:06:58 +08:00

54 lines
No EOL
2.6 KiB
Python

from asyncio import Lock
from typing import Dict, Optional
from ktransformers.server.backend.base import ThreadContext, BackendInterfaceBase
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.server.backend.interfaces.transformers import TransformersThreadContext
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
class ThreadContextManager:
lock: Lock
threads_context: Dict[ObjectID, ThreadContext]
interface: BackendInterfaceBase
def __init__(self,interface) -> None:
logger.debug(f"Creating Context Manager")
self.lock = Lock()
self.threads_context = {}
self.interface = interface
pass
async def get_context_by_run_object(self, run: RunObject) -> ThreadContext:
async with self.lock:
logger.debug(f"keys {self.threads_context.keys()}")
if run.thread_id not in self.threads_context:
logger.debug(f"new inference context {run.thread_id}")
if isinstance(self.interface, ExllamaInterface):
new_context = ExllamaThreadContext(run, self.interface)
elif isinstance(self.interface, KTransformersInterface):
new_context = KTransformersThreadContext(run, self.interface)
elif isinstance(self.interface, TransformersInterface):
new_context = TransformersThreadContext(run, self.interface)
else:
raise NotImplementedError
self.threads_context[run.thread_id] = new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re = self.threads_context[run.thread_id]
re.update_by_run(run)
return re
async def get_context_by_thread_id(self, thread_id: ObjectID) -> Optional[ThreadContext]:
async with self.lock:
if thread_id in self.threads_context:
logger.debug(f'found context for thread {thread_id}')
return self.threads_context[thread_id]
else:
logger.debug(f'no context for thread {thread_id}')
return None