fix-hopper-flashinfer

This commit is contained in:
Atream 2025-04-29 11:06:34 +08:00
parent 38333cf129
commit b0318fc01c
3 changed files with 6 additions and 3 deletions

View file

@ -54,7 +54,8 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf
bsz_tensor=self.bsz_tensor_buf,
backend = "fa2",
)
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):