mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 06:14:58 +00:00
fix-hopper-flashinfer
This commit is contained in:
parent
38333cf129
commit
b0318fc01c
3 changed files with 6 additions and 3 deletions
|
@ -50,7 +50,8 @@ class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
|
|||
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
|
||||
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
|
||||
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf
|
||||
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
|
|
|
@ -54,7 +54,8 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
|
|||
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
|
||||
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
|
||||
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
|
||||
bsz_tensor=self.bsz_tensor_buf
|
||||
bsz_tensor=self.bsz_tensor_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
|
|
|
@ -100,7 +100,8 @@ class MLAWrapper():
|
|||
kv_indptr=self.kv_indptr_buf,
|
||||
kv_indices=self.kv_indices_buf,
|
||||
kv_len_arr=self.kv_len_arr_buf,
|
||||
bsz_tensor=self.batch_size_tensor_buf
|
||||
bsz_tensor=self.batch_size_tensor_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
self.need_plan = True
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue