Initial commit

This commit is contained in:
chenxl 2024-07-27 16:06:58 +08:00
commit 18c42e67df
247 changed files with 53775 additions and 0 deletions

View file

@ -0,0 +1,54 @@
from asyncio import Lock
from typing import Dict, Optional
from ktransformers.server.backend.base import ThreadContext, BackendInterfaceBase
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.server.backend.interfaces.transformers import TransformersThreadContext
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
class ThreadContextManager:
lock: Lock
threads_context: Dict[ObjectID, ThreadContext]
interface: BackendInterfaceBase
def __init__(self,interface) -> None:
logger.debug(f"Creating Context Manager")
self.lock = Lock()
self.threads_context = {}
self.interface = interface
pass
async def get_context_by_run_object(self, run: RunObject) -> ThreadContext:
async with self.lock:
logger.debug(f"keys {self.threads_context.keys()}")
if run.thread_id not in self.threads_context:
logger.debug(f"new inference context {run.thread_id}")
if isinstance(self.interface, ExllamaInterface):
new_context = ExllamaThreadContext(run, self.interface)
elif isinstance(self.interface, KTransformersInterface):
new_context = KTransformersThreadContext(run, self.interface)
elif isinstance(self.interface, TransformersInterface):
new_context = TransformersThreadContext(run, self.interface)
else:
raise NotImplementedError
self.threads_context[run.thread_id] = new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re = self.threads_context[run.thread_id]
re.update_by_run(run)
return re
async def get_context_by_thread_id(self, thread_id: ObjectID) -> Optional[ThreadContext]:
async with self.lock:
if thread_id in self.threads_context:
logger.debug(f'found context for thread {thread_id}')
return self.threads_context[thread_id]
else:
logger.debug(f'no context for thread {thread_id}')
return None