From ee24eb8dc3626b26275a00b82715d9ebb42a9a46 Mon Sep 17 00:00:00 2001 From: ceerrep Date: Mon, 17 Feb 2025 18:08:45 +0800 Subject: [PATCH] fix: fix server for triton kernel --- ktransformers/server/backend/interfaces/ktransformers.py | 9 ++++++++- ktransformers/server/main.py | 3 --- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ktransformers/server/backend/interfaces/ktransformers.py b/ktransformers/server/backend/interfaces/ktransformers.py index efc23b9..86b97d7 100644 --- a/ktransformers/server/backend/interfaces/ktransformers.py +++ b/ktransformers/server/backend/interfaces/ktransformers.py @@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.util.utils import get_device +warm_uped = False + class KTransformersThreadContext(TransformersThreadContext): pass @@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface): self._infer_lock = asyncio.Lock() def decode_one_tokens(self): + global warm_uped + device_map = self.model.gguf_loader.tensor_device_map torch_device = get_device("blk.0.self_attn", device_map) torch_device = "cuda:0" if torch_device == "cuda" else torch_device - if self.args.use_cuda_graph: + torch.cuda.set_device(torch_device) + if warm_uped and self.args.use_cuda_graph: if not hasattr(self, "cuda_graph_runner"): self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner.capture( @@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface): else: logits = self.model(self.current_ids, return_dict=False)[0] logits = logits[0, -1, :] + warm_uped = True return self.logits_to_token(logits) @@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface): if not (type(self) is TransformersInterface): input_ids = input_ids.to("cpu") inputs_embeds = self.model.model.embed_tokens(input_ids).to(device) + torch.cuda.set_device(device) if self.use_static_cache: logits = self.model( inputs_embeds=inputs_embeds, diff --git a/ktransformers/server/main.py b/ktransformers/server/main.py index fc1f51a..f536f9c 100644 --- a/ktransformers/server/main.py +++ b/ktransformers/server/main.py @@ -106,9 +106,6 @@ def custom_openapi(app): def main(): cfg = Config() - # Temporarily disable cuda graph by default because of a bug in the prefix cache. - cfg.use_cuda_graph = False - arg_parser = ArgumentParser(cfg) # 初始化消息