fix-hopper-flashinfer

This commit is contained in:
Atream 2025-04-29 11:06:34 +08:00
parent 38333cf129
commit b0318fc01c
3 changed files with 6 additions and 3 deletions

View file

@ -50,7 +50,8 @@ class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
backend = "fa2",
)
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):