fix-hopper-flashinfer

This commit is contained in:
Atream 2025-04-29 11:06:34 +08:00
parent 38333cf129
commit b0318fc01c
3 changed files with 6 additions and 3 deletions

View file

@ -50,7 +50,8 @@ class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph, self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
backend = "fa2",
) )
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):

View file

@ -54,7 +54,8 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
self.workspace_buffer, use_cuda_graph=use_cuda_graph, self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf, kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf bsz_tensor=self.bsz_tensor_buf,
backend = "fa2",
) )
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):

View file

@ -100,7 +100,8 @@ class MLAWrapper():
kv_indptr=self.kv_indptr_buf, kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf, kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf, kv_len_arr=self.kv_len_arr_buf,
bsz_tensor=self.batch_size_tensor_buf bsz_tensor=self.batch_size_tensor_buf,
backend = "fa2",
) )
self.need_plan = True self.need_plan = True