diff --git a/ktransformers/configs/config.yaml b/ktransformers/configs/config.yaml index 7bde376..80de09a 100644 --- a/ktransformers/configs/config.yaml +++ b/ktransformers/configs/config.yaml @@ -54,4 +54,4 @@ long_context: token_step: local_chat: - prompt_file: "./ktransformers/p.txt" \ No newline at end of file + prompt_file: "" \ No newline at end of file diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index a924a1d..f16ee7f 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -15,7 +15,7 @@ from ktransformers.server.args import ArgumentParser from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM -from ktransformers.models.modeling_deepseekv3 import DeepseekV3ForCausalLM +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM @@ -78,7 +78,7 @@ def local_chat(): else: content += line + "\n" if content == "": - if config.prompt_file == None or config.prompt_file == "": + if not config.prompt_file: content = "hi" else: content = open(config.prompt_file, "r").read() diff --git a/ktransformers/models/configuration_deepseekv3.py b/ktransformers/models/configuration_deepseek_v3.py similarity index 81% rename from ktransformers/models/configuration_deepseekv3.py rename to ktransformers/models/configuration_deepseek_v3.py index 5c599b3..6227092 100644 --- a/ktransformers/models/configuration_deepseekv3.py +++ b/ktransformers/models/configuration_deepseek_v3.py @@ -14,19 +14,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" DeepSeekV3 model configuration """ +"""DeepSeekV3 model configuration""" from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + class DeepseekV3Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + + Args: vocab_size (`int`, *optional*, defaults to 129280): Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the @@ -39,8 +45,6 @@ class DeepseekV3Config(PretrainedConfig): Dimension of the MoE representations. num_hidden_layers (`int`, *optional*, defaults to 61): Number of hidden layers in the Transformer decoder. - num_nextn_predict_layers (`int`, *optional*, defaults to 1): - Number of nextn predict layers in the DeepSeekV3 Model. num_attention_heads (`int`, *optional*, defaults to 128): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*, defaults to 128): @@ -52,38 +56,35 @@ class DeepseekV3Config(PretrainedConfig): paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `num_attention_heads`. n_shared_experts (`int`, *optional*, defaults to 1): - Number of shared experts, None means dense model. + Number of shared experts. n_routed_experts (`int`, *optional*, defaults to 256): - Number of routed experts, None means dense model. - ep_size (``, *optional*, defaults to 1): + Number of routed experts. routed_scaling_factor (`float`, *optional*, defaults to 2.5): Scaling factor or routed experts. - kv_lora_rank (``, *optional*, defaults to 512): - q_lora_rank (``, *optional*, defaults to 1536): - qk_rope_head_dim (``, *optional*, defaults to 64): - v_head_dim (``, *optional*, defaults to 128): - qk_nope_head_dim (``, *optional*, defaults to 128): - topk_method (`str`, *optional*, defaults to `"noaux_tc"`): - Topk method used in routed gate. + kv_lora_rank (`int`, *optional*, defaults to 512): + Rank of the LoRA matrices for key and value projections. + q_lora_rank (`int`, *optional*, defaults to 1536): + Rank of the LoRA matrices for query projections. + qk_rope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the query/key heads that use rotary position embeddings. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of the value heads. + qk_nope_head_dim (`int`, *optional*, defaults to 128): + Dimension of the query/key heads that don't use rotary position embeddings. n_group (`int`, *optional*, defaults to 8): Number of groups for routed experts. topk_group (`int`, *optional*, defaults to 4): Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). num_experts_per_tok (`int`, *optional*, defaults to 8): Number of selected experts, None means dense model. - moe_layer_freq (`int`, *optional*, defaults to 1): - The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. first_k_dense_replace (`int`, *optional*, defaults to 3): Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). \--k dense layers--/ norm_topk_prob (`bool`, *optional*, defaults to `True`): Whether to normalize the weights of the routed experts. - scoring_func (`str`, *optional*, defaults to `"sigmoid"`): - Method of computing expert weights. aux_loss_alpha (`float`, *optional*, defaults to 0.001): Auxiliary loss weight coefficient. Whether to compute the auxiliary loss for each individual sample. - seq_aux (``, *optional*, defaults to `True`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 4096): @@ -119,46 +120,49 @@ class DeepseekV3Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + ```python >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration >>> configuration = model.config ```""" model_type = "deepseek_v3" keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `DeepseekV3Model` + base_model_tp_plan = { + "layers.*.gate_proj": "colwise", + "layers.*.up_proj": "colwise", + "layers.*.down_proj": "rowwise", + } def __init__( self, vocab_size=129280, hidden_size=7168, intermediate_size=18432, - moe_intermediate_size = 2048, + moe_intermediate_size=2048, num_hidden_layers=61, - num_nextn_predict_layers=1, num_attention_heads=128, num_key_value_heads=128, - n_shared_experts = 1, - n_routed_experts = 256, - ep_size = 1, - routed_scaling_factor = 2.5, - kv_lora_rank = 512, - q_lora_rank = 1536, - qk_rope_head_dim = 64, - v_head_dim = 128, - qk_nope_head_dim = 128, - topk_method = 'noaux_tc', - n_group = 8, - topk_group = 4, - num_experts_per_tok = 8, - moe_layer_freq = 1, - first_k_dense_replace = 3, - norm_topk_prob = True, - scoring_func = 'sigmoid', - aux_loss_alpha = 0.001, - seq_aux = True, + n_shared_experts=1, + n_routed_experts=256, + routed_scaling_factor=2.5, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + n_group=8, + topk_group=4, + num_experts_per_tok=8, + first_k_dense_replace=3, + norm_topk_prob=True, + aux_loss_alpha=0.001, hidden_act="silu", max_position_embeddings=4096, initializer_range=0.02, @@ -173,7 +177,6 @@ class DeepseekV3Config(PretrainedConfig): rope_scaling=None, attention_bias=False, attention_dropout=0.0, - mlp_bias=False, **kwargs, ): self.vocab_size = vocab_size @@ -182,27 +185,24 @@ class DeepseekV3Config(PretrainedConfig): self.intermediate_size = intermediate_size self.moe_intermediate_size = moe_intermediate_size self.num_hidden_layers = num_hidden_layers - self.num_nextn_predict_layers = num_nextn_predict_layers self.num_attention_heads = num_attention_heads self.n_shared_experts = n_shared_experts self.n_routed_experts = n_routed_experts - self.ep_size = ep_size self.routed_scaling_factor = routed_scaling_factor self.kv_lora_rank = kv_lora_rank self.q_lora_rank = q_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.v_head_dim = v_head_dim self.qk_nope_head_dim = qk_nope_head_dim - self.topk_method = topk_method + self.q_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_rope_head_dim self.n_group = n_group self.topk_group = topk_group self.num_experts_per_tok = num_experts_per_tok - self.moe_layer_freq = moe_layer_freq self.first_k_dense_replace = first_k_dense_replace self.norm_topk_prob = norm_topk_prob - self.scoring_func = scoring_func self.aux_loss_alpha = aux_loss_alpha - self.seq_aux = seq_aux + # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -217,7 +217,11 @@ class DeepseekV3Config(PretrainedConfig): self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, copy it it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) super().__init__( pad_token_id=pad_token_id, diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index c85c7bb..e402506 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -135,3 +135,7 @@ class StaticCache(transformers.StaticCache): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() + + def get_max_cache_shape(self) -> Tuple[int, int, int, int]: + """Returns the maximum shape of the cache.""" + return self.max_cache_len \ No newline at end of file diff --git a/ktransformers/models/modeling_deepseekv3.py b/ktransformers/models/modeling_deepseek_v3.py similarity index 77% rename from ktransformers/models/modeling_deepseekv3.py rename to ktransformers/models/modeling_deepseek_v3.py index d8a888c..8eb9b9c 100644 --- a/ktransformers/models/modeling_deepseekv3.py +++ b/ktransformers/models/modeling_deepseek_v3.py @@ -1,15 +1,13 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/deepseekv3/modular_deepseekv3.py. +# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the -# modular_deepseekv3.py file directly. One of our CI enforces this. +# modular_deepseek_v3.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn.functional as F from torch import nn @@ -30,7 +28,7 @@ from transformers.utils import ( replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg -from .configuration_deepseekv3 import DeepseekV3Config +from .configuration_deepseek_v3 import DeepseekV3Config logger = logging.get_logger(__name__) @@ -119,15 +117,15 @@ class DeepseekV3RotaryEmbedding(nn.Module): class DeepseekV3MLP(nn.Module): - def __init__(self, config): + def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.moe_intermediate_size - # TODO rm hard coding - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)# config.mlp_bias) + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): @@ -135,70 +133,46 @@ class DeepseekV3MLP(nn.Module): return down_proj -class MoEGate(nn.Module): +class DeepseekV3TopkRouter(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.scoring_func = config.scoring_func - self.seq_aux = config.seq_aux - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group - # topk selection algorithm - self.norm_topk_prob = config.norm_topk_prob - self.gating_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) - if self.topk_method == "noaux_tc": - self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) - self.reset_parameters() - - def reset_parameters(self) -> None: - import torch.nn.init as init - - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts))) def forward(self, hidden_states): - bsz, seq_len, h = hidden_states.shape - ### compute gating score - hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) - if self.scoring_func == "sigmoid": - scores = logits.sigmoid() - else: - raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") + batch_size, seq_length = hidden_states.shape[:-1] + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - ### select top-k experts - if self.topk_method == "noaux_tc": - # assert not self.training - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) - .reshape(bsz * seq_len, -1) - ) # [n, e] - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] - _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) - topk_weight = scores.gather(1, topk_idx) - else: - raise NotImplementedError(f"insupportable TopK function for MoE gating: {self.topk_method}") - - ### norm gate to sum 1 - if self.top_k > 1 and self.norm_topk_prob: - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 - topk_weight = topk_weight / denominator - topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor - - return topk_idx, topk_weight + scores = router_logits.sigmoid() + scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(batch_size * seq_length, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) # [n, e] + scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False) + topk_weights = scores.gather(1, topk_indices) + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor # must multiply the scaling factor + return topk_indices, topk_weights, router_logits class DeepseekV3MoE(nn.Module): @@ -209,116 +183,75 @@ class DeepseekV3MoE(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.num_experts_per_tok = config.num_experts_per_tok - - if hasattr(config, "ep_size") and config.ep_size > 1: - assert config.ep_size == dist.get_world_size() - self.ep_size = config.ep_size - self.experts_per_rank = config.n_routed_experts // config.ep_size - self.ep_rank = dist.get_rank() - self.experts = nn.ModuleList( - [ - ( - DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) - if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank - else None - ) - for i in range(config.n_routed_experts) - ] - ) - else: - self.ep_size = 1 - self.experts_per_rank = config.n_routed_experts - self.ep_rank = 0 - self.experts = nn.ModuleList( - [ - DeepseekV3MLP(config) - for i in range(config.n_routed_experts) - ] - ) - self.gate = MoEGate(config) - if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV3MLP(config=config) + self.experts = nn.ModuleList( + [ + DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size) + for _ in range(config.n_routed_experts) + ] + ) + self.gate = DeepseekV3TopkRouter(config) + self.shared_experts = DeepseekV3MLP(config=config, intermediate_size=config.moe_intermediate_size) def forward(self, hidden_states): - identity = hidden_states + residuals = hidden_states orig_shape = hidden_states.shape - topk_idx, topk_weight = self.gate(hidden_states) + topk_indices, topk_weights, router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - if not self.training: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y + hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states, router_logits - @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - idxs = topk_ids.view(-1).argsort() - sorted_tokens = x[idxs // topk_ids.shape[1]] - sorted_tokens_shape = sorted_tokens.shape - if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) - tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0]) - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) - output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).cpu().numpy().tolist() - gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] - ) - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() - dist.all_to_all( - list(gathered_tokens.split(output_splits)), - list(sorted_tokens.split(input_split_sizes)), - ) - tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum( - dim=0 - ) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) - s = 0 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): - gatherd_idxs[s : s + k] = i % self.experts_per_rank - s += k - gatherd_idxs = gatherd_idxs.argsort() - sorted_tokens = gathered_tokens[gatherd_idxs] - tokens_per_expert = tokens_per_expert_post_gather - tokens_per_expert = tokens_per_expert.cpu().numpy() + def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): + final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) + expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts)) + expert_mask = expert_mask.permute(2, 0, 1) - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens - if num_tokens == 0: - continue - expert = self.experts[i + self.ep_rank * self.experts_per_rank] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out) - start_idx = end_idx + for expert_idx in range(len(self.experts)): + expert = self.experts[expert_idx] + mask = expert_mask[expert_idx] + token_indices, weight_indices = torch.where(mask) - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - if self.ep_size > 1: - new_x = torch.empty_like(outs) - new_x[gatherd_idxs] = outs - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) - dist.all_to_all( - list(gathered_tokens.split(input_split_sizes)), - list(new_x.split(output_splits)), - ) - outs = gathered_tokens + if token_indices.numel() > 0: + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, weighted_output) + return final_hidden_states.type(hidden_states.dtype) - new_x = torch.empty_like(outs) - new_x[idxs] = outs - final_out = ( - new_x.view(*topk_ids.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) - ) - return final_out + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -359,150 +292,94 @@ def eager_attention_forward( return attn_output, attn_weights -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class DeepseekV3Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None): + def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - - self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim self.is_causal = True - - if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) - else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( - self.hidden_size, - config.kv_lora_rank + config.qk_rope_head_dim, + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) - self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( - config.kv_lora_rank, + self.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, - self.hidden_size, + config.hidden_size, bias=config.attention_bias, ) - self.rotary_emb = DeepseekV3RotaryEmbedding( - config=self.config, - ) + self.scaling = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + # TODO apply in DeepSeekV3Model to share accrose layers + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs# : Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, self.num_heads, -1) - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(hidden_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(hidden_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + k_rot = k_rot.view(*input_shape, 1, self.qk_rope_head_dim).transpose(1, 2) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + cos, sin = position_embeddings + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(-1, self.num_heads, -1, -1) - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) - if self.q_head_dim != self.v_head_dim: + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) if past_key_value is not None: @@ -518,8 +395,11 @@ class DeepseekV3Attention(nn.Module): 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: - pass - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + raise NotImplementedError( + f"Attention implementation {self.config._attn_implementation} is not supported. " + "Please use 'eager' or 'sdpa'." + ) + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -531,9 +411,12 @@ class DeepseekV3Attention(nn.Module): scaling=self.scaling, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output) + if self.config._attn_implementation == "flash_attention_2" and self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) return attn_output, attn_weights @@ -544,15 +427,11 @@ class DeepseekV3DecoderLayer(nn.Module): self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) - self.mlp = ( - DeepseekV3MoE(config) - if ( - config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0 - ) - else DeepseekV3MLP(config) - ) + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -563,6 +442,7 @@ class DeepseekV3DecoderLayer(nn.Module): position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -590,16 +470,24 @@ class DeepseekV3DecoderLayer(nn.Module): residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = (torch.zeros((1,), device=hidden_states.device, dtype=torch.int64),) + hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) + if output_router_logits: + outputs += (router_logits,) return outputs -DEEPSEEKV3_START_DOCSTRING = r""" +DEEPSEEK_V3_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) @@ -618,7 +506,7 @@ DEEPSEEKV3_START_DOCSTRING = r""" @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DEEPSEEKV3_START_DOCSTRING, + DEEPSEEK_V3_START_DOCSTRING, ) class DeepseekV3PreTrainedModel(PreTrainedModel): config_class = DeepseekV3Config @@ -646,7 +534,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): module.weight.data[module.padding_idx].zero_() -DEEPSEEKV3_INPUTS_DOCSTRING = r""" +DEEPSEEK_V3_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide @@ -723,7 +611,7 @@ DEEPSEEKV3_INPUTS_DOCSTRING = r""" @add_start_docstrings( "The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.", - DEEPSEEKV3_START_DOCSTRING, + DEEPSEEK_V3_START_DOCSTRING, ) class DeepseekV3Model(DeepseekV3PreTrainedModel): """ @@ -733,7 +621,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): config: DeepseekV3Config """ - def __init__(self, config: DeepseekV3Config): + def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -745,6 +633,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) self.gradient_checkpointing = False + self._register_load_state_dict_pre_hook(self.load_hook) # Initialize weights and apply final processing self.post_init() @@ -755,7 +644,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.embed_tokens = value - @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -983,6 +872,49 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): return causal_mask + def load_hook(self, state_dict, prefix, *args): + """ + Weights have to be permuted for correct rope formulation. We can't do this in the weights + as every other framework already uses the `Llama` original function (which is copyrighted btw). + And I am not even sure it's better.... anyways end of my rant + """ + + def permute_for_rope(input_tensor): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + n_heads, dim1, dim2 = input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2] + input_tensor = input_tensor.reshape(n_heads * dim1, dim2) + input_tensor = input_tensor.view(n_heads, dim1 // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(n_heads, dim1, dim2) + return input_tensor + + def permute_layer_for_rope(key, num_heads, head_dim, rope_dim): + weight = state_dict[key] + weight = weight.view(num_heads, head_dim, -1) + weight_rot = weight[:, -rope_dim:] + weight_rot = permute_for_rope(weight_rot) + weight[:, -rope_dim:] = weight_rot + weight = weight.view(-1, weight.shape[-1]) + state_dict[key] = weight + + for k in state_dict: + if "q_b_proj." in k: + permute_layer_for_rope( + k, + num_heads=self.config.num_attention_heads, + head_dim=self.config.q_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + if "kv_a_proj_with_mqa." in k: + permute_layer_for_rope( + k, + num_heads=1, + head_dim=self.config.kv_lora_rank + self.config.qk_rope_head_dim, + rope_dim=self.config.qk_rope_head_dim, + ) + # class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1019,7 +951,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): return self.model @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") - @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, @@ -1058,8 +990,8 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): ```python >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseekv3/DeepseekV3-2-7b-hf") + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1125,7 +1057,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, - DEEPSEEKV3_START_DOCSTRING, + DEEPSEEK_V3_START_DOCSTRING, ) class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def __init__(self, config): @@ -1143,7 +1075,7 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): def set_input_embeddings(self, value): self.model.embed_tokens = value - @add_start_docstrings_to_model_forward(DEEPSEEKV3_INPUTS_DOCSTRING) + @add_start_docstrings_to_model_forward(DEEPSEEK_V3_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1213,4 +1145,12 @@ class DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel): past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) \ No newline at end of file + ) + + +__all__ = [ + "DeepseekV3PreTrainedModel", + "DeepseekV3Model", + "DeepseekV3ForCausalLM", + "DeepseekV3ForSequenceClassification", +] \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index b3b1802..f98bfff 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -13,7 +13,8 @@ from ktransformers.models.configuration_deepseek import DeepseekV2Config from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.modeling_llama import LlamaRotaryEmbedding from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb -from ktransformers.models.modeling_deepseekv3 import DeepseekV3Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3Attention +from ktransformers.models.modeling_deepseek_v3 import apply_rotary_pos_emb as apply_rotary_pos_emb_v3 from typing import Optional, Tuple from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader @@ -95,7 +96,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(q_pe, position_ids) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin) + q_pe, k_pe = apply_rotary_pos_emb_v3(q_pe, k_pe, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index ddfcda9..03a1488 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -519,7 +519,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): from ktransformers.models.modeling_deepseek import DeepseekV2MoE -from ktransformers.models.modeling_deepseekv3 import DeepseekV3MoE +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock @@ -734,9 +734,10 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): identity = hidden_states orig_shape = hidden_states.shape sequence_length = orig_shape[1] - topk_idx, topk_weight= self.gate(hidden_states) + topk_idx, topk_weight, router_logits= self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + # only for generate phase if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0]) if self.config.n_shared_experts is not None: @@ -744,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y += y_ y.resize_(*orig_shape) - return y + return y, router_logits if self.config.n_shared_experts is not None: y_ = self.shared_experts(identity).squeeze(0) @@ -767,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ) if self.config.n_shared_experts is not None: y += y_ - return y + return y, router_logits @torch.no_grad() def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 91a3872..dcf45cb 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -16,7 +16,7 @@ from cpuinfer_ext.moe import MOEConfig, MOE import ctypes from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader -from ktransformers.models.modeling_deepseekv3 import MoEGate +from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter from ktransformers.util.utils import InferenceState from ktransformers.server.config.config import Config from transformers.activations import ACT2FN @@ -118,11 +118,10 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): else: raise ValueError("Invalid weight type") self.orig_module.weight = self.orig_module.weight.to(device) - if self.topk_method == "noaux_tc": - self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) + self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device) def unload(self): if self.weight is not None: self.weight = None - if self.topk_method == "noaux_tc": + if self.e_score_correction_bias is not None: self.e_score_correction_bias = None diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml index 3fd86d9..7135933 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml @@ -47,7 +47,7 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$" - class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: @@ -55,7 +55,7 @@ prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp$" - class: ktransformers.models.modeling_deepseekv3.DeepseekV3MoE + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE replace: class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function kwargs: @@ -64,7 +64,7 @@ - match: name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" - class: ktransformers.models.modeling_deepseekv3.MoEGate + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter replace: class: ktransformers.operators.gate.KMoEGate kwargs: @@ -72,7 +72,7 @@ prefill_device: "cuda:0" - match: name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" - class: ktransformers.models.modeling_deepseekv3.MoEGate + class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter replace: class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function kwargs: diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 27b788f..cf5c0ef 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -102,7 +102,7 @@ class Config(metaclass=Singleton): self.total_context = self.model.get("total_context", 2**18) self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1) self.max_chunk_size = self.model.get("max_chunk_size", 2048) - self.max_new_tokens = self.model.get("max_new_tokens", 500) + self.max_new_tokens = self.model.get("max_new_tokens", 2000) self.json_mode = self.model.get("json_mode", False) self.healing = self.model.get("healing", False) self.ban_strings: Optional[list] = self.model.get("ban_strings", None) diff --git a/ktransformers/util/modeling_rope_utils.py b/ktransformers/util/modeling_rope_utils.py index 2598a52..4fec4bc 100644 --- a/ktransformers/util/modeling_rope_utils.py +++ b/ktransformers/util/modeling_rope_utils.py @@ -58,7 +58,8 @@ def _compute_default_rope_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) attention_factor = 1.0 # Unused in this type of RoPE @@ -143,14 +144,15 @@ def _compute_dynamic_ntk_parameters( elif config is not None: base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings factor = config.rope_scaling["factor"] attention_factor = 1.0 # Unused in this type of RoPE # seq_len: default to max_position_embeddings, e.g. at init time - seq_len = seq_len if seq_len is not None else max_position_embeddings + seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) @@ -185,15 +187,33 @@ def _compute_yarn_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = config.qk_rope_head_dim - - max_position_embeddings = config.max_position_embeddings + head_dim = getattr(config, "qk_rope_head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) factor = config.rope_scaling["factor"] + attention_factor = config.rope_scaling.get("attention_factor") + mscale = config.rope_scaling.get("mscale") + mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + + # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if "original_max_position_embeddings" in config.rope_scaling: + original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] + factor = config.max_position_embeddings / original_max_position_embeddings + else: + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 # Sets the attention factor as suggested in the paper - attention_factor = config.rope_scaling.get("attention_factor") if attention_factor is None: - attention_factor = 0.1 * math.log(factor) + 1.0 + if mscale and mscale_all_dim: + attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) # Optional config options # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) @@ -211,7 +231,7 @@ def _compute_yarn_parameters( high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) - def linear_ramp_mask(min, max, dim): + def linear_ramp_factor(min, max, dim): if min == max: max += 0.001 # Prevent singularity @@ -219,16 +239,20 @@ def _compute_yarn_parameters( ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) - low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_mask = 1 - linear_ramp_mask(low, high, dim // 2).float().to(device) - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) return inv_freq, attention_factor @@ -244,7 +268,7 @@ def _compute_longrope_parameters( device (`torch.device`): The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. + The current sequence length. rope_kwargs (`Dict`, *optional*): BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. Returns: @@ -261,7 +285,8 @@ def _compute_longrope_parameters( base = config.rope_theta partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) long_factor = config.rope_scaling["long_factor"] short_factor = config.rope_scaling["short_factor"] factor = config.rope_scaling.get("factor") @@ -271,22 +296,20 @@ def _compute_longrope_parameters( # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. if hasattr(config, "original_max_position_embeddings"): - max_position_embeddings = config.original_max_position_embeddings - expanded_max_position_embeddings = config.max_position_embeddings - factor = expanded_max_position_embeddings / max_position_embeddings + original_max_position_embeddings = config.original_max_position_embeddings + factor = config.max_position_embeddings / config.original_max_position_embeddings else: - max_position_embeddings = config.max_position_embeddings - expanded_max_position_embeddings = max_position_embeddings * factor + original_max_position_embeddings = config.max_position_embeddings # Sets the attention factor as suggested in the paper if attention_factor is None: if factor <= 1.0: attention_factor = 1.0 else: - attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)) + attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) # Compute the inverse frequencies -- scaled based on the target sequence length - if expanded_max_position_embeddings > max_position_embeddings: + if seq_len and seq_len > original_max_position_embeddings: ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) else: ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) @@ -325,19 +348,18 @@ def _compute_llama3_parameters( low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in inv_freq: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - new_freqs.append((1 - smooth) * freq / factor + smooth * freq) - inv_freq = torch.tensor(new_freqs, dtype=inv_freq.dtype, device=inv_freq.device) - return inv_freq, attention_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters @@ -353,12 +375,22 @@ ROPE_INIT_FUNCTIONS = { } -def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None): +def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: Optional[set] = None, + ignore_keys: Optional[set] = None, +): """Compare the received keys in `config.rope_scaling` against the expected and optional keys""" - # BC: "rope_type" was originally "type" -- let's gracefully handle it - if "rope_type" not in received_keys and "type" in received_keys: + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: received_keys -= {"type"} - received_keys.add("rope_type") + required_keys.add("rope_type") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys missing_keys = required_keys - received_keys if missing_keys: @@ -372,47 +404,54 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}") -def _validate_default_rope_parameters(config: PretrainedConfig): +def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) -def _validate_linear_scaling_rope_parameters(config: PretrainedConfig): +def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") -def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig): +def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") -def _validate_yarn_parameters(config: PretrainedConfig): +def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor"} - optional_keys = {"attention_factor", "beta_fast", "beta_slow"} + optional_keys = { + "attention_factor", + "beta_fast", + "beta_slow", + "original_max_position_embeddings", + "mscale", + "mscale_all_dim", + } received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: @@ -437,17 +476,18 @@ def _validate_yarn_parameters(config: PretrainedConfig): ) -def _validate_longrope_parameters(config: PretrainedConfig): +def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "short_factor", "long_factor"} # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - dim = int((config.hidden_size // config.num_attention_heads) * partial_rotary_factor) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) short_factor = rope_scaling.get("short_factor") if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): @@ -479,18 +519,19 @@ def _validate_longrope_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0: - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) -def _validate_llama3_parameters(config: PretrainedConfig): +def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): rope_scaling = config.rope_scaling rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) factor = rope_scaling["factor"] if factor is None or not isinstance(factor, float) or factor < 1.0: @@ -502,7 +543,7 @@ def _validate_llama3_parameters(config: PretrainedConfig): logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") if high_freq_factor is None or not isinstance(high_freq_factor, float): logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") - if high_freq_factor < low_freq_factor: + if high_freq_factor <= low_freq_factor: logger.warning( "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" f"{high_freq_factor} and low_freq_factor={low_freq_factor}" @@ -532,7 +573,7 @@ ROPE_VALIDATION_FUNCTIONS = { } -def rope_config_validation(config: PretrainedConfig): +def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None): """ Validate the RoPE config arguments, given a `PretrainedConfig` object """ @@ -544,8 +585,8 @@ def rope_config_validation(config: PretrainedConfig): rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default")) validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type) if validation_fn is not None: - validation_fn(config) + validation_fn(config, ignore_keys=ignore_keys) else: logger.warning( f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'" - ) + ) \ No newline at end of file