feat: add prefix cache for server

This commit is contained in:
ceerrep 2025-02-17 00:10:55 +08:00
parent c515cc49a5
commit bb0ccc7b1a
5 changed files with 132 additions and 55 deletions

View file

@ -174,6 +174,18 @@ class StaticCache(transformers.StaticCache):
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
if self.value_cache[layer_idx] is not None: if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0
def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
if self.is_MLA:
k_cache = self.key_cache[layer_idx]
k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()
else:
self.key_cache[layer_idx][..., start_pos:, :].zero_()
self.value_cache[layer_idx][..., start_pos:, :].zero_()
self.past_tokens[layer_idx] = start_pos
def get_max_cache_shape(self) -> Tuple[int, int, int, int]: def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache.""" """Returns the maximum shape of the cache."""

View file

@ -90,7 +90,8 @@ class ArgumentParser:
# user config # user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key) parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think) parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)
# web config # web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)

View file

@ -121,23 +121,42 @@ class KTransformersInterface(TransformersInterface):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0") device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
if is_new: if is_new:
self.cache.reset()
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
former_seq_length = 0 same_prefix = 0
self.seq_length = input_ids_length flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1, input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int, dtype=torch.int,
device=self.args.device, device=self.args.device,
) )
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else: else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
@ -148,6 +167,7 @@ class KTransformersInterface(TransformersInterface):
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
) )
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device) cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
@ -168,6 +188,7 @@ class KTransformersInterface(TransformersInterface):
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)

View file

@ -198,14 +198,28 @@ class TransformersInterface(BackendInterfaceBase):
self.seq_length += 1 self.seq_length += 1
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
def logits_to_token(self, logits: torch.Tensor): def prepare_logits_wrapper(self, inputs, device):
logits = logits / self.args.temperature if self.args.temperature!=0 else logits generation_config, model_kwargs = self.model._prepare_generation_config(
None, max_length=self.args.max_new_tokens,
do_sample=True,
top_k=self.args.top_k,
top_p=self.args.top_p,
temperature=self.args.temperature,
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
)
self.inputs = inputs
self.generation_config = generation_config
try: # transformers==4.43
self.logits_warper = (
self.model._get_logits_warper(generation_config,device=device)
)
except:
self.logits_warper = (
self.model._get_logits_warper(generation_config)
)
for token_idx in self.ever_generated_ids: def logits_to_token(self, logits: torch.Tensor):
if logits[token_idx] < 0: logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
logits[token_idx] *= self.args.repetition_penalty
else:
logits[token_idx] /= self.args.repetition_penalty
probs = torch.nn.functional.softmax(logits, dim=-1) probs = torch.nn.functional.softmax(logits, dim=-1)
@ -239,21 +253,40 @@ class TransformersInterface(BackendInterfaceBase):
@torch.no_grad @torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool): def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1] input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}") logger.debug(f"input_ids: {input_ids.shape}")
if is_new: if is_new:
self.cache.reset()
self.ever_generated_ids.clear() self.ever_generated_ids.clear()
former_seq_length = 0 same_prefix = 0
self.seq_length = input_ids_length flat_input_ids = input_ids.flatten()
if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros( self.generated_ids = torch.zeros(
self.args.batch_size, self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1, input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int, dtype=torch.int,
device=self.args.device, device=self.args.device,
) )
self.seq_length = 1
flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else: else:
break
logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}") logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length former_seq_length = self.seq_length
self.seq_length += input_ids_length self.seq_length += input_ids_length
@ -264,6 +297,7 @@ class TransformersInterface(BackendInterfaceBase):
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
) )
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1) self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}") logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device) cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int) self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
@ -285,6 +319,7 @@ class TransformersInterface(BackendInterfaceBase):
else: else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0] logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :]) next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token) yield self.append_new_tokens(next_token)
@ -315,6 +350,7 @@ class TransformersInterface(BackendInterfaceBase):
return True return True
async def inference(self, local_messages, thread_id: str): async def inference(self, local_messages, thread_id: str):
self.streamer.reset()
self.profiler.create_and_start_timer("tokenize") self.profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
@ -325,7 +361,7 @@ class TransformersInterface(BackendInterfaceBase):
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
if Config().user_force_think: if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device) token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat( input_ids = torch.cat(
[input_ids, token_thinks], dim=1 [input_ids, token_thinks], dim=1
) )
@ -333,11 +369,14 @@ class TransformersInterface(BackendInterfaceBase):
self.profiler.pause_timer("tokenize") self.profiler.pause_timer("tokenize")
self.profiler.create_and_start_timer("prefill") self.profiler.create_and_start_timer("prefill")
if Config().user_force_think:
t = "<think>\n"
print(t,end="",flush=True)
yield t
for t in self.prefill(input_ids, self.check_is_new(thread_id)): for t in self.prefill(input_ids, self.check_is_new(thread_id)):
# output think token after prefill done
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think
if t is not None: if t is not None:
print(t, end="",flush=True) print(t, end="",flush=True)
yield t yield t

View file

@ -105,6 +105,10 @@ def custom_openapi(app):
def main(): def main():
cfg = Config() cfg = Config()
# Temporarily disable cuda graph by default because of a bug in the prefix cache.
cfg.use_cuda_graph = False
arg_parser = ArgumentParser(cfg) arg_parser = ArgumentParser(cfg)
# 初始化消息 # 初始化消息