mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
fix server warmup
This commit is contained in:
parent
9f1da18630
commit
f029588b61
2 changed files with 17 additions and 14 deletions
|
@ -73,13 +73,13 @@ class KTransformersInterface(TransformersInterface):
|
|||
|
||||
self._infer_lock = asyncio.Lock()
|
||||
|
||||
def decode_one_tokens(self, i):
|
||||
def decode_one_tokens(self):
|
||||
device_map = self.model.gguf_loader.tensor_device_map
|
||||
torch_device = get_device("blk.0.self_attn", device_map)
|
||||
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
|
||||
global warm_uped
|
||||
if self.args.use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
|
||||
warm_uped = True
|
||||
if self.args.use_cuda_graph and warm_uped == True:
|
||||
|
||||
if not hasattr(self, "cuda_graph_runner"):
|
||||
self.cuda_graph_runner = CUDAGraphRunner()
|
||||
self.cuda_graph_runner.capture(
|
||||
|
@ -102,6 +102,9 @@ class KTransformersInterface(TransformersInterface):
|
|||
logits = logits[0, -1, :]
|
||||
return self.logits_to_token(logits)
|
||||
|
||||
if self.args.use_cuda_graph:
|
||||
warm_uped = True
|
||||
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(torch_device)
|
||||
logits = self.model(
|
||||
|
|
|
@ -219,7 +219,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.ever_generated_ids.add(last)
|
||||
return last
|
||||
|
||||
def decode_one_tokens(self, i):
|
||||
def decode_one_tokens(self):
|
||||
if self.use_static_cache:
|
||||
mask = torch.ones((1, self.seq_length)).to(self.args.device)
|
||||
logits = self.model(
|
||||
|
@ -299,7 +299,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
|
||||
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
next_token = self.decode_one_tokens(i)
|
||||
next_token = self.decode_one_tokens()
|
||||
self.profiler.inc("decode")
|
||||
if next_token == self.tokenizer.eos_token_id:
|
||||
assert self.args.batch_size == 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue