Implement multi-batch support for v2, v3, and r1 models with backend_type configured as ktransformers.

This commit is contained in:
jiafei96 2025-07-09 09:09:47 +00:00
parent 890b0f1622
commit a6ab9e349c
6 changed files with 383 additions and 52 deletions

View file

@ -669,10 +669,12 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
from ktransformers.server.backend.interfaces.ktransformers import multi_batch_enabled
if (os.name == 'nt'
or get_compute_capability() < 8
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
or device_manager.gpu_vendor != GPUVendor.NVIDIA
or multi_batch_enabled):
# print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(