support chunk prefill, support 139K context for 24G VRAM

This commit is contained in:
Atream 2025-03-01 11:28:25 +00:00
parent 494469d4c5
commit f35e8d41d8
10 changed files with 227 additions and 83 deletions

View file

@ -111,12 +111,10 @@ class KTransformersInterface(TransformersInterface):
warm_uped = True
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model(
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
return_dict=False,
use_cache=True,
)[0]
@ -167,41 +165,54 @@ class KTransformersInterface(TransformersInterface):
self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
expected_length = min(self.seq_length + self.args.max_new_tokens + 1, self.args.cache_lens)
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
else:
logger.warning(f"seq_length bigger than cache_lens, killed")
exit(0)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1, self.seq_length)).to(device)
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
attention_mask=mask,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
def chunk_prefill(input_ids, cache_position):
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
return logits
chunk_start = 0
while chunk_start < input_ids_length:
chunk_end = min(chunk_start + self.args.chunk_prefill_size, input_ids_length)
if self.cache != None:
self.cache.cur_idx=cache_position[chunk_start:chunk_end]
logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])
chunk_start += self.args.chunk_prefill_size
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device, temperature, top_p)