mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 14:51:06 +00:00
Apply magikRUKKOLA's patch from issue #1417
This commit is contained in:
parent
890b0f1622
commit
8c8cb207aa
2 changed files with 25 additions and 11 deletions
|
@ -42,18 +42,22 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
|
||||
# Increase buffer sizes to be safe
|
||||
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0)
|
||||
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
|
||||
# Make sure this buffer is large enough
|
||||
self.paged_kv_indices_buf = torch.empty((max_pages * 2,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
|
||||
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
|
||||
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
|
||||
qo_indptr=self.qo_indptr_buf,
|
||||
kv_indptr=self.paged_kv_indptr_buf,
|
||||
kv_indices=self.paged_kv_indices_buf,
|
||||
kv_len_arr=self.paged_kv_len_buf,
|
||||
bsz_tensor=self.bsz_tensor_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
|
@ -145,4 +149,4 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|||
minibatch = batch.minibatch
|
||||
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue