update rope calculation; update modeling.py; update gate for moe

This commit is contained in:
Azure 2025-02-01 07:32:21 +00:00
parent 5a50b34627
commit f873558a89
11 changed files with 402 additions and 412 deletions

View file

@ -54,4 +54,4 @@ long_context:
token_step:
local_chat:
prompt_file: "./ktransformers/p.txt"
prompt_file: ""

View file

@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM
from ktransformers.models.modeling_llama import LlamaForCausalLM
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
@ -78,7 +78,7 @@ def local_chat():
else:
content += line + "\n"
if content == "":
if config.prompt_file == None or config.prompt_file == "":
if not config.prompt_file:
content = "hi"
else:
content = open(config.prompt_file, "r").read()

View file

@ -14,19 +14,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" DeepSeekV3 model configuration """
"""DeepSeekV3 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 61):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 128):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 128):
@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig):
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts, None means dense model.
Number of shared experts.
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts, None means dense model.
ep_size (`<fill_type>`, *optional*, defaults to 1): <fill_docstring>
Number of routed experts.
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
kv_lora_rank (`<fill_type>`, *optional*, defaults to 512): <fill_docstring>
q_lora_rank (`<fill_type>`, *optional*, defaults to 1536): <fill_docstring>
qk_rope_head_dim (`<fill_type>`, *optional*, defaults to 64): <fill_docstring>
v_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
qk_nope_head_dim (`<fill_type>`, *optional*, defaults to 128): <fill_docstring>
topk_method (`str`, *optional*, defaults to `"noaux_tc"`):
Topk method used in routed gate.
kv_lora_rank (`int`, *optional*, defaults to 512):
Rank of the LoRA matrices for key and value projections.
q_lora_rank (`int`, *optional*, defaults to 1536):
Rank of the LoRA matrices for query projections.
qk_rope_head_dim (`int`, *optional*, defaults to 64):
Dimension of the query/key heads that use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 128):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 128):
Dimension of the query/key heads that don't use rotary position embeddings.
n_group (`int`, *optional*, defaults to 8):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 4):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 3):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to `"sigmoid"`):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
Whether to compute the auxiliary loss for each individual sample.
seq_aux (`<fill_type>`, *optional*, defaults to `True`): <fill_docstring>
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `DeepseekV3Model`
base_model_tp_plan = {
"layers.*.gate_proj": "colwise",
"layers.*.up_proj": "colwise",
"layers.*.down_proj": "rowwise",
}
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size = 2048,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts = 1,
n_routed_experts = 256,
ep_size = 1,
routed_scaling_factor = 2.5,
kv_lora_rank = 512,
q_lora_rank = 1536,
qk_rope_head_dim = 64,
v_head_dim = 128,
qk_nope_head_dim = 128,
topk_method = 'noaux_tc',
n_group = 8,
topk_group = 4,
num_experts_per_tok = 8,
moe_layer_freq = 1,
first_k_dense_replace = 3,
norm_topk_prob = True,
scoring_func = 'sigmoid',
aux_loss_alpha = 0.001,
seq_aux = True,
n_shared_experts=1,
n_routed_experts=256,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
n_group=8,
topk_group=4,
num_experts_per_tok=8,
first_k_dense_replace=3,
norm_topk_prob=True,
aux_loss_alpha=0.001,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig):
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig):
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig):
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
pad_token_id=pad_token_id,

View file

@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len

View file

@ -1,15 +1,13 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py.
# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_deepseekv3.py file directly. One of our CI enforces this.
# modular_deepseek_v3.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import math
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
@ -30,7 +28,7 @@ from transformers.utils import (
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from .configuration_deepseekv3 import DeepseekV3Config
from .configuration_deepseek_v3 import DeepseekV3Config
logger = logging.get_logger(__name__)
@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module):
class DeepseekV3MLP(nn.Module):
def __init__(self, config):
def __init__(self, config, hidden_size=None, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.moe_intermediate_size
# TODO rm hard coding
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias)
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module):
return down_proj
class MoEGate(nn.Module):
class DeepseekV3TopkRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.scoring_func = config.scoring_func
self.seq_aux = config.seq_aux
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
# topk selection algorithm
self.norm_topk_prob = config.norm_topk_prob
self.gating_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
if self.topk_method == "noaux_tc":
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)
if self.scoring_func == "sigmoid":
scores = logits.sigmoid()
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
batch_size, seq_length = hidden_states.shape[:-1]
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
### select top-k experts
if self.topk_method == "noaux_tc":
# assert not self.training
scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
scores = router_logits.sigmoid()
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
.reshape(bsz * seq_len, -1)
.expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group)
.reshape(-1, self.n_routed_experts)
) # [n, e]
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
topk_weight = scores.gather(1, topk_idx)
else:
raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}")
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
return topk_idx, topk_weight
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_indices)
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor
return topk_indices, topk_weights, router_logits
class DeepseekV3MoE(nn.Module):
@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
if hasattr(config, "ep_size") and config.ep_size > 1:
assert config.ep_size == dist.get_world_size()
self.ep_size = config.ep_size
self.experts_per_rank = config.n_routed_experts // config.ep_size
self.ep_rank = dist.get_rank()
self.experts = nn.ModuleList(
[
(
DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)
if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank
else None
)
for i in range(config.n_routed_experts)
for _ in range(config.n_routed_experts)
]
)
else:
self.ep_size = 1
self.experts_per_rank = config.n_routed_experts
self.ep_rank = 0
self.experts = nn.ModuleList(
[
DeepseekV3MLP(config)
for i in range(config.n_routed_experts)
]
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV3MLP(config=config)
self.gate = DeepseekV3TopkRouter(config)
self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size)
def forward(self, hidden_states):
identity = hidden_states
residuals = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
topk_indices, topk_weights, router_logits = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
if not self.training:
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states, router_logits
@torch.no_grad()
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
sorted_tokens_shape = sorted_tokens.shape
if self.ep_size > 1:
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist()
gathered_tokens = sorted_tokens.new_empty(
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
)
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
dist.all_to_all(
list(gathered_tokens.split(output_splits)),
list(sorted_tokens.split(input_split_sizes)),
)
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(
dim=0
)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
gatherd_idxs[s : s + k] = i % self.experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort()
sorted_tokens = gathered_tokens[gatherd_idxs]
tokens_per_expert = tokens_per_expert_post_gather
tokens_per_expert = tokens_per_expert.cpu().numpy()
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
expert_mask = expert_mask.permute(2, 0, 1)
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
for expert_idx in range(len(self.experts)):
expert = self.experts[expert_idx]
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
if self.ep_size > 1:
new_x = torch.empty_like(outs)
new_x[gatherd_idxs] = outs
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
dist.all_to_all(
list(gathered_tokens.split(input_split_sizes)),
list(new_x.split(output_splits)),
)
outs = gathered_tokens
if token_indices.numel() > 0:
expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, weighted_output)
return final_hidden_states.type(hidden_states.dtype)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -359,150 +292,94 @@ def eager_attention_forward(
return attn_output, attn_weights
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
class DeepseekV3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):
def __init__(self, config: DeepseekV3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim
self.is_causal = True
if self.q_lora_rank is None:
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
config.kv_lora_rank + config.qk_rope_head_dim,
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)
self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
config.kv_lora_rank,
self.kv_lora_rank,
self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
config.hidden_size,
bias=config.attention_bias,
)
self.rotary_emb = DeepseekV3RotaryEmbedding(
config=self.config,
)
self.scaling = self.q_head_dim ** (-0.5)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scaling = self.scaling * mscale * mscale
# TODO apply in DeepSeekV3Model to share accrose layers
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs# : Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, self.num_heads, -1)
if self.q_lora_rank is None:
q = self.q_proj(hidden_states)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2)
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2)
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
cos, sin = position_embeddings
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
k_rot = k_rot.expand(-1, self.num_heads, -1, -1)
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)
if self.q_head_dim != self.v_head_dim:
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
if past_key_value is not None:
@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module):
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
pass
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
raise NotImplementedError(
f"Attention implementation {self.config._attn_implementation} is not supported. "
"Please use 'eager' or 'sdpa'."
)
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module):
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module):
self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx)
self.mlp = (
DeepseekV3MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV3MLP(config)
)
if layer_idx >= config.first_k_dense_replace:
self.mlp = DeepseekV3MoE(config)
else:
self.mlp = DeepseekV3MLP(config)
self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module):
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple):
hidden_states, router_logits = hidden_states
else:
router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if output_router_logits:
outputs += (router_logits,)
return outputs
DEEPSEEKV3_START_DOCSTRING = r"""
DEEPSEEK_V3_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r"""
@add_start_docstrings(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
DEEPSEEKV3_START_DOCSTRING,
DEEPSEEK_V3_START_DOCSTRING,
)
class DeepseekV3PreTrainedModel(PreTrainedModel):
config_class = DeepseekV3Config
@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
DEEPSEEKV3_INPUTS_DOCSTRING = r"""
DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r"""
@add_start_docstrings(
"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.",
DEEPSEEKV3_START_DOCSTRING,
DEEPSEEK_V3_START_DOCSTRING,
)
class DeepseekV3Model(DeepseekV3PreTrainedModel):
"""
@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
config: DeepseekV3Config
"""
def __init__(self, config: DeepseekV3Config):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = DeepseekV3RotaryEmbedding(config=config)
self.gradient_checkpointing = False
self._register_load_state_dict_pre_hook(self.load_hook)
# Initialize weights and apply final processing
self.post_init()
@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask
def load_hook(self, state_dict, prefix, *args):
"""
Weights have to be permuted for correct rope formulation. We can't do this in the weights
as every other framework already uses the `Llama` original function (which is copyrighted btw).
And I am not even sure it's better.... anyways end of my rant
"""
def permute_for_rope(input_tensor):
"""
When you go from the complex ROPE formulation to sin and cos one, you need
to permute the query and key weights (to avoid doing it on the fly)
"""
n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2]
input_tensor = input_tensor.reshape(n_heads * dim1, dim2)
input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2)
input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2)
return input_tensor
def permute_layer_for_rope(key, num_heads, head_dim, rope_dim):
weight = state_dict[key]
weight = weight.view(num_heads, head_dim, -1)
weight_rot = weight[:, -rope_dim:]
weight_rot = permute_for_rope(weight_rot)
weight[:, -rope_dim:] = weight_rot
weight = weight.view(-1, weight.shape[-1])
state_dict[key] = weight
for k in state_dict:
if "q_b_proj." in k:
permute_layer_for_rope(
k,
num_heads=self.config.num_attention_heads,
head_dim=self.config.q_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
if "kv_a_proj_with_mqa." in k:
permute_layer_for_rope(
k,
num_heads=1,
head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim,
rope_dim=self.config.qk_rope_head_dim,
)
# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
return self.model
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
```python
>>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf")
>>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
DEEPSEEKV3_START_DOCSTRING,
DEEPSEEK_V3_START_DOCSTRING,
)
class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
def __init__(self, config):
@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING)
@add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@ -1214,3 +1146,11 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
__all__ = [
"DeepseekV3PreTrainedModel",
"DeepseekV3Model",
"DeepseekV3ForCausalLM",
"DeepseekV3ForSequenceClassification",
]

View file

@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention
from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models

View file

@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight= self.gate(hidden_states)
topk_idx, topk_weight, router_logits= self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing():
self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])
if self.config.n_shared_experts is not None:
@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
return y, router_logits
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0)
@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
)
if self.config.n_shared_experts is not None:
y += y_
return y
return y, router_logits
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:

View file

@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.models.modeling_deepseekv3 import MoEGate
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
from ktransformers.util.utils import InferenceState
from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN
@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device)
if self.topk_method == "noaux_tc":
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
def unload(self):
if self.weight is not None:
self.weight = None
if self.topk_method == "noaux_tc":
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None

View file

@ -47,7 +47,7 @@
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
@ -55,7 +55,7 @@
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
@ -64,7 +64,7 @@
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseekv3.MoEGate
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
@ -72,7 +72,7 @@
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseekv3.MoEGate
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:

View file

@ -102,7 +102,7 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.max_new_tokens = self.model.get("max_new_tokens", 500)
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)

View file

@ -58,7 +58,8 @@ def _compute_default_rope_parameters(
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE
@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters(
elif config is not None:
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
max_position_embeddings = config.max_position_embeddings
factor = config.rope_scaling["factor"]
attention_factor = 1.0 # Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len = seq_len if seq_len is not None else max_position_embeddings
seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
# Compute the inverse frequencies
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
@ -185,15 +187,33 @@ def _compute_yarn_parameters(
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = config.qk_rope_head_dim
max_position_embeddings = config.max_position_embeddings
head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
factor = config.rope_scaling["factor"]
attention_factor = config.rope_scaling.get("attention_factor")
mscale = config.rope_scaling.get("mscale")
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
# NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if "original_max_position_embeddings" in config.rope_scaling:
original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"]
factor = config.max_position_embeddings / original_max_position_embeddings
else:
original_max_position_embeddings = config.max_position_embeddings
def get_mscale(scale, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# Sets the attention factor as suggested in the paper
attention_factor = config.rope_scaling.get("attention_factor")
if attention_factor is None:
attention_factor = 0.1 * math.log(factor) + 1.0
if mscale and mscale_all_dim:
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
else:
attention_factor = get_mscale(factor)
# Optional config options
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
@ -211,7 +231,7 @@ def _compute_yarn_parameters(
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1)
def linear_ramp_mask(min, max, dim):
def linear_ramp_factor(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
@ -219,16 +239,20 @@ def _compute_yarn_parameters(
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device)
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
)
return inv_freq, attention_factor
@ -244,7 +268,7 @@ def _compute_longrope_parameters(
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
The current sequence length.
rope_kwargs (`Dict`, *optional*):
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
Returns:
@ -261,7 +285,8 @@ def _compute_longrope_parameters(
base = config.rope_theta
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
long_factor = config.rope_scaling["long_factor"]
short_factor = config.rope_scaling["short_factor"]
factor = config.rope_scaling.get("factor")
@ -271,22 +296,20 @@ def _compute_longrope_parameters(
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
# values to compute the default attention scaling factor, instead of using `factor`.
if hasattr(config, "original_max_position_embeddings"):
max_position_embeddings = config.original_max_position_embeddings
expanded_max_position_embeddings = config.max_position_embeddings
factor = expanded_max_position_embeddings / max_position_embeddings
original_max_position_embeddings = config.original_max_position_embeddings
factor = config.max_position_embeddings / config.original_max_position_embeddings
else:
max_position_embeddings = config.max_position_embeddings
expanded_max_position_embeddings = max_position_embeddings * factor
original_max_position_embeddings = config.max_position_embeddings
# Sets the attention factor as suggested in the paper
if attention_factor is None:
if factor <= 1.0:
attention_factor = 1.0
else:
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
# Compute the inverse frequencies -- scaled based on the target sequence length
if expanded_max_position_embeddings > max_position_embeddings:
if seq_len and seq_len > original_max_position_embeddings:
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
else:
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
@ -325,19 +348,18 @@ def _compute_llama3_parameters(
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device)
return inv_freq, attention_factor
wavelen = 2 * math.pi / inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
return inv_freq_llama, attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = {
}
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it
if "rope_type" not in received_keys and "type" in received_keys:
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "type" in received_keys:
received_keys -= {"type"}
received_keys.add("rope_type")
required_keys.add("rope_type")
# Some models need to store model-specific keys, and we don't want to throw warning at them
if ignore_keys is not None:
received_keys -= ignore_keys
missing_keys = required_keys - received_keys
if missing_keys:
@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig):
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig):
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
optional_keys = {
"attention_factor",
"beta_fast",
"beta_slow",
"original_max_position_embeddings",
"mscale",
"mscale_all_dim",
}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig):
)
def _validate_longrope_parameters(config: PretrainedConfig):
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
if attention_factor is not None:
if not isinstance(attention_factor, float) or attention_factor < 0.0:
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig):
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
if high_freq_factor <= low_freq_factor:
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = {
}
def rope_config_validation(config: PretrainedConfig):
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
"""
@ -544,7 +585,7 @@ def rope_config_validation(config: PretrainedConfig):
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config)
validation_fn(config, ignore_keys=ignore_keys)
else:
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"