mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-05 20:19:51 +00:00
fix: fix server for triton kernel
This commit is contained in:
parent
bb1cadfff3
commit
ee24eb8dc3
2 changed files with 8 additions and 4 deletions
|
@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
|
|||
from ktransformers.util.utils import get_device
|
||||
|
||||
|
||||
warm_uped = False
|
||||
|
||||
class KTransformersThreadContext(TransformersThreadContext):
|
||||
pass
|
||||
|
||||
|
@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface):
|
|||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_one_tokens(self):
|
||||
global warm_uped
|
||||
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
if self.args.use_cuda_graph:
|
||||
torch.cuda.set_device(torch_device)
|
||||
if warm_uped and self.args.use_cuda_graph:
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
|
@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
else:
|
||||
logits = self.model(self.current_ids, return_dict=False)[0]
|
||||
logits = logits[0, -1, :]
|
||||
warm_uped = True
|
||||
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
|
@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface):
|
|||
if not (type(self) is TransformersInterface):
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
|
||||
torch.cuda.set_device(device)
|
||||
if self.use_static_cache:
|
||||
logits = self.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
|
|
|
@ -106,9 +106,6 @@ def custom_openapi(app):
|
|||
def main():
|
||||
cfg = Config()
|
||||
|
||||
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
|
||||
cfg.use_cuda_graph = False
|
||||
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
# 初始化消息
|
||||
|
|
Loading…
Add table
Reference in a new issue