mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
feat: add prefix cache for server
This commit is contained in:
parent
c515cc49a5
commit
bb0ccc7b1a
5 changed files with 132 additions and 55 deletions
|
@ -198,14 +198,28 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.seq_length += 1
|
||||
return self.streamer.put(new_tokens)
|
||||
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = logits / self.args.temperature if self.args.temperature!=0 else logits
|
||||
def prepare_logits_wrapper(self, inputs, device):
|
||||
generation_config, model_kwargs = self.model._prepare_generation_config(
|
||||
None, max_length=self.args.max_new_tokens,
|
||||
do_sample=True,
|
||||
top_k=self.args.top_k,
|
||||
top_p=self.args.top_p,
|
||||
temperature=self.args.temperature,
|
||||
repetition_penalty=self.args.repetition_penalty # change this to modify generate config
|
||||
)
|
||||
self.inputs = inputs
|
||||
self.generation_config = generation_config
|
||||
try: # transformers==4.43
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config,device=device)
|
||||
)
|
||||
except:
|
||||
self.logits_warper = (
|
||||
self.model._get_logits_warper(generation_config)
|
||||
)
|
||||
|
||||
for token_idx in self.ever_generated_ids:
|
||||
if logits[token_idx] < 0:
|
||||
logits[token_idx] *= self.args.repetition_penalty
|
||||
else:
|
||||
logits[token_idx] /= self.args.repetition_penalty
|
||||
def logits_to_token(self, logits: torch.Tensor):
|
||||
logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
|
@ -239,31 +253,51 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
@torch.no_grad
|
||||
def prefill(self, input_ids: torch.Tensor, is_new: bool):
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
if is_new:
|
||||
self.cache.reset()
|
||||
self.ever_generated_ids.clear()
|
||||
former_seq_length = 0
|
||||
self.seq_length = input_ids_length
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
self.seq_length + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
else:
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
same_prefix = 0
|
||||
flat_input_ids = input_ids.flatten()
|
||||
|
||||
if getattr(self, 'generated_ids', None) is None:
|
||||
self.generated_ids = torch.zeros(
|
||||
self.args.batch_size,
|
||||
input_ids.shape[-1] + self.args.max_new_tokens + 1,
|
||||
dtype=torch.int,
|
||||
device=self.args.device,
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
self.seq_length = 1
|
||||
|
||||
flat_prev_ids = self.generated_ids.flatten()
|
||||
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
|
||||
if flat_input_ids[i] == flat_prev_ids[i]:
|
||||
same_prefix += 1
|
||||
else:
|
||||
break
|
||||
|
||||
logger.debug(f"same prefix len: {same_prefix}")
|
||||
self.cache.remove_suffix(same_prefix)
|
||||
self.seq_length = same_prefix
|
||||
self.generated_ids = self.generated_ids[..., :same_prefix]
|
||||
input_ids = input_ids[..., same_prefix:]
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
|
||||
self.ever_generated_ids.clear()
|
||||
self.profiler.set_counter("prefill", input_ids_length)
|
||||
logger.debug(f"input_ids: {input_ids.shape}")
|
||||
|
||||
logger.debug(f"generate_ids: {self.generated_ids.shape}")
|
||||
former_seq_length = self.seq_length
|
||||
self.seq_length += input_ids_length
|
||||
expected_length = self.seq_length + self.args.max_new_tokens + 1
|
||||
delta_length = expected_length - self.generated_ids.shape[-1]
|
||||
if delta_length > 0:
|
||||
new_generate_ids = torch.zeros(
|
||||
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
|
||||
)
|
||||
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
|
||||
|
||||
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
|
||||
cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)
|
||||
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
|
||||
|
@ -285,6 +319,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
else:
|
||||
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
|
||||
|
||||
self.prepare_logits_wrapper(input_ids, device)
|
||||
next_token = self.logits_to_token(logits[0, -1, :])
|
||||
yield self.append_new_tokens(next_token)
|
||||
|
||||
|
@ -315,6 +350,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
return True
|
||||
|
||||
async def inference(self, local_messages, thread_id: str):
|
||||
self.streamer.reset()
|
||||
self.profiler.create_and_start_timer("tokenize")
|
||||
if isinstance(local_messages, List):
|
||||
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
|
||||
|
@ -325,7 +361,7 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
else:
|
||||
raise ValueError("local_messages should be List or str")
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\\n",add_special_tokens=False)],device=input_ids.device)
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
|
@ -333,11 +369,14 @@ class TransformersInterface(BackendInterfaceBase):
|
|||
self.profiler.pause_timer("tokenize")
|
||||
|
||||
self.profiler.create_and_start_timer("prefill")
|
||||
if Config().user_force_think:
|
||||
t = "<think>\n"
|
||||
print(t,end="",flush=True)
|
||||
yield t
|
||||
|
||||
|
||||
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
|
||||
# output think token after prefill done
|
||||
if Config().user_force_think:
|
||||
think = '<think>\n'
|
||||
print(think, end="",flush=True)
|
||||
yield think
|
||||
if t is not None:
|
||||
print(t, end="",flush=True)
|
||||
yield t
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue