fix rope; update moegate

This commit is contained in:
Azure 2025-02-01 18:05:45 +00:00
parent f873558a89
commit f748cd29f0
5 changed files with 54 additions and 21 deletions

View file

@ -142,37 +142,42 @@ class DeepseekV3TopkRouter(nn.Module):
self.routed_scaling_factor = config.routed_scaling_factor
self.n_group = config.n_group
self.topk_group = config.topk_group
self.norm_topk_prob = config.norm_topk_prob
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
def forward(self, hidden_states):
batch_size, seq_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
scores = router_logits.sigmoid()
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights, router_logits
@torch.no_grad()
def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = (
group_mask.unsqueeze(-1)
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
) # [n, e]
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_indices)
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
return topk_indices, topk_weights, router_logits
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
return topk_indices
class DeepseekV3MoE(nn.Module):

View file

@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import (
LlamaLinearScalingRotaryEmbedding,
LlamaDynamicNTKScalingRotaryEmbedding,
)
from ktransformers.models.modeling_deepseek_v3 import (
DeepseekV3RotaryEmbedding
)
from ktransformers.models.modeling_deepseek import (
DeepseekV2YarnRotaryEmbedding,
DeepseekV2RotaryEmbedding,
@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
self.orig_module.mscale_all_dim,
)
class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
def load(self):
# TODO support perlayer prefill
self.orig_module.__init__(
self.config,
device=self.generate_device
)
return
class DynamicNTKScalingRotaryEmbedding(
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding

View file

@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, x.shape[-1])
x = x.reshape(-1, orig_shape[-1])
marlin_s = self.marlin_s.to(x.dtype)
x = KTransformersOps.gptq_marlin_gemm(
x,

View file

@ -643,7 +643,7 @@ class KDeepseekV2Model(BaseInjectedModule):
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
input_ids = input_ids.to(org_device)
if per_layer_prefill_flag:

View file

@ -8,17 +8,17 @@
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"