mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-04 19:50:04 +00:00
Restore CPU offloading capability
This commit is contained in:
parent
05f6cede37
commit
f4d52d1f0c
3 changed files with 192 additions and 2 deletions
|
@ -595,7 +595,10 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
if (os.name == 'nt'
|
||||
or get_compute_capability() < 8
|
||||
or hidden_states.device.type == 'cpu'
|
||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
||||
return self.forward_windows(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
|
|
@ -650,7 +650,10 @@ class KDeepseekV2Model(BaseInjectedModule):
|
|||
if per_layer_prefill_flag:
|
||||
causal_mask = None
|
||||
else:
|
||||
if os.name == 'nt' or get_compute_capability()<8 or device_manager.gpu_vendor != GPUVendor.NVIDIA:
|
||||
if (os.name == 'nt'
|
||||
or get_compute_capability() < 8
|
||||
or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())
|
||||
or device_manager.gpu_vendor != GPUVendor.NVIDIA):
|
||||
# print("for Windows or GPU before ampere, use forward_windows")
|
||||
# only use mask in forward windows or can't flash attn
|
||||
causal_mask = self._update_causal_mask(
|
||||
|
|
|
@ -0,0 +1,184 @@
|
|||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
# === Rotary Embedding Replacement ===
|
||||
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
# === Linear Layers Replacement (excluding self_attn) ===
|
||||
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.(?!self_attn).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.(?!self_attn).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
generate_op: "KLinearCPUInfer"
|
||||
prefill_op: "KLinearTorch"
|
||||
out_device: "cpu"
|
||||
|
||||
# === MLP (MoE) Replacement ===
|
||||
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV2MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
# === MLP Gate Replacement ===
|
||||
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseek_v3.MoEGate
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
# === MLP Experts Replacement ===
|
||||
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda:0"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda:0"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cpu"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cpu"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
|
||||
# === Self-Attention Replacement ===
|
||||
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "^model\\.layers\\.([12][0-9])\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
# === Overall Model Replacement with Transfer Map ===
|
||||
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KDeepseekV2Model"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
transfer_map:
|
||||
10: "cpu"
|
||||
|
||||
# === Default Catch-All for Other Modules ===#
|
||||
# GPU 0: layers 0–9
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9])\\."
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
|
||||
#lmm_head on GPU 0
|
||||
- match:
|
||||
name: "^lm_head"
|
||||
class: torch.nn.Linear
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# CPU: layers 10-29
|
||||
- match:
|
||||
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
Loading…
Add table
Reference in a new issue