update rope calculation; update modeling.py; update gate for moe

This commit is contained in:
Azure 2025-02-01 07:32:21 +00:00
parent 5a50b34627
commit f873558a89
11 changed files with 402 additions and 412 deletions

View file

@ -54,4 +54,4 @@ long_context:
token_step: token_step:
local_chat: local_chat:
prompt_file: "./ktransformers/p.txt" prompt_file: ""

View file

@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM
@ -78,7 +78,7 @@ def local_chat():
else: else:
content += line + "\n" content += line + "\n"
if content == "": if content == "":
if config.prompt_file == None or config.prompt_file == "": if not config.prompt_file:
content = "hi" content = "hi"
else: else:
content = open(config.prompt_file, "r").read() content = open(config.prompt_file, "r").read()

View file

@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" DeepSeekV3 model configuration """ """DeepSeekV3 model configuration"""
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV3Config(PretrainedConfig): class DeepseekV3Config(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3. defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
vocab_size (`int`, *optional*, defaults to 129280): vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations. Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61): num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder. Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128): num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128): num_key_value_heads (`int`, *optional*, defaults to 128):
@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`. `num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1): n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts, None means dense model. Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256): n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model. Number of routed experts.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
routed_scaling_factor (`float`, *optional*, defaults to 2.5): routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts. Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring> kv_lora_rank (`int`, *optional*, defaults to 512):
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring> Rank of the LoRA matrices for key and value projections.
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring> q_lora_rank (`int`, *optional*, defaults to 1536):
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring> Rank of the LoRA matrices for query projections.
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring> qk_rope_head_dim (`int`, *optional*, defaults to 64):
topk_method (`str`, *optional*, defaults to `"noaux_tc"`): Dimension of the query/key heads that use rotary position embeddings.
Topk method used in routed gate. v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8): n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts. Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4): topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8): num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model. Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3): first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/ \--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`): norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts. Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001): aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient. Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample. Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder. The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096): max_position_embeddings (`int`, *optional*, defaults to 4096):
@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention. Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0): attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities. The dropout ratio for the attention probabilities.
```python ```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config >>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration >>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config() >>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration >>> # Accessing the model configuration
>>> configuration = model.config >>> configuration = model.config
```""" ```"""
model_type = "deepseek_v3" model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
}
def __init__( def __init__(
self, self,
vocab_size=129280, vocab_size=129280,
hidden_size=7168, hidden_size=7168,
intermediate_size=18432, intermediate_size=18432,
moe_intermediate_size = 2048, moe_intermediate_size=2048,
num_hidden_layers=61, num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128, num_attention_heads=128,
num_key_value_heads=128, num_key_value_heads=128,
n_shared_experts = 1, n_shared_experts=1,
n_routed_experts = 256, n_routed_experts=256,
ep_size = 1, routed_scaling_factor=2.5,
routed_scaling_factor = 2.5, kv_lora_rank=512,
kv_lora_rank = 512, q_lora_rank=1536,
q_lora_rank = 1536, qk_rope_head_dim=64,
qk_rope_head_dim = 64, v_head_dim=128,
v_head_dim = 128, qk_nope_head_dim=128,
qk_nope_head_dim = 128, n_group=8,
topk_method = 'noaux_tc', topk_group=4,
n_group = 8, num_experts_per_tok=8,
topk_group = 4, first_k_dense_replace=3,
num_experts_per_tok = 8, norm_topk_prob=True,
moe_layer_freq = 1, aux_loss_alpha=0.001,
first_k_dense_replace = 3,
norm_topk_prob = True,
scoring_func = 'sigmoid',
aux_loss_alpha = 0.001,
seq_aux = True,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=4096, max_position_embeddings=4096,
initializer_range=0.02, initializer_range=0.02,
@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling=None, rope_scaling=None,
attention_bias=False, attention_bias=False,
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim
self.n_group = n_group self.n_group = n_group
self.topk_group = topk_group self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility # for backward compatibility
if num_key_value_heads is None: if num_key_value_heads is None:
num_key_value_heads = num_attention_heads num_key_value_heads = num_attention_heads
@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.attention_bias = attention_bias self.attention_bias = attention_bias
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias # Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,

View file

@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address # In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_() self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_()
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len

View file

@ -1,15 +1,13 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py. # This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of # Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the # the file from the modular. If any change should be done, please apply the change to the
# modular_deepseekv3.py file directly. One of our CI enforces this. # modular_deepseek_v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math import math
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
@ -30,7 +28,7 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.deprecation import deprecate_kwarg
from .configuration_deepseekv3 import DeepseekV3Config from .configuration_deepseek_v3 import DeepseekV3Config
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
class DeepseekV3MLP(nn.Module): class DeepseekV3MLP(nn.Module):
def __init__(self, config): def __init__(self, config, hidden_size=None, intermediate_size=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = config.moe_intermediate_size self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
# TODO rm hard coding
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, x):
@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
return down_proj return down_proj
class MoEGate(nn.Module): class DeepseekV3TopkRouter(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.scoring_func = config.scoring_func
self.seq_aux = config.seq_aux
self.topk_method = config.topk_method
self.n_group = config.n_group self.n_group = config.n_group
self.topk_group = config.topk_group self.topk_group = config.topk_group
# topk selection algorithm self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
self.norm_topk_prob = config.norm_topk_prob self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
self.gating_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
if self.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states): def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape batch_size, seq_length = hidden_states.shape[:-1]
### compute gating score hidden_states = hidden_states.view(-1, self.config.hidden_size)
hidden_states = hidden_states.view(-1, h) router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
if self.scoring_func == "sigmoid":
scores = logits.sigmoid()
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
### select top-k experts scores = router_logits.sigmoid()
if self.topk_method == "noaux_tc": scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
# assert not self.training group_scores = (
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
group_scores = ( .topk(2, dim=-1)[0]
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) .sum(dim=-1)
) # [n, n_group] ) # [n, n_group]
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = ( score_mask = (
group_mask.unsqueeze(-1) group_mask.unsqueeze(-1)
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
.reshape(bsz * seq_len, -1) .reshape(-1, self.n_routed_experts)
) # [n, e] ) # [n, e]
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
topk_weight = scores.gather(1, topk_idx) topk_weights = scores.gather(1, topk_indices)
else: denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
### norm gate to sum 1 return topk_indices, topk_weights, router_logits
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
return topk_idx, topk_weight
class DeepseekV3MoE(nn.Module): class DeepseekV3MoE(nn.Module):
@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.num_experts_per_tok = config.num_experts_per_tok self.experts = nn.ModuleList(
[
if hasattr(config, "ep_size") and config.ep_size > 1: DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
assert config.ep_size == dist.get_world_size() for _ in range(config.n_routed_experts)
self.ep_size = config.ep_size ]
self.experts_per_rank = config.n_routed_experts // config.ep_size )
self.ep_rank = dist.get_rank() self.gate = DeepseekV3TopkRouter(config)
self.experts = nn.ModuleList( self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size)
[
(
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
else None
)
for i in range(config.n_routed_experts)
]
)
else:
self.ep_size = 1
self.experts_per_rank = config.n_routed_experts
self.ep_rank = 0
self.experts = nn.ModuleList(
[
DeepseekV3MLP(config)
for i in range(config.n_routed_experts)
]
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(config=config)
def forward(self, hidden_states): def forward(self, hidden_states):
identity = hidden_states residuals = hidden_states
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states) topk_indices, topk_weights, router_logits = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if not self.training: hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) hidden_states = hidden_states + self.shared_experts(residuals)
if self.config.n_shared_experts is not None: return hidden_states, router_logits
y = y + self.shared_experts(identity)
return y
@torch.no_grad() def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
def moe_infer(self, x, topk_ids, topk_weight): final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
cnts.scatter_(1, topk_ids, 1) expert_mask = expert_mask.permute(2, 0, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
sorted_tokens_shape = sorted_tokens.shape
if self.ep_size > 1:
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
gathered_tokens = sorted_tokens.new_empty(
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
)
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
dist.all_to_all(
list(gathered_tokens.split(output_splits)),
list(sorted_tokens.split(input_split_sizes)),
)
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(
dim=0
)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
gatherd_idxs[s : s + k] = i % self.experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort()
sorted_tokens = gathered_tokens[gatherd_idxs]
tokens_per_expert = tokens_per_expert_post_gather
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = [] for expert_idx in range(len(self.experts)):
start_idx = 0 expert = self.experts[expert_idx]
for i, num_tokens in enumerate(tokens_per_expert): mask = expert_mask[expert_idx]
end_idx = start_idx + num_tokens token_indices, weight_indices = torch.where(mask)
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) if token_indices.numel() > 0:
if self.ep_size > 1: expert_weights = topk_weights[token_indices, weight_indices]
new_x = torch.empty_like(outs) expert_input = hidden_states[token_indices]
new_x[gatherd_idxs] = outs expert_output = expert(expert_input)
gathered_tokens = new_x.new_empty(*sorted_tokens_shape) weighted_output = expert_output * expert_weights.unsqueeze(-1)
dist.all_to_all( final_hidden_states.index_add_(0, token_indices, weighted_output)
list(gathered_tokens.split(input_split_sizes)), return final_hidden_states.type(hidden_states.dtype)
list(new_x.split(output_splits)),
)
outs = gathered_tokens
new_x = torch.empty_like(outs)
new_x[idxs] = outs def rotate_half(x):
final_out = ( """Rotates half the hidden dims of the input."""
new_x.view(*topk_ids.shape, -1) x1 = x[..., : x.shape[-1] // 2]
.type(topk_weight.dtype) x2 = x[..., x.shape[-1] // 2 :]
.mul_(topk_weight.unsqueeze(dim=-1)) return torch.cat((-x2, x1), dim=-1)
.sum(dim=1)
.type(new_x.dtype)
) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
return final_out """Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -359,150 +292,94 @@ def eager_attention_forward(
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.llama.modeling_llama.rotate_half def yarn_get_mscale(scale=1, mscale=1):
def rotate_half(x): if scale <= 1:
"""Rotates half the hidden dims of the input.""" return 1.0
x1 = x[..., : x.shape[-1] // 2] return 0.1 * mscale * math.log(scale) + 1.0
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class DeepseekV3Attention(nn.Module): class DeepseekV3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): def __init__(self, config: DeepseekV3Config, layer_idx: int):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
if layer_idx is None: self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
self.is_causal = True self.is_causal = True
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
if self.q_lora_rank is None: self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
self.kv_a_proj_with_mqa = nn.Linear( self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size, config.hidden_size,
config.kv_lora_rank + config.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias, bias=config.attention_bias,
) )
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear( self.kv_b_proj = nn.Linear(
config.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False, bias=False,
) )
self.o_proj = nn.Linear( self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
self.hidden_size, config.hidden_size,
bias=config.attention_bias, bias=config.attention_bias,
) )
self.rotary_emb = DeepseekV3RotaryEmbedding( self.scaling = self.q_head_dim ** (-0.5)
config=self.config, if self.config.rope_scaling is not None:
) mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scaling = self.scaling * mscale * mscale
# TODO apply in DeepSeekV3Model to share accrose layers
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs# : Unpack[FlashAttentionKwargs], **kwargs# : Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, self.num_heads, -1)
if self.q_lora_rank is None: q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2)
q = self.q_proj(hidden_states) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2)
kv_seq_len = value_states.shape[-2] k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) cos, sin = position_embeddings
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe k_rot = k_rot.expand(-1, self.num_heads, -1, -1)
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states = torch.cat((k_pass, k_rot), dim=-1)
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if self.q_head_dim != self.v_head_dim: if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
if past_key_value is not None: if past_key_value is not None:
@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
) )
else: else:
pass raise NotImplementedError(
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] f"Attention implementation {self.config._attn_implementation} is not supported. "
"Please use 'eager' or 'sdpa'."
)
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface( attn_output, attn_weights = attention_interface(
self, self,
@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
scaling=self.scaling, scaling=self.scaling,
**kwargs, **kwargs,
) )
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights return attn_output, attn_weights
@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
self.mlp = ( if layer_idx >= config.first_k_dense_replace:
DeepseekV3MoE(config) self.mlp = DeepseekV3MoE(config)
if ( else:
config.n_routed_experts is not None self.mlp = DeepseekV3MLP(config)
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, router_logits = hidden_states
else:
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)
if output_attentions: if output_attentions:
outputs += (self_attn_weights,) outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs return outputs
DEEPSEEKV3_START_DOCSTRING = r""" DEEPSEEK_V3_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.) etc.)
@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
@add_start_docstrings( @add_start_docstrings(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
DEEPSEEKV3_START_DOCSTRING, DEEPSEEK_V3_START_DOCSTRING,
) )
class DeepseekV3PreTrainedModel(PreTrainedModel): class DeepseekV3PreTrainedModel(PreTrainedModel):
config_class = DeepseekV3Config config_class = DeepseekV3Config
@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
DEEPSEEKV3_INPUTS_DOCSTRING = r""" DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
@add_start_docstrings( @add_start_docstrings(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
DEEPSEEKV3_START_DOCSTRING, DEEPSEEK_V3_START_DOCSTRING,
) )
class DeepseekV3Model(DeepseekV3PreTrainedModel): class DeepseekV3Model(DeepseekV3PreTrainedModel):
""" """
@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
config: DeepseekV3Config config: DeepseekV3Config
""" """
def __init__(self, config: DeepseekV3Config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self._register_load_state_dict_pre_hook(self.load_hook)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask return causal_mask
def load_hook(self, state_dict, prefix, *args):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def permute_for_rope(input_tensor):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
return input_tensor
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
weight = state_dict[key]
weight = weight.view(num_heads, head_dim, -1)
weight_rot = weight[:, -rope_dim:]
weight_rot = permute_for_rope(weight_rot)
weight[:, -rope_dim:] = weight_rot
weight = weight.view(-1, weight.shape[-1])
state_dict[key] = weight
for k in state_dict:
if "q_b_proj." in k:
permute_layer_for_rope(
k,
num_heads=self.config.num_attention_heads,
head_dim=self.config.q_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
if "kv_a_proj_with_mqa." in k:
permute_layer_for_rope(
k,
num_heads=1,
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
return self.model return self.model
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
```python ```python
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?" >>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt") >>> inputs = tokenizer(prompt, return_tensors="pt")
@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch). each row of the batch).
""", """,
DEEPSEEKV3_START_DOCSTRING, DEEPSEEK_V3_START_DOCSTRING,
) )
class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
def __init__(self, config): def __init__(self, config):
@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.model.embed_tokens = value self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
@ -1214,3 +1146,11 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
__all__ = [
"DeepseekV3PreTrainedModel",
"DeepseekV3Model",
"DeepseekV3ForCausalLM",
"DeepseekV3ForSequenceClassification",
]

View file

@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
from typing import Optional, Tuple from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, position_ids) cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models

View file

@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity = hidden_states identity = hidden_states
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
sequence_length = orig_shape[1] sequence_length = orig_shape[1]
topk_idx, topk_weight= self.gate(hidden_states) topk_idx, topk_weight, router_logits= self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_ y += y_
y.resize_(*orig_shape) y.resize_(*orig_shape)
return y return y, router_logits
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
) )
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y += y_ y += y_
return y return y, router_logits
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:

View file

@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.models.modeling_deepseekv3 import MoEGate from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else: else:
raise ValueError("Invalid weight type") raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device) self.orig_module.weight = self.orig_module.weight.to(device)
if self.topk_method == "noaux_tc": self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
def unload(self): def unload(self):
if self.weight is not None: if self.weight is not None:
self.weight = None self.weight = None
if self.topk_method == "noaux_tc": if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None self.e_score_correction_bias = None

View file

@ -47,7 +47,7 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace: replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs: kwargs:
@ -55,7 +55,7 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$" name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace: replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs: kwargs:
@ -64,7 +64,7 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseekv3.MoEGate class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
@ -72,7 +72,7 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseekv3.MoEGate class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
replace: replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs: kwargs:

View file

@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18) self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048) self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.max_new_tokens = self.model.get("max_new_tokens", 500) self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False) self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False) self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None) self.ban_strings: Optional[list] = self.model.get("ban_strings", None)

View file

@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif config is not None: elif config is not None:
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE attention_factor = 1.0 # Unused in this type of RoPE
@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif config is not None: elif config is not None:
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"] factor = config.rope_scaling["factor"]
attention_factor = 1.0 # Unused in this type of RoPE attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time # seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None else max_position_embeddings seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
# Compute the inverse frequencies # Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = config.qk_rope_head_dim head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"] factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in config.rope_scaling:
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings
def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper # Sets the attention factor as suggested in the paper
attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None: if attention_factor is None:
attention_factor = 0.1 * math.log(factor) + 1.0 if mscale and mscale_all_dim:
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)
# Optional config options # Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(min, max, dim): def linear_ramp_factor(min, max, dim):
if min == max: if min == max:
max += 0.001 # Prevent singularity max += 0.001 # Prevent singularity
@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs) inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation # Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device) inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, attention_factor return inv_freq, attention_factor
@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`): device (`torch.device`):
The device to use for initialization of the inverse frequencies. The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*): seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE. The current sequence length.
rope_kwargs (`Dict`, *optional*): rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns: Returns:
@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base = config.rope_theta base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"] long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"] short_factor = config.rope_scaling["short_factor"]
factor = config.rope_scaling.get("factor") factor = config.rope_scaling.get("factor")
@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`. # values to compute the default attention scaling factor, instead of using `factor`.
if hasattr(config, "original_max_position_embeddings"): if hasattr(config, "original_max_position_embeddings"):
max_position_embeddings = config.original_max_position_embeddings original_max_position_embeddings = config.original_max_position_embeddings
expanded_max_position_embeddings = config.max_position_embeddings factor = config.max_position_embeddings / config.original_max_position_embeddings
factor = expanded_max_position_embeddings / max_position_embeddings
else: else:
max_position_embeddings = config.max_position_embeddings original_max_position_embeddings = config.max_position_embeddings
expanded_max_position_embeddings = max_position_embeddings * factor
# Sets the attention factor as suggested in the paper # Sets the attention factor as suggested in the paper
if attention_factor is None: if attention_factor is None:
if factor <= 1.0: if factor <= 1.0:
attention_factor = 1.0 attention_factor = 1.0
else: else:
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
# Compute the inverse frequencies -- scaled based on the target sequence length # Compute the inverse frequencies -- scaled based on the target sequence length
if expanded_max_position_embeddings > max_position_embeddings: if seq_len and seq_len > original_max_position_embeddings:
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
else: else:
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen = old_context_len / low_freq_factor low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq: wavelen = 2 * math.pi / inv_freq
wavelen = 2 * math.pi / freq # wavelen < high_freq_wavelen: do nothing
if wavelen < high_freq_wavelen: # wavelen > low_freq_wavelen: divide by factor
new_freqs.append(freq) inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
elif wavelen > low_freq_wavelen: # otherwise: interpolate between the two, using a smooth factor
new_freqs.append(freq / factor) smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
else: smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
assert low_freq_wavelen != high_freq_wavelen is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) return inv_freq_llama, attention_factor
return inv_freq, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
} }
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys""" """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "rope_type" not in received_keys and "type" in received_keys: if "type" in received_keys:
received_keys -= {"type"} received_keys -= {"type"}
received_keys.add("rope_type") required_keys.add("rope_type")
# Some models need to store model-specific keys, and we don't want to throw warning at them
if ignore_keys is not None:
received_keys -= ignore_keys
missing_keys = required_keys - received_keys missing_keys = required_keys - received_keys
if missing_keys: if missing_keys:
@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig): def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"} required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"} optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig): def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"} required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"} optional_keys = {
"attention_factor",
"beta_fast",
"beta_slow",
"original_max_position_embeddings",
"mscale",
"mscale_all_dim",
}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
) )
def _validate_longrope_parameters(config: PretrainedConfig): def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"} required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings` # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys) _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
short_factor = rope_scaling.get("short_factor") short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor") attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: if attention_factor is not None:
logger.warning( if not isinstance(attention_factor, float) or attention_factor < 0.0:
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" logger.warning(
) f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig): def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys()) received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys) _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"] factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0: if factor is None or not isinstance(factor, float) or factor < 1.0:
@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float): if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor: if high_freq_factor <= low_freq_factor:
logger.warning( logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}" f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
} }
def rope_config_validation(config: PretrainedConfig): def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
""" """
Validate the RoPE config arguments, given a `PretrainedConfig` object Validate the RoPE config arguments, given a `PretrainedConfig` object
""" """
@ -544,7 +585,7 @@ def rope_config_validation(config: PretrainedConfig):
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None: if validation_fn is not None:
validation_fn(config) validation_fn(config, ignore_keys=ignore_keys)
else: else:
logger.warning( logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"