use compile for gate, slight performance improvement

This commit is contained in:
Atream 2025-03-14 12:43:28 +00:00
parent 6c4ed59175
commit a889288fc1
9 changed files with 155 additions and 37 deletions

View file

@ -1,25 +1,14 @@
from typing import Optional
from typing import Any, Union from torch import nn
import numpy as np
import numpy.typing as npt
from torch import Tensor, nn
import torch.nn.functional as F
import torch import torch
import sys, os import torch.nn.functional as F
import os
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.operators.linear import KTransformersLinear
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import time
# class Base(BaseInjectedModule, ABC): # class Base(BaseInjectedModule, ABC):
@ -100,8 +89,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module = None, orig_module: nn.Module = None,
prefill_device: str = "cuda",
generate_device: str = "cuda", generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
@ -131,3 +120,133 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.weight = None self.weight = None
if self.e_score_correction_bias is not None: if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None self.e_score_correction_bias = None
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
# This is used by the Deepseek-V2 and Deepseek-V3 model
#@torch.compile(dynamic=True)
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "sigmoid",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_ids.to(torch.long), topk_weights.to(torch.float32)
class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda",
prefill_op: str| None = "KLinearMarlin",
use_quant: bool = False,
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
self.generate_op = generate_op
self.prefill_op = prefill_op
self.is_windows = os.name == 'nt'
self.use_quant = use_quant
if not self.is_windows and use_quant:
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
gguf_loader, config, self.gate_linear, #orig_module
generate_device, generate_op, prefill_device, prefill_op)
else:
self.gate_linear = None
def forward(self, hidden_states) -> torch.Tensor:
if self.is_windows:
return self.orig_module.forward(hidden_states)
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
if self.use_quant:
logits = self.gate_linear.forward(logits)
else:
logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
)
return grouped_topk(hidden_states, logits,
self.top_k, self.norm_topk_prob,
self.n_group, self.topk_group)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
if not self.is_windows and self.use_quant:
self.gate_linear.load(self.orig_module.weight)
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None

View file

@ -477,7 +477,6 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda", generate_device: str = "cuda",
generate_op: str| None = "KLinearMarlin", generate_op: str| None = "KLinearMarlin",
prefill_device: str = "cuda", prefill_device: str = "cuda",

View file

@ -26,7 +26,7 @@
- match: - match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"

View file

@ -147,7 +147,7 @@
name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$" name: "^model\\.layers\\.([0-9]|1[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -157,7 +157,7 @@
name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(1[5-9]|2[0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
@ -167,7 +167,7 @@
name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$" name: "^model\\.layers\\.(3[0-9]|4[0-4])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
@ -177,7 +177,7 @@
name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$" name: "^model\\.layers\\.(4[5-9]|5[0-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"

View file

@ -278,7 +278,7 @@
name: "^model\\.layers\\.([0-7])\\.mlp\\.gate$" name: "^model\\.layers\\.([0-7])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -288,7 +288,7 @@
name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.gate$" name: "^model\\.layers\\.(8|9|1[0-5])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
@ -298,7 +298,7 @@
name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.gate$" name: "^model\\.layers\\.(1[6-9]|2[0-3])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
@ -308,7 +308,7 @@
name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.gate$" name: "^model\\.layers\\.(2[4-9]|3[0-1])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
@ -318,7 +318,7 @@
name: "^model\\.layers\\.(3[2-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(3[2-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:4" generate_device: "cuda:4"
prefill_device: "cuda:4" prefill_device: "cuda:4"
@ -328,7 +328,7 @@
name: "^model\\.layers\\.(4[0-7])\\.mlp\\.gate$" name: "^model\\.layers\\.(4[0-7])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:5" generate_device: "cuda:5"
prefill_device: "cuda:5" prefill_device: "cuda:5"
@ -338,7 +338,7 @@
name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.gate$" name: "^model\\.layers\\.(4[8-9]|5[0-5])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:6" generate_device: "cuda:6"
prefill_device: "cuda:6" prefill_device: "cuda:6"
@ -348,7 +348,7 @@
name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.gate$" name: "^model\\.layers\\.(5[6-9]|60)\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:7" generate_device: "cuda:7"
prefill_device: "cuda:7" prefill_device: "cuda:7"

View file

@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -18,7 +18,7 @@
name: "^model\\.layers\\.([3456][0-9])\\." name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"

View file

@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3 class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"

View file

@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"

View file

@ -38,7 +38,7 @@
- match: - match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"