mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-09 13:55:27 +00:00
fix rope; update moegate
This commit is contained in:
parent
f873558a89
commit
f748cd29f0
5 changed files with 54 additions and 21 deletions
|
@ -142,37 +142,42 @@ class DeepseekV3TopkRouter(nn.Module):
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.n_group = config.n_group
|
self.n_group = config.n_group
|
||||||
self.topk_group = config.topk_group
|
self.topk_group = config.topk_group
|
||||||
|
self.norm_topk_prob = config.norm_topk_prob
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
batch_size, seq_length = hidden_states.shape[:-1]
|
|
||||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||||
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||||
|
|
||||||
scores = router_logits.sigmoid()
|
scores = router_logits.sigmoid()
|
||||||
|
topk_indices = self.get_topk_indices(scores)
|
||||||
|
topk_weights = scores.gather(1, topk_indices)
|
||||||
|
if self.norm_topk_prob:
|
||||||
|
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
|
topk_weights /= denominator
|
||||||
|
topk_weights = topk_weights * self.routed_scaling_factor
|
||||||
|
return topk_indices, topk_weights, router_logits
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_topk_indices(self, scores):
|
||||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||||
group_scores = (
|
group_scores = (
|
||||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||||
.topk(2, dim=-1)[0]
|
.topk(2, dim=-1)[0]
|
||||||
.sum(dim=-1)
|
.sum(dim=-1)
|
||||||
) # [n, n_group]
|
)
|
||||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
group_mask = torch.zeros_like(group_scores)
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
group_mask.scatter_(1, group_idx, 1)
|
||||||
score_mask = (
|
score_mask = (
|
||||||
group_mask.unsqueeze(-1)
|
group_mask.unsqueeze(-1)
|
||||||
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
|
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||||
.reshape(-1, self.n_routed_experts)
|
.reshape(-1, self.n_routed_experts)
|
||||||
) # [n, e]
|
)
|
||||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||||
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||||
topk_weights = scores.gather(1, topk_indices)
|
return topk_indices
|
||||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
|
||||||
topk_weights /= denominator
|
|
||||||
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
|
|
||||||
return topk_indices, topk_weights, router_logits
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MoE(nn.Module):
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
|
|
@ -12,6 +12,9 @@ from ktransformers.models.modeling_llama import (
|
||||||
LlamaLinearScalingRotaryEmbedding,
|
LlamaLinearScalingRotaryEmbedding,
|
||||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from ktransformers.models.modeling_deepseek_v3 import (
|
||||||
|
DeepseekV3RotaryEmbedding
|
||||||
|
)
|
||||||
from ktransformers.models.modeling_deepseek import (
|
from ktransformers.models.modeling_deepseek import (
|
||||||
DeepseekV2YarnRotaryEmbedding,
|
DeepseekV2YarnRotaryEmbedding,
|
||||||
DeepseekV2RotaryEmbedding,
|
DeepseekV2RotaryEmbedding,
|
||||||
|
@ -134,6 +137,31 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
|
||||||
self.orig_module.mscale_all_dim,
|
self.orig_module.mscale_all_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
gguf_loader: GGUFLoader,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
orig_module: nn.Module,
|
||||||
|
# device: str = "cuda",
|
||||||
|
generate_device: str = "cuda",
|
||||||
|
prefill_device: str = "cuda",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
BaseInjectedModule.__init__(
|
||||||
|
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
|
||||||
|
)
|
||||||
|
self.generate_device = generate_device
|
||||||
|
self.prefill_device = prefill_device
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
# TODO support perlayer prefill
|
||||||
|
self.orig_module.__init__(
|
||||||
|
self.config,
|
||||||
|
device=self.generate_device
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
class DynamicNTKScalingRotaryEmbedding(
|
class DynamicNTKScalingRotaryEmbedding(
|
||||||
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
|
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
|
||||||
|
|
|
@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, orig_shape[-1])
|
||||||
marlin_s = self.marlin_s.to(x.dtype)
|
marlin_s = self.marlin_s.to(x.dtype)
|
||||||
x = KTransformersOps.gptq_marlin_gemm(
|
x = KTransformersOps.gptq_marlin_gemm(
|
||||||
x,
|
x,
|
||||||
|
|
|
@ -643,7 +643,7 @@ class KDeepseekV2Model(BaseInjectedModule):
|
||||||
org_device = input_ids.device
|
org_device = input_ids.device
|
||||||
# TODO move to embed_tokens's device, not hard code to cpu
|
# TODO move to embed_tokens's device, not hard code to cpu
|
||||||
input_ids = input_ids.to("cpu")
|
input_ids = input_ids.to("cpu")
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
|
||||||
input_ids = input_ids.to(org_device)
|
input_ids = input_ids.to(org_device)
|
||||||
|
|
||||||
if per_layer_prefill_flag:
|
if per_layer_prefill_flag:
|
||||||
|
|
|
@ -8,17 +8,17 @@
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
|
||||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:0"
|
generate_device: "cuda:0"
|
||||||
prefill_device: "cuda:0"
|
prefill_device: "cuda:0"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([3456][0-9])\\."
|
name: "^model\\.layers\\.([3456][0-9])\\."
|
||||||
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.RoPE.YarnRotaryEmbedding
|
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding
|
||||||
kwargs:
|
kwargs:
|
||||||
generate_device: "cuda:1"
|
generate_device: "cuda:1"
|
||||||
prefill_device: "cuda:1"
|
prefill_device: "cuda:1"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue