mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 22:05:30 +00:00
Merge pull request #227 from hrz6976/main
Add a lock to server inference()
This commit is contained in:
commit
ae5d9e11a9
1 changed files with 9 additions and 1 deletions
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import asyncio
|
||||||
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
from transformers import AutoTokenizer, AutoConfig, GenerationConfig
|
||||||
from ktransformers.server.backend.interfaces.transformers import (
|
from ktransformers.server.backend.interfaces.transformers import (
|
||||||
TransformersInterface,
|
TransformersInterface,
|
||||||
|
@ -70,6 +71,8 @@ class KTransformersInterface(TransformersInterface):
|
||||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||||
self.streamer = TextStreamer(self.tokenizer)
|
self.streamer = TextStreamer(self.tokenizer)
|
||||||
|
|
||||||
|
self._infer_lock = asyncio.Lock()
|
||||||
|
|
||||||
def decode_one_tokens(self):
|
def decode_one_tokens(self):
|
||||||
device_map = self.model.gguf_loader.tensor_device_map
|
device_map = self.model.gguf_loader.tensor_device_map
|
||||||
torch_device = get_device("blk.0.self_attn", device_map)
|
torch_device = get_device("blk.0.self_attn", device_map)
|
||||||
|
@ -171,4 +174,9 @@ class KTransformersInterface(TransformersInterface):
|
||||||
@property
|
@property
|
||||||
def active_cache_position(self):
|
def active_cache_position(self):
|
||||||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||||
return torch.tensor([self.seq_length - 1], device=device)
|
return torch.tensor([self.seq_length - 1], device=device)
|
||||||
|
|
||||||
|
async def inference(self, local_messages, thread_id: str):
|
||||||
|
async with self._infer_lock:
|
||||||
|
async for v in super().inference(local_messages, thread_id):
|
||||||
|
yield v
|
Loading…
Add table
Add a link
Reference in a new issue