Restore CPU offloading capability

This commit is contained in:
Aubrey Li 2025-03-21 10:04:31 +08:00
parent 05f6cede37
commit f4d52d1f0c
3 changed files with 192 additions and 2 deletions

View file

@ -650,7 +650,10 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
if (os.name == 'nt'
or get_compute_capability() < 8
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
# print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(