Apply magikRUKKOLA's patch from issue #1417

This commit is contained in:
Jesse CreateThis 2025-07-06 19:45:06 +00:00
parent 890b0f1622
commit 8c8cb207aa
2 changed files with 25 additions and 11 deletions

View file

@ -42,18 +42,22 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
self.use_cuda_graph = use_cuda_graph
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
# Increase buffer sizes to be safe
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0)
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
# Make sure this buffer is large enough
self.paged_kv_indices_buf = torch.empty((max_pages * 2,), dtype=torch.int32, device=device)
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
qo_indptr=self.qo_indptr_buf,
kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,
kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf,
backend = "fa2",
)
@ -145,4 +149,4 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
minibatch = batch.minibatch
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)