mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
update rope calculation; update modeling.py; update gate for moe
This commit is contained in:
parent
5a50b34627
commit
f873558a89
11 changed files with 402 additions and 412 deletions
|
@ -54,4 +54,4 @@ long_context:
|
|||
token_step:
|
||||
|
||||
local_chat:
|
||||
prompt_file: "./ktransformers/p.txt"
|
||||
prompt_file: ""
|
|
@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
|
|||
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
|
@ -78,7 +78,7 @@ def local_chat():
|
|||
else:
|
||||
content += line + "\n"
|
||||
if content == "":
|
||||
if config.prompt_file == None or config.prompt_file == "":
|
||||
if not config.prompt_file:
|
||||
content = "hi"
|
||||
else:
|
||||
content = open(config.prompt_file, "r").read()
|
||||
|
|
|
@ -17,16 +17,22 @@
|
|||
"""DeepSeekV3 model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 129280):
|
||||
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||
|
@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
Dimension of the MoE representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 61):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
||||
Number of nextn predict layers in the DeepSeekV3 Model.
|
||||
num_attention_heads (`int`, *optional*, defaults to 128):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 128):
|
||||
|
@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
n_shared_experts (`int`, *optional*, defaults to 1):
|
||||
Number of shared experts, None means dense model.
|
||||
Number of shared experts.
|
||||
n_routed_experts (`int`, *optional*, defaults to 256):
|
||||
Number of routed experts, None means dense model.
|
||||
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
|
||||
Number of routed experts.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
|
||||
Scaling factor or routed experts.
|
||||
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
|
||||
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
|
||||
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
|
||||
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
|
||||
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
|
||||
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
|
||||
Topk method used in routed gate.
|
||||
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||
Rank of the LoRA matrices for key and value projections.
|
||||
q_lora_rank (`int`, *optional*, defaults to 1536):
|
||||
Rank of the LoRA matrices for query projections.
|
||||
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
||||
Dimension of the query/key heads that use rotary position embeddings.
|
||||
v_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of the value heads.
|
||||
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of the query/key heads that don't use rotary position embeddings.
|
||||
n_group (`int`, *optional*, defaults to 8):
|
||||
Number of groups for routed experts.
|
||||
topk_group (`int`, *optional*, defaults to 4):
|
||||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||
Number of selected experts, None means dense model.
|
||||
moe_layer_freq (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 3):
|
||||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||
\--k dense layers--/
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the weights of the routed experts.
|
||||
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
|
||||
Method of computing expert weights.
|
||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||
Auxiliary loss weight coefficient.
|
||||
Whether to compute the auxiliary loss for each individual sample.
|
||||
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
|
@ -119,16 +120,25 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
```python
|
||||
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
||||
|
||||
>>> # Initializing a Deepseek-V3 style configuration
|
||||
>>> configuration = DeepseekV3Config()
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "deepseek_v3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `DeepseekV3Model`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.gate_proj": "colwise",
|
||||
"layers.*.up_proj": "colwise",
|
||||
"layers.*.down_proj": "rowwise",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -137,28 +147,22 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
intermediate_size=18432,
|
||||
moe_intermediate_size=2048,
|
||||
num_hidden_layers=61,
|
||||
num_nextn_predict_layers=1,
|
||||
num_attention_heads=128,
|
||||
num_key_value_heads=128,
|
||||
n_shared_experts=1,
|
||||
n_routed_experts=256,
|
||||
ep_size = 1,
|
||||
routed_scaling_factor=2.5,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method = 'noaux_tc',
|
||||
n_group=8,
|
||||
topk_group=4,
|
||||
num_experts_per_tok=8,
|
||||
moe_layer_freq = 1,
|
||||
first_k_dense_replace=3,
|
||||
norm_topk_prob=True,
|
||||
scoring_func = 'sigmoid',
|
||||
aux_loss_alpha=0.001,
|
||||
seq_aux = True,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
|
@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
mlp_bias=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.head_dim = qk_rope_head_dim
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.seq_aux = seq_aux
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
|
|||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
|
@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
|
|||
# In-place ops prevent breaking the static address
|
||||
self.key_cache[layer_idx].zero_()
|
||||
self.value_cache[layer_idx].zero_()
|
||||
|
||||
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
|
||||
"""Returns the maximum shape of the cache."""
|
||||
return self.max_cache_len
|
|
@ -1,15 +1,13 @@
|
|||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py.
|
||||
# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_deepseekv3.py file directly. One of our CI enforces this.
|
||||
# modular_deepseek_v3.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
@ -30,7 +28,7 @@ from transformers.utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from .configuration_deepseekv3 import DeepseekV3Config
|
||||
from .configuration_deepseek_v3 import DeepseekV3Config
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
|
|||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.moe_intermediate_size
|
||||
# TODO rm hard coding
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias)
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
|
|||
return down_proj
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
class DeepseekV3TopkRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.scoring_func = config.scoring_func
|
||||
self.seq_aux = config.seq_aux
|
||||
self.topk_method = config.topk_method
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
|
||||
# topk selection algorithm
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
self.gating_dim = config.hidden_size
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
||||
if self.topk_method == "noaux_tc":
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
import torch.nn.init as init
|
||||
|
||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
||||
if self.scoring_func == "sigmoid":
|
||||
scores = logits.sigmoid()
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
||||
batch_size, seq_length = hidden_states.shape[:-1]
|
||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||
|
||||
### select top-k experts
|
||||
if self.topk_method == "noaux_tc":
|
||||
# assert not self.training
|
||||
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
||||
scores = router_logits.sigmoid()
|
||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(bsz * seq_len, -1)
|
||||
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
) # [n, e]
|
||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weight = scores.gather(1, topk_idx)
|
||||
else:
|
||||
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
|
||||
|
||||
### norm gate to sum 1
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
||||
|
||||
return topk_idx, topk_weight
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
|
||||
return topk_indices, topk_weights, router_logits
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
|
@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
|
||||
if hasattr(config, "ep_size") and config.ep_size > 1:
|
||||
assert config.ep_size == dist.get_world_size()
|
||||
self.ep_size = config.ep_size
|
||||
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
||||
self.ep_rank = dist.get_rank()
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
(
|
||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
|
||||
else None
|
||||
)
|
||||
for i in range(config.n_routed_experts)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.ep_size = 1
|
||||
self.experts_per_rank = config.n_routed_experts
|
||||
self.ep_rank = 0
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepseekV3MLP(config)
|
||||
for i in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.gate = MoEGate(config)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV3MLP(config=config)
|
||||
self.gate = DeepseekV3TopkRouter(config)
|
||||
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
identity = hidden_states
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_idx, topk_weight = self.gate(hidden_states)
|
||||
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
if not self.training:
|
||||
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
sorted_tokens_shape = sorted_tokens.shape
|
||||
if self.ep_size > 1:
|
||||
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
||||
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
|
||||
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
|
||||
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
|
||||
gathered_tokens = sorted_tokens.new_empty(
|
||||
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
|
||||
)
|
||||
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
||||
dist.all_to_all(
|
||||
list(gathered_tokens.split(output_splits)),
|
||||
list(sorted_tokens.split(input_split_sizes)),
|
||||
)
|
||||
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(
|
||||
dim=0
|
||||
)
|
||||
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
||||
s = 0
|
||||
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
||||
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
||||
s += k
|
||||
gatherd_idxs = gatherd_idxs.argsort()
|
||||
sorted_tokens = gathered_tokens[gatherd_idxs]
|
||||
tokens_per_expert = tokens_per_expert_post_gather
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
|
||||
expert_mask = expert_mask.permute(2, 0, 1)
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
for expert_idx in range(len(self.experts)):
|
||||
expert = self.experts[expert_idx]
|
||||
mask = expert_mask[expert_idx]
|
||||
token_indices, weight_indices = torch.where(mask)
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
if self.ep_size > 1:
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[gatherd_idxs] = outs
|
||||
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
||||
dist.all_to_all(
|
||||
list(gathered_tokens.split(input_split_sizes)),
|
||||
list(new_x.split(output_splits)),
|
||||
)
|
||||
outs = gathered_tokens
|
||||
if token_indices.numel() > 0:
|
||||
expert_weights = topk_weights[token_indices, weight_indices]
|
||||
expert_input = hidden_states[token_indices]
|
||||
expert_output = expert(expert_input)
|
||||
weighted_output = expert_output * expert_weights.unsqueeze(-1)
|
||||
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
||||
return final_hidden_states.type(hidden_states.dtype)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
@ -359,150 +292,94 @@ def eager_attention_forward(
|
|||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
|
||||
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
if layer_idx is None:
|
||||
logger.warning_once(
|
||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||
"when creating this class."
|
||||
)
|
||||
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.v_head_dim = config.v_head_dim
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||
|
||||
self.is_causal = True
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
|
||||
else:
|
||||
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
||||
|
||||
self.kv_a_proj_with_mqa = nn.Linear(
|
||||
self.hidden_size,
|
||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
config.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
|
||||
self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
|
||||
self.kv_b_proj = nn.Linear(
|
||||
config.kv_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.rotary_emb = DeepseekV3RotaryEmbedding(
|
||||
config=self.config,
|
||||
)
|
||||
self.scaling = self.q_head_dim ** (-0.5)
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
# TODO apply in DeepSeekV3Model to share accrose layers
|
||||
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs# : Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, self.num_heads, -1)
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2)
|
||||
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
kv = (
|
||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
if self.layer_idx is None:
|
||||
raise ValueError(
|
||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||
"with a layer index."
|
||||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2)
|
||||
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
||||
k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||
|
||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
||||
cos, sin = position_embeddings
|
||||
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||
k_rot = k_rot.expand(-1, self.num_heads, -1, -1)
|
||||
|
||||
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
||||
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
||||
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
||||
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
||||
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
||||
|
||||
if self.q_head_dim != self.v_head_dim:
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
||||
|
||||
if past_key_value is not None:
|
||||
|
@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
|
|||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
pass
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
raise NotImplementedError(
|
||||
f"Attention implementation {self.config._attn_implementation} is not supported. "
|
||||
"Please use 'eager' or 'sdpa'."
|
||||
)
|
||||
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
|
@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
|
|||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
|
@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||
|
||||
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = (
|
||||
DeepseekV3MoE(config)
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0
|
||||
)
|
||||
else DeepseekV3MLP(config)
|
||||
)
|
||||
if layer_idx >= config.first_k_dense_replace:
|
||||
self.mlp = DeepseekV3MoE(config)
|
||||
else:
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
|
||||
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
|
@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
|
@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, router_logits = hidden_states
|
||||
else:
|
||||
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
if output_router_logits:
|
||||
outputs += (router_logits,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
DEEPSEEKV3_START_DOCSTRING = r"""
|
||||
DEEPSEEK_V3_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
|
|||
|
||||
@add_start_docstrings(
|
||||
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
|
||||
DEEPSEEKV3_START_DOCSTRING,
|
||||
DEEPSEEK_V3_START_DOCSTRING,
|
||||
)
|
||||
class DeepseekV3PreTrainedModel(PreTrainedModel):
|
||||
config_class = DeepseekV3Config
|
||||
|
@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
|||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
DEEPSEEKV3_INPUTS_DOCSTRING = r"""
|
||||
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||
|
@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
|
|||
|
||||
@add_start_docstrings(
|
||||
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
|
||||
DEEPSEEKV3_START_DOCSTRING,
|
||||
DEEPSEEK_V3_START_DOCSTRING,
|
||||
)
|
||||
class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||
"""
|
||||
|
@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
config: DeepseekV3Config
|
||||
"""
|
||||
|
||||
def __init__(self, config: DeepseekV3Config):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
|
@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
|||
|
||||
return causal_mask
|
||||
|
||||
def load_hook(self, state_dict, prefix, *args):
|
||||
"""
|
||||
Weights have to be permuted for correct rope formulation. We can't do this in the weights
|
||||
as every other framework already uses the `Llama` original function (which is copyrighted btw).
|
||||
And I am not even sure it's better.... anyways end of my rant
|
||||
"""
|
||||
|
||||
def permute_for_rope(input_tensor):
|
||||
"""
|
||||
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||
to permute the query and key weights (to avoid doing it on the fly)
|
||||
"""
|
||||
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
|
||||
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
|
||||
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
|
||||
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
|
||||
return input_tensor
|
||||
|
||||
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
|
||||
weight = state_dict[key]
|
||||
weight = weight.view(num_heads, head_dim, -1)
|
||||
weight_rot = weight[:, -rope_dim:]
|
||||
weight_rot = permute_for_rope(weight_rot)
|
||||
weight[:, -rope_dim:] = weight_rot
|
||||
weight = weight.view(-1, weight.shape[-1])
|
||||
state_dict[key] = weight
|
||||
|
||||
for k in state_dict:
|
||||
if "q_b_proj." in k:
|
||||
permute_layer_for_rope(
|
||||
k,
|
||||
num_heads=self.config.num_attention_heads,
|
||||
head_dim=self.config.q_head_dim,
|
||||
rope_dim=self.config.qk_rope_head_dim,
|
||||
)
|
||||
if "kv_a_proj_with_mqa." in k:
|
||||
permute_layer_for_rope(
|
||||
k,
|
||||
num_heads=1,
|
||||
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||
rope_dim=self.config.qk_rope_head_dim,
|
||||
)
|
||||
|
||||
|
||||
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
|||
return self.model
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
|||
```python
|
||||
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
|
||||
|
||||
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
|
||||
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
|||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
""",
|
||||
DEEPSEEKV3_START_DOCSTRING,
|
||||
DEEPSEEK_V3_START_DOCSTRING,
|
||||
)
|
||||
class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
|
@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
|||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
|
||||
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
@ -1214,3 +1146,11 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
|||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DeepseekV3PreTrainedModel",
|
||||
"DeepseekV3Model",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV3ForSequenceClassification",
|
||||
]
|
|
@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
|||
from ktransformers.models.configuration_llama import LlamaConfig
|
||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
|
||||
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
|
@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
|
|||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
||||
q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||
|
|
|
@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
|||
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
identity = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
topk_idx, topk_weight= self.gate(hidden_states)
|
||||
topk_idx, topk_weight, router_logits= self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
# only for generate phase
|
||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||
if self.config.n_shared_experts is not None:
|
||||
|
@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
return y, router_logits
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
y_ = self.shared_experts(identity).squeeze(0)
|
||||
|
@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
|||
)
|
||||
if self.config.n_shared_experts is not None:
|
||||
y += y_
|
||||
return y
|
||||
return y, router_logits
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
|
|||
import ctypes
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from ktransformers.models.modeling_deepseekv3 import MoEGate
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from ktransformers.server.config.config import Config
|
||||
from transformers.activations import ACT2FN
|
||||
|
@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = self.orig_module.weight.to(device)
|
||||
if self.topk_method == "noaux_tc":
|
||||
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.topk_method == "noaux_tc":
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias = None
|
||||
|
|
|
@ -47,7 +47,7 @@
|
|||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
|
@ -55,7 +55,7 @@
|
|||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
||||
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||
kwargs:
|
||||
|
@ -64,7 +64,7 @@
|
|||
|
||||
- match:
|
||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseekv3.MoEGate
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate
|
||||
kwargs:
|
||||
|
@ -72,7 +72,7 @@
|
|||
prefill_device: "cuda:0"
|
||||
- match:
|
||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||
class: ktransformers.models.modeling_deepseekv3.MoEGate
|
||||
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
|
||||
replace:
|
||||
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||
kwargs:
|
||||
|
|
|
@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
|
|||
self.total_context = self.model.get("total_context", 2**18)
|
||||
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
|
||||
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
|
||||
self.max_new_tokens = self.model.get("max_new_tokens", 500)
|
||||
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
|
||||
self.json_mode = self.model.get("json_mode", False)
|
||||
self.healing = self.model.get("healing", False)
|
||||
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
|
||||
|
|
|
@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
|
|||
elif config is not None:
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
|
@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
|
|||
elif config is not None:
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
factor = config.rope_scaling["factor"]
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
# seq_len: default to max_position_embeddings, e.g. at init time
|
||||
seq_len = seq_len if seq_len is not None else max_position_embeddings
|
||||
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
||||
|
||||
# Compute the inverse frequencies
|
||||
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
||||
|
@ -185,15 +187,33 @@ def _compute_yarn_parameters(
|
|||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = config.qk_rope_head_dim
|
||||
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
factor = config.rope_scaling["factor"]
|
||||
attention_factor = config.rope_scaling.get("attention_factor")
|
||||
mscale = config.rope_scaling.get("mscale")
|
||||
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
|
||||
|
||||
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
|
||||
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||
if "original_max_position_embeddings" in config.rope_scaling:
|
||||
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
|
||||
factor = config.max_position_embeddings / original_max_position_embeddings
|
||||
else:
|
||||
original_max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
def get_mscale(scale, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
# Sets the attention factor as suggested in the paper
|
||||
attention_factor = config.rope_scaling.get("attention_factor")
|
||||
if attention_factor is None:
|
||||
attention_factor = 0.1 * math.log(factor) + 1.0
|
||||
if mscale and mscale_all_dim:
|
||||
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
||||
else:
|
||||
attention_factor = get_mscale(factor)
|
||||
|
||||
# Optional config options
|
||||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
||||
|
@ -211,7 +231,7 @@ def _compute_yarn_parameters(
|
|||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_mask(min, max, dim):
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
|
||||
|
@ -219,16 +239,20 @@ def _compute_yarn_parameters(
|
|||
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
|
||||
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
|
||||
|
||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
|
||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
|
||||
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
||||
inv_freq = (
|
||||
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
|
||||
|
@ -244,7 +268,7 @@ def _compute_longrope_parameters(
|
|||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
seq_len (`int`, *optional*):
|
||||
The current sequence length. Unused for this type of RoPE.
|
||||
The current sequence length.
|
||||
rope_kwargs (`Dict`, *optional*):
|
||||
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
||||
Returns:
|
||||
|
@ -261,7 +285,8 @@ def _compute_longrope_parameters(
|
|||
|
||||
base = config.rope_theta
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
long_factor = config.rope_scaling["long_factor"]
|
||||
short_factor = config.rope_scaling["short_factor"]
|
||||
factor = config.rope_scaling.get("factor")
|
||||
|
@ -271,22 +296,20 @@ def _compute_longrope_parameters(
|
|||
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||
if hasattr(config, "original_max_position_embeddings"):
|
||||
max_position_embeddings = config.original_max_position_embeddings
|
||||
expanded_max_position_embeddings = config.max_position_embeddings
|
||||
factor = expanded_max_position_embeddings / max_position_embeddings
|
||||
original_max_position_embeddings = config.original_max_position_embeddings
|
||||
factor = config.max_position_embeddings / config.original_max_position_embeddings
|
||||
else:
|
||||
max_position_embeddings = config.max_position_embeddings
|
||||
expanded_max_position_embeddings = max_position_embeddings * factor
|
||||
original_max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
# Sets the attention factor as suggested in the paper
|
||||
if attention_factor is None:
|
||||
if factor <= 1.0:
|
||||
attention_factor = 1.0
|
||||
else:
|
||||
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
||||
attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
|
||||
|
||||
# Compute the inverse frequencies -- scaled based on the target sequence length
|
||||
if expanded_max_position_embeddings > max_position_embeddings:
|
||||
if seq_len and seq_len > original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||
else:
|
||||
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
||||
|
@ -325,19 +348,18 @@ def _compute_llama3_parameters(
|
|||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
new_freqs = []
|
||||
for freq in inv_freq:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
new_freqs.append(freq)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
new_freqs.append(freq / factor)
|
||||
else:
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
|
||||
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
wavelen = 2 * math.pi / inv_freq
|
||||
# wavelen < high_freq_wavelen: do nothing
|
||||
# wavelen > low_freq_wavelen: divide by factor
|
||||
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
||||
# otherwise: interpolate between the two, using a smooth factor
|
||||
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
||||
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
||||
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||
|
||||
return inv_freq_llama, attention_factor
|
||||
|
||||
|
||||
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
||||
|
@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
|
|||
}
|
||||
|
||||
|
||||
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
||||
def _check_received_keys(
|
||||
rope_type: str,
|
||||
received_keys: set,
|
||||
required_keys: set,
|
||||
optional_keys: Optional[set] = None,
|
||||
ignore_keys: Optional[set] = None,
|
||||
):
|
||||
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
||||
# BC: "rope_type" was originally "type" -- let's gracefully handle it
|
||||
if "rope_type" not in received_keys and "type" in received_keys:
|
||||
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
||||
if "type" in received_keys:
|
||||
received_keys -= {"type"}
|
||||
received_keys.add("rope_type")
|
||||
required_keys.add("rope_type")
|
||||
|
||||
# Some models need to store model-specific keys, and we don't want to throw warning at them
|
||||
if ignore_keys is not None:
|
||||
received_keys -= ignore_keys
|
||||
|
||||
missing_keys = required_keys - received_keys
|
||||
if missing_keys:
|
||||
|
@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
|
|||
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
||||
|
||||
|
||||
def _validate_default_rope_parameters(config: PretrainedConfig):
|
||||
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||
|
||||
|
||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"original_max_position_embeddings"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
||||
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
||||
optional_keys = {
|
||||
"attention_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"original_max_position_embeddings",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
|
@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
|
|||
)
|
||||
|
||||
|
||||
def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||
|
||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
dim = int(head_dim * partial_rotary_factor)
|
||||
|
||||
short_factor = rope_scaling.get("short_factor")
|
||||
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
||||
|
@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
|||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
attention_factor = rope_scaling.get("attention_factor")
|
||||
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
||||
if attention_factor is not None:
|
||||
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
||||
logger.warning(
|
||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
||||
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
|
@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
|
|||
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
||||
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||
if high_freq_factor < low_freq_factor:
|
||||
if high_freq_factor <= low_freq_factor:
|
||||
logger.warning(
|
||||
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
||||
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
||||
|
@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
|
|||
}
|
||||
|
||||
|
||||
def rope_config_validation(config: PretrainedConfig):
|
||||
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||
"""
|
||||
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
||||
"""
|
||||
|
@ -544,7 +585,7 @@ def rope_config_validation(config: PretrainedConfig):
|
|||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
||||
if validation_fn is not None:
|
||||
validation_fn(config)
|
||||
validation_fn(config, ignore_keys=ignore_keys)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue