mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
Allow temperature and top_p from requests
This commit is contained in:
parent
4b5991e77e
commit
8704c09192
4 changed files with 18 additions and 12 deletions
|
@ -14,7 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
|
|||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device
|
||||
|
||||
from typing import Optional
|
||||
|
||||
warm_uped = False
|
||||
|
||||
|
@ -207,7 +207,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
|
||||
return torch.tensor([self.seq_length - 1], device=device)
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
async def inference(self, local_messages, thread_id: str, temperature: Optional[float], top_p: Optional[float]):
|
||||
async with self._infer_lock:
|
||||
async for v in super().inference(local_messages, thread_id):
|
||||
async for v in super().inference(local_messages, thread_id, temperature, top_p):
|
||||
yield v
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue