mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 20:49:55 +00:00
fix rope; update moegate
This commit is contained in:
parent
f873558a89
commit
f748cd29f0
5 changed files with 54 additions and 21 deletions
|
@ -142,37 +142,42 @@ class DeepseekV3TopkRouter(nn.Module):
|
|||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||
|
||||
scores = router_logits.sigmoid()
|
||||
topk_indices = self.get_topk_indices(scores)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor
|
||||
return topk_indices, topk_weights, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def get_topk_indices(self, scores):
|
||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
) # [n, e]
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
|
||||
return topk_indices, topk_weights, router_logits
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
return topk_indices
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
|
|
@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import (
|
|||
LlamaLinearScalingRotaryEmbedding,
|
||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||
)
|
||||
from ktransformers.models.modeling_deepseek_v3 import (
|
||||
DeepseekV3RotaryEmbedding
|
||||
)
|
||||
from ktransformers.models.modeling_deepseek import (
|
||||
DeepseekV2YarnRotaryEmbedding,
|
||||
DeepseekV2RotaryEmbedding,
|
||||
|
@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
|||
self.orig_module.mscale_all_dim,
|
||||
)
|
||||
|
||||
class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
# TODO support perlayer prefill
|
||||
self.orig_module.__init__(
|
||||
self.config,
|
||||
device=self.generate_device
|
||||
)
|
||||
return
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(
|
||||
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
|
||||
|
|
|
@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
|
|||
x = x.to(self.device)
|
||||
orig_shape = list(x.shape)
|
||||
orig_dtype = x.dtype
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
x = x.reshape(-1, orig_shape[-1])
|
||||
marlin_s = self.marlin_s.to(x.dtype)
|
||||
x = KTransformersOps.gptq_marlin_gemm(
|
||||
x,
|
||||
|
|
|
@ -643,7 +643,7 @@ class KDeepseekV2Model(BaseInjectedModule):
|
|||
org_device = input_ids.device
|
||||
# TODO move to embed_tokens's device, not hard code to cpu
|
||||
input_ids = input_ids.to("cpu")
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
|
||||
input_ids = input_ids.to(org_device)
|
||||
|
||||
if per_layer_prefill_flag:
|
||||
|
|
|
@ -8,17 +8,17 @@
|
|||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda:0"
|
||||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
||||
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda:1"
|
||||
prefill_device: "cuda:1"
|
||||
|
|
Loading…
Add table
Reference in a new issue