diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 5e7391f..35c8093 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -435,7 +435,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): q_nope.dtype, compressed_kv.dtype) attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) - """ k = ( torch.cat([compressed_kv, k_pe], dim=-1) @@ -465,7 +464,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = self.o_proj(attn_output) - + return attn_output, None, past_key_value else: if past_key_value is not None: diff --git a/ktransformers/operators/flashinfer_wrapper.py b/ktransformers/operators/flashinfer_wrapper.py index 864b33e..f8ea3ce 100644 --- a/ktransformers/operators/flashinfer_wrapper.py +++ b/ktransformers/operators/flashinfer_wrapper.py @@ -122,7 +122,7 @@ class MLAWrapper(): if kv_indices is None: assert self.max_batch_size == 1 kv_indices = self.kv_indices_buf - + self.wrapper.plan( qo_indptr, kv_indptr, @@ -189,7 +189,14 @@ class MLAWrapperSingleton(): @classmethod def reset_buffer(cls): for device, wrapper in cls.wrappers.items(): - wrapper.qo_indptr_buf[1] = 1 + wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here. + + @classmethod + def update_buffer(cls, max_pages): + for device, wrapper in cls.wrappers.items(): + wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here. + wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device) + wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf if __name__ == "__main__": diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml index 03c85a0..ea75b30 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml @@ -293,6 +293,7 @@ kwargs: generate_device: "cuda:0" prefill_device: "cuda:0" + absorb_for_prefill: False # GPU 1: layers 15–29 - match: @@ -302,6 +303,7 @@ kwargs: generate_device: "cuda:1" prefill_device: "cuda:1" + absorb_for_prefill: False # GPU 2: layers 30–44 - match: @@ -311,6 +313,7 @@ kwargs: generate_device: "cuda:2" prefill_device: "cuda:2" + absorb_for_prefill: False # GPU 3: layers 45–60 - match: @@ -320,6 +323,7 @@ kwargs: generate_device: "cuda:3" prefill_device: "cuda:3" + absorb_for_prefill: False # === Overall Model Replacement with Transfer Map === diff --git a/ktransformers/util/utils.py b/ktransformers/util/utils.py index 64b9131..3f5ad8e 100644 --- a/ktransformers/util/utils.py +++ b/ktransformers/util/utils.py @@ -177,6 +177,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud else: inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) if use_flashinfer_mla: + MLAWrapperSingleton.update_buffer(past_key_values.max_pages) MLAWrapperSingleton.need_plan_all() logits = model(