mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
update rope calculation; update modeling.py; update gate for moe
This commit is contained in:
parent
5a50b34627
commit
f873558a89
11 changed files with 402 additions and 412 deletions
|
@ -54,4 +54,4 @@ long_context:
|
||||||
token_step:
|
token_step:
|
||||||
|
|
||||||
local_chat:
|
local_chat:
|
||||||
prompt_file: "./ktransformers/p.txt"
|
prompt_file: ""
|
|
@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
|
||||||
|
|
||||||
|
|
||||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
|
||||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||||
|
@ -78,7 +78,7 @@ def local_chat():
|
||||||
else:
|
else:
|
||||||
content += line + "\n"
|
content += line + "\n"
|
||||||
if content == "":
|
if content == "":
|
||||||
if config.prompt_file == None or config.prompt_file == "":
|
if not config.prompt_file:
|
||||||
content = "hi"
|
content = "hi"
|
||||||
else:
|
else:
|
||||||
content = open(config.prompt_file, "r").read()
|
content = open(config.prompt_file, "r").read()
|
||||||
|
|
|
@ -14,19 +14,25 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DeepSeekV3 model configuration """
|
"""DeepSeekV3 model configuration"""
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.modeling_rope_utils import rope_config_validation
|
||||||
|
|
||||||
|
|
||||||
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Config(PretrainedConfig):
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
||||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size (`int`, *optional*, defaults to 129280):
|
vocab_size (`int`, *optional*, defaults to 129280):
|
||||||
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||||
|
@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
|
||||||
Dimension of the MoE representations.
|
Dimension of the MoE representations.
|
||||||
num_hidden_layers (`int`, *optional*, defaults to 61):
|
num_hidden_layers (`int`, *optional*, defaults to 61):
|
||||||
Number of hidden layers in the Transformer decoder.
|
Number of hidden layers in the Transformer decoder.
|
||||||
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
|
||||||
Number of nextn predict layers in the DeepSeekV3 Model.
|
|
||||||
num_attention_heads (`int`, *optional*, defaults to 128):
|
num_attention_heads (`int`, *optional*, defaults to 128):
|
||||||
Number of attention heads for each attention layer in the Transformer decoder.
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
num_key_value_heads (`int`, *optional*, defaults to 128):
|
num_key_value_heads (`int`, *optional*, defaults to 128):
|
||||||
|
@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
|
||||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
`num_attention_heads`.
|
`num_attention_heads`.
|
||||||
n_shared_experts (`int`, *optional*, defaults to 1):
|
n_shared_experts (`int`, *optional*, defaults to 1):
|
||||||
Number of shared experts, None means dense model.
|
Number of shared experts.
|
||||||
n_routed_experts (`int`, *optional*, defaults to 256):
|
n_routed_experts (`int`, *optional*, defaults to 256):
|
||||||
Number of routed experts, None means dense model.
|
Number of routed experts.
|
||||||
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
|
|
||||||
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
|
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
|
||||||
Scaling factor or routed experts.
|
Scaling factor or routed experts.
|
||||||
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
|
kv_lora_rank (`int`, *optional*, defaults to 512):
|
||||||
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
|
Rank of the LoRA matrices for key and value projections.
|
||||||
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
|
q_lora_rank (`int`, *optional*, defaults to 1536):
|
||||||
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
|
Rank of the LoRA matrices for query projections.
|
||||||
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
|
qk_rope_head_dim (`int`, *optional*, defaults to 64):
|
||||||
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
|
Dimension of the query/key heads that use rotary position embeddings.
|
||||||
Topk method used in routed gate.
|
v_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
Dimension of the value heads.
|
||||||
|
qk_nope_head_dim (`int`, *optional*, defaults to 128):
|
||||||
|
Dimension of the query/key heads that don't use rotary position embeddings.
|
||||||
n_group (`int`, *optional*, defaults to 8):
|
n_group (`int`, *optional*, defaults to 8):
|
||||||
Number of groups for routed experts.
|
Number of groups for routed experts.
|
||||||
topk_group (`int`, *optional*, defaults to 4):
|
topk_group (`int`, *optional*, defaults to 4):
|
||||||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||||
Number of selected experts, None means dense model.
|
Number of selected experts, None means dense model.
|
||||||
moe_layer_freq (`int`, *optional*, defaults to 1):
|
|
||||||
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
|
||||||
first_k_dense_replace (`int`, *optional*, defaults to 3):
|
first_k_dense_replace (`int`, *optional*, defaults to 3):
|
||||||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||||
\--k dense layers--/
|
\--k dense layers--/
|
||||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to normalize the weights of the routed experts.
|
Whether to normalize the weights of the routed experts.
|
||||||
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
|
|
||||||
Method of computing expert weights.
|
|
||||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||||
Auxiliary loss weight coefficient.
|
Auxiliary loss weight coefficient.
|
||||||
Whether to compute the auxiliary loss for each individual sample.
|
Whether to compute the auxiliary loss for each individual sample.
|
||||||
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
|
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
The non-linear activation function (function or string) in the decoder.
|
The non-linear activation function (function or string) in the decoder.
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||||
|
@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
|
||||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
The dropout ratio for the attention probabilities.
|
The dropout ratio for the attention probabilities.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
||||||
|
|
||||||
>>> # Initializing a Deepseek-V3 style configuration
|
>>> # Initializing a Deepseek-V3 style configuration
|
||||||
>>> configuration = DeepseekV3Config()
|
>>> configuration = DeepseekV3Config()
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
model_type = "deepseek_v3"
|
model_type = "deepseek_v3"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
# Default tensor parallel plan for base model `DeepseekV3Model`
|
||||||
|
base_model_tp_plan = {
|
||||||
|
"layers.*.gate_proj": "colwise",
|
||||||
|
"layers.*.up_proj": "colwise",
|
||||||
|
"layers.*.down_proj": "rowwise",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=129280,
|
vocab_size=129280,
|
||||||
hidden_size=7168,
|
hidden_size=7168,
|
||||||
intermediate_size=18432,
|
intermediate_size=18432,
|
||||||
moe_intermediate_size = 2048,
|
moe_intermediate_size=2048,
|
||||||
num_hidden_layers=61,
|
num_hidden_layers=61,
|
||||||
num_nextn_predict_layers=1,
|
|
||||||
num_attention_heads=128,
|
num_attention_heads=128,
|
||||||
num_key_value_heads=128,
|
num_key_value_heads=128,
|
||||||
n_shared_experts = 1,
|
n_shared_experts=1,
|
||||||
n_routed_experts = 256,
|
n_routed_experts=256,
|
||||||
ep_size = 1,
|
routed_scaling_factor=2.5,
|
||||||
routed_scaling_factor = 2.5,
|
kv_lora_rank=512,
|
||||||
kv_lora_rank = 512,
|
q_lora_rank=1536,
|
||||||
q_lora_rank = 1536,
|
qk_rope_head_dim=64,
|
||||||
qk_rope_head_dim = 64,
|
v_head_dim=128,
|
||||||
v_head_dim = 128,
|
qk_nope_head_dim=128,
|
||||||
qk_nope_head_dim = 128,
|
n_group=8,
|
||||||
topk_method = 'noaux_tc',
|
topk_group=4,
|
||||||
n_group = 8,
|
num_experts_per_tok=8,
|
||||||
topk_group = 4,
|
first_k_dense_replace=3,
|
||||||
num_experts_per_tok = 8,
|
norm_topk_prob=True,
|
||||||
moe_layer_freq = 1,
|
aux_loss_alpha=0.001,
|
||||||
first_k_dense_replace = 3,
|
|
||||||
norm_topk_prob = True,
|
|
||||||
scoring_func = 'sigmoid',
|
|
||||||
aux_loss_alpha = 0.001,
|
|
||||||
seq_aux = True,
|
|
||||||
hidden_act="silu",
|
hidden_act="silu",
|
||||||
max_position_embeddings=4096,
|
max_position_embeddings=4096,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
|
||||||
rope_scaling=None,
|
rope_scaling=None,
|
||||||
attention_bias=False,
|
attention_bias=False,
|
||||||
attention_dropout=0.0,
|
attention_dropout=0.0,
|
||||||
mlp_bias=False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.moe_intermediate_size = moe_intermediate_size
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.n_shared_experts = n_shared_experts
|
self.n_shared_experts = n_shared_experts
|
||||||
self.n_routed_experts = n_routed_experts
|
self.n_routed_experts = n_routed_experts
|
||||||
self.ep_size = ep_size
|
|
||||||
self.routed_scaling_factor = routed_scaling_factor
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
self.q_lora_rank = q_lora_rank
|
self.q_lora_rank = q_lora_rank
|
||||||
self.qk_rope_head_dim = qk_rope_head_dim
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
self.qk_nope_head_dim = qk_nope_head_dim
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
self.topk_method = topk_method
|
self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.head_dim = qk_rope_head_dim
|
||||||
self.n_group = n_group
|
self.n_group = n_group
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.num_experts_per_tok = num_experts_per_tok
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
self.moe_layer_freq = moe_layer_freq
|
|
||||||
self.first_k_dense_replace = first_k_dense_replace
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
self.norm_topk_prob = norm_topk_prob
|
self.norm_topk_prob = norm_topk_prob
|
||||||
self.scoring_func = scoring_func
|
|
||||||
self.aux_loss_alpha = aux_loss_alpha
|
self.aux_loss_alpha = aux_loss_alpha
|
||||||
self.seq_aux = seq_aux
|
|
||||||
# for backward compatibility
|
# for backward compatibility
|
||||||
if num_key_value_heads is None:
|
if num_key_value_heads is None:
|
||||||
num_key_value_heads = num_attention_heads
|
num_key_value_heads = num_attention_heads
|
||||||
|
@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.attention_bias = attention_bias
|
self.attention_bias = attention_bias
|
||||||
self.attention_dropout = attention_dropout
|
self.attention_dropout = attention_dropout
|
||||||
self.mlp_bias = mlp_bias
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
|
# BC: if there is a 'type' field, copy it it to 'rope_type'.
|
||||||
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
|
rope_config_validation(self)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
|
@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
|
||||||
# In-place ops prevent breaking the static address
|
# In-place ops prevent breaking the static address
|
||||||
self.key_cache[layer_idx].zero_()
|
self.key_cache[layer_idx].zero_()
|
||||||
self.value_cache[layer_idx].zero_()
|
self.value_cache[layer_idx].zero_()
|
||||||
|
|
||||||
|
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
|
||||||
|
"""Returns the maximum shape of the cache."""
|
||||||
|
return self.max_cache_len
|
|
@ -1,15 +1,13 @@
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py.
|
# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
|
||||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
# the file from the modular. If any change should be done, please apply the change to the
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
# modular_deepseekv3.py file directly. One of our CI enforces this.
|
# modular_deepseek_v3.py file directly. One of our CI enforces this.
|
||||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
import math
|
import math
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -30,7 +28,7 @@ from transformers.utils import (
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from .configuration_deepseekv3 import DeepseekV3Config
|
from .configuration_deepseek_v3 import DeepseekV3Config
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MLP(nn.Module):
|
class DeepseekV3MLP(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||||
self.intermediate_size = config.moe_intermediate_size
|
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||||
# TODO rm hard coding
|
|
||||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias)
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias)
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias)
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
|
||||||
return down_proj
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
class MoEGate(nn.Module):
|
class DeepseekV3TopkRouter(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
self.n_routed_experts = config.n_routed_experts
|
self.n_routed_experts = config.n_routed_experts
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.scoring_func = config.scoring_func
|
|
||||||
self.seq_aux = config.seq_aux
|
|
||||||
self.topk_method = config.topk_method
|
|
||||||
self.n_group = config.n_group
|
self.n_group = config.n_group
|
||||||
self.topk_group = config.topk_group
|
self.topk_group = config.topk_group
|
||||||
|
|
||||||
# topk selection algorithm
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||||
self.norm_topk_prob = config.norm_topk_prob
|
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
||||||
self.gating_dim = config.hidden_size
|
|
||||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
|
||||||
if self.topk_method == "noaux_tc":
|
|
||||||
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
|
||||||
import torch.nn.init as init
|
|
||||||
|
|
||||||
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
bsz, seq_len, h = hidden_states.shape
|
batch_size, seq_length = hidden_states.shape[:-1]
|
||||||
### compute gating score
|
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||||
hidden_states = hidden_states.view(-1, h)
|
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||||
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
|
|
||||||
if self.scoring_func == "sigmoid":
|
|
||||||
scores = logits.sigmoid()
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
|
|
||||||
|
|
||||||
### select top-k experts
|
scores = router_logits.sigmoid()
|
||||||
if self.topk_method == "noaux_tc":
|
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||||
# assert not self.training
|
group_scores = (
|
||||||
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||||
group_scores = (
|
.topk(2, dim=-1)[0]
|
||||||
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
|
.sum(dim=-1)
|
||||||
) # [n, n_group]
|
) # [n, n_group]
|
||||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
score_mask = (
|
score_mask = (
|
||||||
group_mask.unsqueeze(-1)
|
group_mask.unsqueeze(-1)
|
||||||
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
|
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
|
||||||
.reshape(bsz * seq_len, -1)
|
.reshape(-1, self.n_routed_experts)
|
||||||
) # [n, e]
|
) # [n, e]
|
||||||
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
|
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
|
||||||
topk_weight = scores.gather(1, topk_idx)
|
topk_weights = scores.gather(1, topk_indices)
|
||||||
else:
|
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
|
topk_weights /= denominator
|
||||||
|
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
|
||||||
### norm gate to sum 1
|
return topk_indices, topk_weights, router_logits
|
||||||
if self.top_k > 1 and self.norm_topk_prob:
|
|
||||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
|
||||||
topk_weight = topk_weight / denominator
|
|
||||||
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
|
||||||
|
|
||||||
return topk_idx, topk_weight
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3MoE(nn.Module):
|
class DeepseekV3MoE(nn.Module):
|
||||||
|
@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_experts_per_tok = config.num_experts_per_tok
|
self.experts = nn.ModuleList(
|
||||||
|
[
|
||||||
if hasattr(config, "ep_size") and config.ep_size > 1:
|
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
||||||
assert config.ep_size == dist.get_world_size()
|
for _ in range(config.n_routed_experts)
|
||||||
self.ep_size = config.ep_size
|
]
|
||||||
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
)
|
||||||
self.ep_rank = dist.get_rank()
|
self.gate = DeepseekV3TopkRouter(config)
|
||||||
self.experts = nn.ModuleList(
|
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size)
|
||||||
[
|
|
||||||
(
|
|
||||||
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
|
|
||||||
if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
for i in range(config.n_routed_experts)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.ep_size = 1
|
|
||||||
self.experts_per_rank = config.n_routed_experts
|
|
||||||
self.ep_rank = 0
|
|
||||||
self.experts = nn.ModuleList(
|
|
||||||
[
|
|
||||||
DeepseekV3MLP(config)
|
|
||||||
for i in range(config.n_routed_experts)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.gate = MoEGate(config)
|
|
||||||
if config.n_shared_experts is not None:
|
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
|
||||||
self.shared_experts = DeepseekV3MLP(config=config)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
identity = hidden_states
|
residuals = hidden_states
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
topk_idx, topk_weight = self.gate(hidden_states)
|
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
if not self.training:
|
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||||
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||||
if self.config.n_shared_experts is not None:
|
return hidden_states, router_logits
|
||||||
y = y + self.shared_experts(identity)
|
|
||||||
return y
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||||
def moe_infer(self, x, topk_ids, topk_weight):
|
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
|
||||||
cnts.scatter_(1, topk_ids, 1)
|
expert_mask = expert_mask.permute(2, 0, 1)
|
||||||
tokens_per_expert = cnts.sum(dim=0)
|
|
||||||
idxs = topk_ids.view(-1).argsort()
|
|
||||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
|
||||||
sorted_tokens_shape = sorted_tokens.shape
|
|
||||||
if self.ep_size > 1:
|
|
||||||
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
|
|
||||||
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
|
|
||||||
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
|
|
||||||
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
|
|
||||||
gathered_tokens = sorted_tokens.new_empty(
|
|
||||||
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
|
|
||||||
)
|
|
||||||
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
|
||||||
dist.all_to_all(
|
|
||||||
list(gathered_tokens.split(output_splits)),
|
|
||||||
list(sorted_tokens.split(input_split_sizes)),
|
|
||||||
)
|
|
||||||
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(
|
|
||||||
dim=0
|
|
||||||
)
|
|
||||||
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
|
|
||||||
s = 0
|
|
||||||
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
|
||||||
gatherd_idxs[s : s + k] = i % self.experts_per_rank
|
|
||||||
s += k
|
|
||||||
gatherd_idxs = gatherd_idxs.argsort()
|
|
||||||
sorted_tokens = gathered_tokens[gatherd_idxs]
|
|
||||||
tokens_per_expert = tokens_per_expert_post_gather
|
|
||||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
|
||||||
|
|
||||||
outputs = []
|
for expert_idx in range(len(self.experts)):
|
||||||
start_idx = 0
|
expert = self.experts[expert_idx]
|
||||||
for i, num_tokens in enumerate(tokens_per_expert):
|
mask = expert_mask[expert_idx]
|
||||||
end_idx = start_idx + num_tokens
|
token_indices, weight_indices = torch.where(mask)
|
||||||
if num_tokens == 0:
|
|
||||||
continue
|
|
||||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
|
||||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
|
||||||
expert_out = expert(tokens_for_this_expert)
|
|
||||||
outputs.append(expert_out)
|
|
||||||
start_idx = end_idx
|
|
||||||
|
|
||||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
if token_indices.numel() > 0:
|
||||||
if self.ep_size > 1:
|
expert_weights = topk_weights[token_indices, weight_indices]
|
||||||
new_x = torch.empty_like(outs)
|
expert_input = hidden_states[token_indices]
|
||||||
new_x[gatherd_idxs] = outs
|
expert_output = expert(expert_input)
|
||||||
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
weighted_output = expert_output * expert_weights.unsqueeze(-1)
|
||||||
dist.all_to_all(
|
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
||||||
list(gathered_tokens.split(input_split_sizes)),
|
return final_hidden_states.type(hidden_states.dtype)
|
||||||
list(new_x.split(output_splits)),
|
|
||||||
)
|
|
||||||
outs = gathered_tokens
|
|
||||||
|
|
||||||
new_x = torch.empty_like(outs)
|
|
||||||
new_x[idxs] = outs
|
def rotate_half(x):
|
||||||
final_out = (
|
"""Rotates half the hidden dims of the input."""
|
||||||
new_x.view(*topk_ids.shape, -1)
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
.type(topk_weight.dtype)
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
.sum(dim=1)
|
|
||||||
.type(new_x.dtype)
|
|
||||||
)
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
return final_out
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
@ -359,150 +292,94 @@ def eager_attention_forward(
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
def yarn_get_mscale(scale=1, mscale=1):
|
||||||
def rotate_half(x):
|
if scale <= 1:
|
||||||
"""Rotates half the hidden dims of the input."""
|
return 1.0
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q (`torch.Tensor`): The query tensor.
|
|
||||||
k (`torch.Tensor`): The key tensor.
|
|
||||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
||||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
||||||
position_ids (`torch.Tensor`, *optional*):
|
|
||||||
Deprecated and unused.
|
|
||||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
||||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
||||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
||||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
||||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
||||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
||||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
||||||
Returns:
|
|
||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
||||||
"""
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Attention(nn.Module):
|
class DeepseekV3Attention(nn.Module):
|
||||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
|
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
if layer_idx is None:
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||||
logger.warning_once(
|
|
||||||
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
|
||||||
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
|
||||||
"when creating this class."
|
|
||||||
)
|
|
||||||
|
|
||||||
self.attention_dropout = config.attention_dropout
|
self.attention_dropout = config.attention_dropout
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
|
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.q_lora_rank = config.q_lora_rank
|
self.q_lora_rank = config.q_lora_rank
|
||||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||||
self.kv_lora_rank = config.kv_lora_rank
|
self.kv_lora_rank = config.kv_lora_rank
|
||||||
self.v_head_dim = config.v_head_dim
|
self.v_head_dim = config.v_head_dim
|
||||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||||
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
|
||||||
|
|
||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
||||||
if self.q_lora_rank is None:
|
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
||||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
|
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
||||||
else:
|
|
||||||
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
|
|
||||||
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
|
|
||||||
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
|
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = nn.Linear(
|
self.kv_a_proj_with_mqa = nn.Linear(
|
||||||
self.hidden_size,
|
config.hidden_size,
|
||||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
)
|
)
|
||||||
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
|
self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
|
||||||
self.kv_b_proj = nn.Linear(
|
self.kv_b_proj = nn.Linear(
|
||||||
config.kv_lora_rank,
|
self.kv_lora_rank,
|
||||||
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = nn.Linear(
|
self.o_proj = nn.Linear(
|
||||||
self.num_heads * self.v_head_dim,
|
self.num_heads * self.v_head_dim,
|
||||||
self.hidden_size,
|
config.hidden_size,
|
||||||
bias=config.attention_bias,
|
bias=config.attention_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.rotary_emb = DeepseekV3RotaryEmbedding(
|
self.scaling = self.q_head_dim ** (-0.5)
|
||||||
config=self.config,
|
if self.config.rope_scaling is not None:
|
||||||
)
|
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if mscale_all_dim:
|
||||||
|
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||||
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
# TODO apply in DeepSeekV3Model to share accrose layers
|
||||||
|
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor],
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs# : Unpack[FlashAttentionKwargs],
|
**kwargs# : Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, self.num_heads, -1)
|
||||||
|
|
||||||
if self.q_lora_rank is None:
|
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2)
|
||||||
q = self.q_proj(hidden_states)
|
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
else:
|
|
||||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
|
||||||
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
|
||||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
||||||
|
|
||||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||||
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
|
||||||
kv = (
|
|
||||||
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
|
||||||
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2)
|
||||||
kv_seq_len = value_states.shape[-2]
|
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
if past_key_value is not None:
|
|
||||||
if self.layer_idx is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
||||||
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
||||||
"with a layer index."
|
|
||||||
)
|
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
|
|
||||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2)
|
||||||
|
|
||||||
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
cos, sin = position_embeddings
|
||||||
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
|
||||||
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
k_rot = k_rot.expand(-1, self.num_heads, -1, -1)
|
||||||
|
|
||||||
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
query_states = torch.cat((q_pass, q_rot), dim=-1)
|
||||||
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
key_states = torch.cat((k_pass, k_rot), dim=-1)
|
||||||
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
|
||||||
|
|
||||||
if self.q_head_dim != self.v_head_dim:
|
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||||
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
|
||||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pass
|
raise NotImplementedError(
|
||||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
f"Attention implementation {self.config._attn_implementation} is not supported. "
|
||||||
|
"Please use 'eager' or 'sdpa'."
|
||||||
|
)
|
||||||
|
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
attn_output, attn_weights = attention_interface(
|
attn_output, attn_weights = attention_interface(
|
||||||
self,
|
self,
|
||||||
|
@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
|
||||||
scaling=self.scaling,
|
scaling=self.scaling,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
|
||||||
|
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||||
|
|
||||||
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
|
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
|
||||||
|
|
||||||
self.mlp = (
|
if layer_idx >= config.first_k_dense_replace:
|
||||||
DeepseekV3MoE(config)
|
self.mlp = DeepseekV3MoE(config)
|
||||||
if (
|
else:
|
||||||
config.n_routed_experts is not None
|
self.mlp = DeepseekV3MLP(config)
|
||||||
and layer_idx >= config.first_k_dense_replace
|
|
||||||
and layer_idx % config.moe_layer_freq == 0
|
|
||||||
)
|
|
||||||
else DeepseekV3MLP(config)
|
|
||||||
)
|
|
||||||
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
output_router_logits: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states, router_logits = hidden_states
|
||||||
|
else:
|
||||||
|
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs += (self_attn_weights,)
|
outputs += (self_attn_weights,)
|
||||||
|
if output_router_logits:
|
||||||
|
outputs += (router_logits,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
DEEPSEEKV3_START_DOCSTRING = r"""
|
DEEPSEEK_V3_START_DOCSTRING = r"""
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||||
etc.)
|
etc.)
|
||||||
|
@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
|
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
|
||||||
DEEPSEEKV3_START_DOCSTRING,
|
DEEPSEEK_V3_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DeepseekV3PreTrainedModel(PreTrainedModel):
|
class DeepseekV3PreTrainedModel(PreTrainedModel):
|
||||||
config_class = DeepseekV3Config
|
config_class = DeepseekV3Config
|
||||||
|
@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
|
|
||||||
DEEPSEEKV3_INPUTS_DOCSTRING = r"""
|
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
|
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
|
||||||
DEEPSEEKV3_START_DOCSTRING,
|
DEEPSEEK_V3_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
|
@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||||
config: DeepseekV3Config
|
config: DeepseekV3Config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: DeepseekV3Config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||||
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
|
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
def load_hook(self, state_dict, prefix, *args):
|
||||||
|
"""
|
||||||
|
Weights have to be permuted for correct rope formulation. We can't do this in the weights
|
||||||
|
as every other framework already uses the `Llama` original function (which is copyrighted btw).
|
||||||
|
And I am not even sure it's better.... anyways end of my rant
|
||||||
|
"""
|
||||||
|
|
||||||
|
def permute_for_rope(input_tensor):
|
||||||
|
"""
|
||||||
|
When you go from the complex ROPE formulation to sin and cos one, you need
|
||||||
|
to permute the query and key weights (to avoid doing it on the fly)
|
||||||
|
"""
|
||||||
|
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
|
||||||
|
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
|
||||||
|
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
|
||||||
|
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
|
||||||
|
weight = state_dict[key]
|
||||||
|
weight = weight.view(num_heads, head_dim, -1)
|
||||||
|
weight_rot = weight[:, -rope_dim:]
|
||||||
|
weight_rot = permute_for_rope(weight_rot)
|
||||||
|
weight[:, -rope_dim:] = weight_rot
|
||||||
|
weight = weight.view(-1, weight.shape[-1])
|
||||||
|
state_dict[key] = weight
|
||||||
|
|
||||||
|
for k in state_dict:
|
||||||
|
if "q_b_proj." in k:
|
||||||
|
permute_layer_for_rope(
|
||||||
|
k,
|
||||||
|
num_heads=self.config.num_attention_heads,
|
||||||
|
head_dim=self.config.q_head_dim,
|
||||||
|
rope_dim=self.config.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
if "kv_a_proj_with_mqa." in k:
|
||||||
|
permute_layer_for_rope(
|
||||||
|
k,
|
||||||
|
num_heads=1,
|
||||||
|
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
|
||||||
|
rope_dim=self.config.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||||
|
|
||||||
|
@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
|
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
|
||||||
|
|
||||||
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
|
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
|
||||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||||
each row of the batch).
|
each row of the batch).
|
||||||
""",
|
""",
|
||||||
DEEPSEEKV3_START_DOCSTRING,
|
DEEPSEEK_V3_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.model.embed_tokens = value
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
@ -1214,3 +1146,11 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DeepseekV3PreTrainedModel",
|
||||||
|
"DeepseekV3Model",
|
||||||
|
"DeepseekV3ForCausalLM",
|
||||||
|
"DeepseekV3ForSequenceClassification",
|
||||||
|
]
|
|
@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
|
||||||
from ktransformers.models.configuration_llama import LlamaConfig
|
from ktransformers.models.configuration_llama import LlamaConfig
|
||||||
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
|
||||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
|
||||||
|
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
|
@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(q_pe, position_ids)
|
cos, sin = self.rotary_emb(q_pe, position_ids)
|
||||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
|
q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
||||||
|
|
|
@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
||||||
|
|
||||||
|
|
||||||
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
||||||
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
||||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
|
@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||||
identity = hidden_states
|
identity = hidden_states
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
sequence_length = orig_shape[1]
|
sequence_length = orig_shape[1]
|
||||||
topk_idx, topk_weight= self.gate(hidden_states)
|
topk_idx, topk_weight, router_logits= self.gate(hidden_states)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
|
# only for generate phase
|
||||||
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
|
||||||
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
|
@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||||
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
|
||||||
y += y_
|
y += y_
|
||||||
y.resize_(*orig_shape)
|
y.resize_(*orig_shape)
|
||||||
return y
|
return y, router_logits
|
||||||
|
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y_ = self.shared_experts(identity).squeeze(0)
|
y_ = self.shared_experts(identity).squeeze(0)
|
||||||
|
@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
|
||||||
)
|
)
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
y += y_
|
y += y_
|
||||||
return y
|
return y, router_logits
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
|
||||||
import ctypes
|
import ctypes
|
||||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||||
from ktransformers.util.custom_gguf import GGUFLoader
|
from ktransformers.util.custom_gguf import GGUFLoader
|
||||||
from ktransformers.models.modeling_deepseekv3 import MoEGate
|
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
|
||||||
from ktransformers.util.utils import InferenceState
|
from ktransformers.util.utils import InferenceState
|
||||||
from ktransformers.server.config.config import Config
|
from ktransformers.server.config.config import Config
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid weight type")
|
raise ValueError("Invalid weight type")
|
||||||
self.orig_module.weight = self.orig_module.weight.to(device)
|
self.orig_module.weight = self.orig_module.weight.to(device)
|
||||||
if self.topk_method == "noaux_tc":
|
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
|
||||||
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
|
|
||||||
|
|
||||||
def unload(self):
|
def unload(self):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
self.weight = None
|
self.weight = None
|
||||||
if self.topk_method == "noaux_tc":
|
if self.e_score_correction_bias is not None:
|
||||||
self.e_score_correction_bias = None
|
self.e_score_correction_bias = None
|
||||||
|
|
|
@ -47,7 +47,7 @@
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
|
||||||
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
kwargs:
|
kwargs:
|
||||||
|
@ -55,7 +55,7 @@
|
||||||
prefill_device: "cuda:0"
|
prefill_device: "cuda:0"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
|
||||||
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
|
||||||
kwargs:
|
kwargs:
|
||||||
|
@ -64,7 +64,7 @@
|
||||||
|
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
|
||||||
class: ktransformers.models.modeling_deepseekv3.MoEGate
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.gate.KMoEGate
|
class: ktransformers.operators.gate.KMoEGate
|
||||||
kwargs:
|
kwargs:
|
||||||
|
@ -72,7 +72,7 @@
|
||||||
prefill_device: "cuda:0"
|
prefill_device: "cuda:0"
|
||||||
- match:
|
- match:
|
||||||
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
|
||||||
class: ktransformers.models.modeling_deepseekv3.MoEGate
|
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
|
||||||
replace:
|
replace:
|
||||||
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
|
||||||
kwargs:
|
kwargs:
|
||||||
|
|
|
@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
|
||||||
self.total_context = self.model.get("total_context", 2**18)
|
self.total_context = self.model.get("total_context", 2**18)
|
||||||
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
|
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
|
||||||
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
|
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
|
||||||
self.max_new_tokens = self.model.get("max_new_tokens", 500)
|
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
|
||||||
self.json_mode = self.model.get("json_mode", False)
|
self.json_mode = self.model.get("json_mode", False)
|
||||||
self.healing = self.model.get("healing", False)
|
self.healing = self.model.get("healing", False)
|
||||||
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
|
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
|
||||||
|
|
|
@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
base = config.rope_theta
|
base = config.rope_theta
|
||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
dim = int(head_dim * partial_rotary_factor)
|
||||||
|
|
||||||
attention_factor = 1.0 # Unused in this type of RoPE
|
attention_factor = 1.0 # Unused in this type of RoPE
|
||||||
|
|
||||||
|
@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
|
||||||
elif config is not None:
|
elif config is not None:
|
||||||
base = config.rope_theta
|
base = config.rope_theta
|
||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
dim = int(head_dim * partial_rotary_factor)
|
||||||
max_position_embeddings = config.max_position_embeddings
|
max_position_embeddings = config.max_position_embeddings
|
||||||
factor = config.rope_scaling["factor"]
|
factor = config.rope_scaling["factor"]
|
||||||
|
|
||||||
attention_factor = 1.0 # Unused in this type of RoPE
|
attention_factor = 1.0 # Unused in this type of RoPE
|
||||||
|
|
||||||
# seq_len: default to max_position_embeddings, e.g. at init time
|
# seq_len: default to max_position_embeddings, e.g. at init time
|
||||||
seq_len = seq_len if seq_len is not None else max_position_embeddings
|
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
|
||||||
|
|
||||||
# Compute the inverse frequencies
|
# Compute the inverse frequencies
|
||||||
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
||||||
|
@ -185,15 +187,33 @@ def _compute_yarn_parameters(
|
||||||
|
|
||||||
base = config.rope_theta
|
base = config.rope_theta
|
||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
dim = config.qk_rope_head_dim
|
head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
dim = int(head_dim * partial_rotary_factor)
|
||||||
max_position_embeddings = config.max_position_embeddings
|
|
||||||
factor = config.rope_scaling["factor"]
|
factor = config.rope_scaling["factor"]
|
||||||
|
attention_factor = config.rope_scaling.get("attention_factor")
|
||||||
|
mscale = config.rope_scaling.get("mscale")
|
||||||
|
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
|
||||||
|
|
||||||
|
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
|
||||||
|
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||||
|
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||||
|
if "original_max_position_embeddings" in config.rope_scaling:
|
||||||
|
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
|
||||||
|
factor = config.max_position_embeddings / original_max_position_embeddings
|
||||||
|
else:
|
||||||
|
original_max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
|
def get_mscale(scale, mscale=1):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
# Sets the attention factor as suggested in the paper
|
# Sets the attention factor as suggested in the paper
|
||||||
attention_factor = config.rope_scaling.get("attention_factor")
|
|
||||||
if attention_factor is None:
|
if attention_factor is None:
|
||||||
attention_factor = 0.1 * math.log(factor) + 1.0
|
if mscale and mscale_all_dim:
|
||||||
|
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
||||||
|
else:
|
||||||
|
attention_factor = get_mscale(factor)
|
||||||
|
|
||||||
# Optional config options
|
# Optional config options
|
||||||
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
||||||
|
@ -211,7 +231,7 @@ def _compute_yarn_parameters(
|
||||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
||||||
return max(low, 0), min(high, dim - 1)
|
return max(low, 0), min(high, dim - 1)
|
||||||
|
|
||||||
def linear_ramp_mask(min, max, dim):
|
def linear_ramp_factor(min, max, dim):
|
||||||
if min == max:
|
if min == max:
|
||||||
max += 0.001 # Prevent singularity
|
max += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
@ -219,16 +239,20 @@ def _compute_yarn_parameters(
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
ramp_func = torch.clamp(linear_func, 0, 1)
|
||||||
return ramp_func
|
return ramp_func
|
||||||
|
|
||||||
|
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
||||||
|
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
||||||
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||||
|
|
||||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
|
||||||
|
|
||||||
# Get n-dimensional rotational scaling corrected for extrapolation
|
# Get n-dimensional rotational scaling corrected for extrapolation
|
||||||
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
|
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
|
||||||
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
inv_freq = (
|
||||||
|
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
||||||
|
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
||||||
|
)
|
||||||
return inv_freq, attention_factor
|
return inv_freq, attention_factor
|
||||||
|
|
||||||
|
|
||||||
|
@ -244,7 +268,7 @@ def _compute_longrope_parameters(
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
The device to use for initialization of the inverse frequencies.
|
The device to use for initialization of the inverse frequencies.
|
||||||
seq_len (`int`, *optional*):
|
seq_len (`int`, *optional*):
|
||||||
The current sequence length. Unused for this type of RoPE.
|
The current sequence length.
|
||||||
rope_kwargs (`Dict`, *optional*):
|
rope_kwargs (`Dict`, *optional*):
|
||||||
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -261,7 +285,8 @@ def _compute_longrope_parameters(
|
||||||
|
|
||||||
base = config.rope_theta
|
base = config.rope_theta
|
||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
dim = int(head_dim * partial_rotary_factor)
|
||||||
long_factor = config.rope_scaling["long_factor"]
|
long_factor = config.rope_scaling["long_factor"]
|
||||||
short_factor = config.rope_scaling["short_factor"]
|
short_factor = config.rope_scaling["short_factor"]
|
||||||
factor = config.rope_scaling.get("factor")
|
factor = config.rope_scaling.get("factor")
|
||||||
|
@ -271,22 +296,20 @@ def _compute_longrope_parameters(
|
||||||
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
||||||
# values to compute the default attention scaling factor, instead of using `factor`.
|
# values to compute the default attention scaling factor, instead of using `factor`.
|
||||||
if hasattr(config, "original_max_position_embeddings"):
|
if hasattr(config, "original_max_position_embeddings"):
|
||||||
max_position_embeddings = config.original_max_position_embeddings
|
original_max_position_embeddings = config.original_max_position_embeddings
|
||||||
expanded_max_position_embeddings = config.max_position_embeddings
|
factor = config.max_position_embeddings / config.original_max_position_embeddings
|
||||||
factor = expanded_max_position_embeddings / max_position_embeddings
|
|
||||||
else:
|
else:
|
||||||
max_position_embeddings = config.max_position_embeddings
|
original_max_position_embeddings = config.max_position_embeddings
|
||||||
expanded_max_position_embeddings = max_position_embeddings * factor
|
|
||||||
|
|
||||||
# Sets the attention factor as suggested in the paper
|
# Sets the attention factor as suggested in the paper
|
||||||
if attention_factor is None:
|
if attention_factor is None:
|
||||||
if factor <= 1.0:
|
if factor <= 1.0:
|
||||||
attention_factor = 1.0
|
attention_factor = 1.0
|
||||||
else:
|
else:
|
||||||
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
|
||||||
|
|
||||||
# Compute the inverse frequencies -- scaled based on the target sequence length
|
# Compute the inverse frequencies -- scaled based on the target sequence length
|
||||||
if expanded_max_position_embeddings > max_position_embeddings:
|
if seq_len and seq_len > original_max_position_embeddings:
|
||||||
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
||||||
else:
|
else:
|
||||||
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
||||||
|
@ -325,19 +348,18 @@ def _compute_llama3_parameters(
|
||||||
|
|
||||||
low_freq_wavelen = old_context_len / low_freq_factor
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
high_freq_wavelen = old_context_len / high_freq_factor
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
new_freqs = []
|
|
||||||
for freq in inv_freq:
|
wavelen = 2 * math.pi / inv_freq
|
||||||
wavelen = 2 * math.pi / freq
|
# wavelen < high_freq_wavelen: do nothing
|
||||||
if wavelen < high_freq_wavelen:
|
# wavelen > low_freq_wavelen: divide by factor
|
||||||
new_freqs.append(freq)
|
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
||||||
elif wavelen > low_freq_wavelen:
|
# otherwise: interpolate between the two, using a smooth factor
|
||||||
new_freqs.append(freq / factor)
|
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
else:
|
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
||||||
assert low_freq_wavelen != high_freq_wavelen
|
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
||||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||||
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
|
|
||||||
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
|
return inv_freq_llama, attention_factor
|
||||||
return inv_freq, attention_factor
|
|
||||||
|
|
||||||
|
|
||||||
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
||||||
|
@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
def _check_received_keys(
|
||||||
|
rope_type: str,
|
||||||
|
received_keys: set,
|
||||||
|
required_keys: set,
|
||||||
|
optional_keys: Optional[set] = None,
|
||||||
|
ignore_keys: Optional[set] = None,
|
||||||
|
):
|
||||||
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
||||||
# BC: "rope_type" was originally "type" -- let's gracefully handle it
|
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
||||||
if "rope_type" not in received_keys and "type" in received_keys:
|
if "type" in received_keys:
|
||||||
received_keys -= {"type"}
|
received_keys -= {"type"}
|
||||||
received_keys.add("rope_type")
|
required_keys.add("rope_type")
|
||||||
|
|
||||||
|
# Some models need to store model-specific keys, and we don't want to throw warning at them
|
||||||
|
if ignore_keys is not None:
|
||||||
|
received_keys -= ignore_keys
|
||||||
|
|
||||||
missing_keys = required_keys - received_keys
|
missing_keys = required_keys - received_keys
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
|
@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
|
||||||
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_default_rope_parameters(config: PretrainedConfig):
|
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type"}
|
required_keys = {"rope_type"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
|
||||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||||
optional_keys = {"original_max_position_embeddings"}
|
optional_keys = {"original_max_position_embeddings"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
optional_keys = {
|
||||||
|
"attention_factor",
|
||||||
|
"beta_fast",
|
||||||
|
"beta_slow",
|
||||||
|
"original_max_position_embeddings",
|
||||||
|
"mscale",
|
||||||
|
"mscale_all_dim",
|
||||||
|
}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
|
@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_longrope_parameters(config: PretrainedConfig):
|
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||||
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
dim = int(head_dim * partial_rotary_factor)
|
||||||
|
|
||||||
short_factor = rope_scaling.get("short_factor")
|
short_factor = rope_scaling.get("short_factor")
|
||||||
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
||||||
|
@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
attention_factor = rope_scaling.get("attention_factor")
|
attention_factor = rope_scaling.get("attention_factor")
|
||||||
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
if attention_factor is not None:
|
||||||
logger.warning(
|
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
||||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
logger.warning(
|
||||||
)
|
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
|
@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
|
||||||
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||||
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
||||||
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||||
if high_freq_factor < low_freq_factor:
|
if high_freq_factor <= low_freq_factor:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
||||||
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
||||||
|
@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def rope_config_validation(config: PretrainedConfig):
|
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
"""
|
"""
|
||||||
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
||||||
"""
|
"""
|
||||||
|
@ -544,7 +585,7 @@ def rope_config_validation(config: PretrainedConfig):
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||||
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
||||||
if validation_fn is not None:
|
if validation_fn is not None:
|
||||||
validation_fn(config)
|
validation_fn(config, ignore_keys=ignore_keys)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue