fix: fix server for triton kernel

This commit is contained in:
ceerrep 2025-02-17 18:08:45 +08:00
parent bb1cadfff3
commit ee24eb8dc3
2 changed files with 8 additions and 4 deletions

View file

@ -16,6 +16,8 @@ from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
warm_uped = False
class KTransformersThreadContext(TransformersThreadContext):
pass
@ -74,10 +76,13 @@ class KTransformersInterface(TransformersInterface):
self._infer_lock = asyncio.Lock()
def decode_one_tokens(self):
global warm_uped
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph:
torch.cuda.set_device(torch_device)
if warm_uped and self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
@ -113,6 +118,7 @@ class KTransformersInterface(TransformersInterface):
else:
logits = self.model(self.current_ids, return_dict=False)[0]
logits = logits[0, -1, :]
warm_uped = True
return self.logits_to_token(logits)
@ -176,6 +182,7 @@ class KTransformersInterface(TransformersInterface):
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,

View file

@ -106,9 +106,6 @@ def custom_openapi(app):
def main():
cfg = Config()
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
cfg.use_cuda_graph = False
arg_parser = ArgumentParser(cfg)
# 初始化消息