mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
fix-update-flashinfer_wrapper_local_chat
This commit is contained in:
parent
5474be5299
commit
477ac28a9c
4 changed files with 15 additions and 4 deletions
|
@ -435,7 +435,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
||||||
q_nope.dtype,
|
q_nope.dtype,
|
||||||
compressed_kv.dtype)
|
compressed_kv.dtype)
|
||||||
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
k = (
|
k = (
|
||||||
torch.cat([compressed_kv, k_pe], dim=-1)
|
torch.cat([compressed_kv, k_pe], dim=-1)
|
||||||
|
|
|
@ -189,7 +189,14 @@ class MLAWrapperSingleton():
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset_buffer(cls):
|
def reset_buffer(cls):
|
||||||
for device, wrapper in cls.wrappers.items():
|
for device, wrapper in cls.wrappers.items():
|
||||||
wrapper.qo_indptr_buf[1] = 1
|
wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_buffer(cls, max_pages):
|
||||||
|
for device, wrapper in cls.wrappers.items():
|
||||||
|
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
|
||||||
|
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
|
||||||
|
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -293,6 +293,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:0"
|
generate_device: "cuda:0"
|
||||||
prefill_device: "cuda:0"
|
prefill_device: "cuda:0"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# GPU 1: layers 15–29
|
# GPU 1: layers 15–29
|
||||||
- match:
|
- match:
|
||||||
|
@ -302,6 +303,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:1"
|
generate_device: "cuda:1"
|
||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# GPU 2: layers 30–44
|
# GPU 2: layers 30–44
|
||||||
- match:
|
- match:
|
||||||
|
@ -311,6 +313,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:2"
|
generate_device: "cuda:2"
|
||||||
prefill_device: "cuda:2"
|
prefill_device: "cuda:2"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# GPU 3: layers 45–60
|
# GPU 3: layers 45–60
|
||||||
- match:
|
- match:
|
||||||
|
@ -320,6 +323,7 @@
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:3"
|
generate_device: "cuda:3"
|
||||||
prefill_device: "cuda:3"
|
prefill_device: "cuda:3"
|
||||||
|
absorb_for_prefill: False
|
||||||
|
|
||||||
# === Overall Model Replacement with Transfer Map ===
|
# === Overall Model Replacement with Transfer Map ===
|
||||||
|
|
||||||
|
|
|
@ -177,6 +177,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
|
||||||
else:
|
else:
|
||||||
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
|
||||||
if use_flashinfer_mla:
|
if use_flashinfer_mla:
|
||||||
|
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
|
||||||
MLAWrapperSingleton.need_plan_all()
|
MLAWrapperSingleton.need_plan_all()
|
||||||
|
|
||||||
logits = model(
|
logits = model(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue