fix server warmup

This commit is contained in:
Xie Weiyu 2025-02-18 11:39:45 +08:00
parent 9f1da18630
commit f029588b61
2 changed files with 17 additions and 14 deletions

View file

@ -219,7 +219,7 @@ class TransformersInterface(BackendInterfaceBase):
self.ever_generated_ids.add(last)
return last
def decode_one_tokens(self, i):
def decode_one_tokens(self):
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(self.args.device)
logits = self.model(
@ -299,7 +299,7 @@ class TransformersInterface(BackendInterfaceBase):
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens(i)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id:
assert self.args.batch_size == 1