feat: add prefix cache for server

This commit is contained in:
ceerrep 2025-02-17 00:10:55 +08:00
parent c515cc49a5
commit bb0ccc7b1a
5 changed files with 132 additions and 55 deletions

View file

@ -121,33 +121,53 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
if is_new:
self.cache.reset()
self.ever_generated_ids.clear()
former_seq_length = 0
self.seq_length = input_ids_length
self.generated_ids = torch.zeros(
self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
else:
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
same_prefix = 0
flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
@ -168,6 +188,7 @@ class KTransformersInterface(TransformersInterface):
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@ -179,4 +200,4 @@ class KTransformersInterface(TransformersInterface):
async def inference(self, local_messages, thread_id: str):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id):
yield v
yield v