support chunk prefill, support 139K context for 24G VRAM

This commit is contained in:
Atream 2025-03-01 11:28:25 +00:00
parent 494469d4c5
commit f35e8d41d8
10 changed files with 227 additions and 83 deletions

View file

@ -62,6 +62,7 @@ def local_chat(
prompt_file : str | None = None,
mode: str = "normal",
force_think: bool = False,
chunk_prefill_size: int = 8192
):
torch.set_grad_enabled(False)
@ -170,12 +171,12 @@ def local_chat(
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled and get_compute_capability() >= 8:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
)
else:
generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,
)