mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
fix-gate-compile
This commit is contained in:
parent
e788248364
commit
114995355b
1 changed files with 3 additions and 4 deletions
|
@ -125,7 +125,7 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
|
||||
# This is used by the Deepseek-V2 and Deepseek-V3 model
|
||||
#@torch.compile(dynamic=True)
|
||||
@torch.compile(dynamic=True)
|
||||
def grouped_topk(hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
|
@ -225,9 +225,8 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
|
|||
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
||||
)
|
||||
|
||||
return grouped_topk(hidden_states, logits,
|
||||
self.top_k, self.norm_topk_prob,
|
||||
self.n_group, self.topk_group)
|
||||
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob,
|
||||
self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
|
|
Loading…
Add table
Reference in a new issue