mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
support chunk prefill, support 139K context for 24G VRAM
This commit is contained in:
parent
494469d4c5
commit
f35e8d41d8
10 changed files with 227 additions and 83 deletions
|
@ -110,7 +110,7 @@ def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
|
|||
module.load()
|
||||
|
||||
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,
|
||||
mode = 'normal', force_think: bool = False, use_flashinfer_mla = False,
|
||||
mode = 'normal', force_think: bool = False, chunk_prefill_size = 16384, use_flashinfer_mla = False,
|
||||
num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
@ -124,7 +124,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
|
||||
tokens = []
|
||||
|
||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, use_cuda_graph: bool = True):
|
||||
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):
|
||||
if cuda_graph_runner is None:
|
||||
use_cuda_graph = False
|
||||
if use_cuda_graph:
|
||||
|
@ -152,24 +152,8 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
return next_token
|
||||
|
||||
torch.cuda.set_device(torch_device)
|
||||
with torch.no_grad():
|
||||
stream = TextStreamer(tokenizer)
|
||||
if mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
)
|
||||
else:
|
||||
past_key_values = None
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position
|
||||
start_time = time.time()
|
||||
|
||||
# TODO: use CUDA Graph for chunk prefill, may get small improvement
|
||||
def chunk_prefill(inputs, cache_position, past_key_values):
|
||||
if mode == "long_context":
|
||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
|
||||
else:
|
||||
|
@ -181,6 +165,20 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
logits = model(
|
||||
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
|
||||
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
|
||||
|
||||
return logits
|
||||
|
||||
torch.cuda.set_device(torch_device)
|
||||
with torch.no_grad():
|
||||
|
||||
stream = TextStreamer(tokenizer)
|
||||
if mode != 'long_context':
|
||||
past_key_values = StaticCache(
|
||||
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype
|
||||
)
|
||||
else:
|
||||
past_key_values = None
|
||||
|
||||
generation_config, model_kwargs = model._prepare_generation_config(
|
||||
None, do_sample=True
|
||||
# change this to modify generate config
|
||||
|
@ -194,12 +192,29 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
logits_warper = (
|
||||
model._get_logits_warper(generation_config)
|
||||
)
|
||||
|
||||
cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
|
||||
start_time = time.time()
|
||||
|
||||
chunk_start = 0
|
||||
while chunk_start < seq_length:
|
||||
chunk_end = min(chunk_start + chunk_prefill_size, seq_length)
|
||||
if past_key_values != None:
|
||||
past_key_values.cur_idx=cache_position[chunk_start:chunk_end]
|
||||
logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)
|
||||
chunk_start += chunk_prefill_size
|
||||
|
||||
next_token_scores = logits_warper(inputs, logits[:, -1, :])
|
||||
if generation_config.do_sample:
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
else:
|
||||
next_token = torch.argmax(next_token_scores, dim=-1)
|
||||
|
||||
first_token_time = time.time() - start_time
|
||||
|
||||
if use_flashinfer_mla:
|
||||
|
@ -208,7 +223,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
prefill_count = seq_length
|
||||
prefill_time = first_token_time
|
||||
if force_think:
|
||||
print("<think>\n")
|
||||
print("<think>")
|
||||
print(stream.put(next_token.item()), end="", flush=True)
|
||||
generated_ids[:, seq_length] = next_token
|
||||
tokens.append(int(next_token))
|
||||
|
@ -230,7 +245,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
|||
warm_uped = True
|
||||
cuda_graph_runner = CUDAGraphRunner()
|
||||
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
|
||||
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)
|
||||
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
tokens.append(int(next_token))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue