fix server warmup

This commit is contained in:
Xie Weiyu 2025-02-18 11:39:45 +08:00
parent 9f1da18630
commit f029588b61
2 changed files with 17 additions and 14 deletions

View file

@ -73,13 +73,13 @@ class KTransformersInterface(TransformersInterface):
self._infer_lock = asyncio.Lock()
def decode_one_tokens(self, i):
def decode_one_tokens(self):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
global warm_uped
if self.args.use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
if self.args.use_cuda_graph and warm_uped == True:
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
@ -93,15 +93,18 @@ class KTransformersInterface(TransformersInterface):
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if self.args.use_cuda_graph:
warm_uped = True
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model(