From 3f9bbf1181e741faf29f61fb587d3d07b07beecf Mon Sep 17 00:00:00 2001 From: djw <1913953267@qq.com> Date: Mon, 28 Apr 2025 08:44:47 +0000 Subject: [PATCH] support qwen3, dont speak human language --- .../models/configuration_qwen2_moe.py | 177 ++ .../models/configuration_qwen3_moe.py | 233 +++ ktransformers/models/custom_cache.py | 56 + .../models/custom_modeling_qwen2_moe.py | 133 ++ .../models/custom_modeling_qwen3_moe.py | 133 ++ ktransformers/models/modeling_qwen3_moe.py | 1472 +++++++++++++++++ ktransformers/operators/RoPE.py | 28 +- ktransformers/operators/attention.py | 89 - .../operators/balance_serve_attention.py | 287 ++++ ktransformers/operators/experts.py | 227 +++ .../flashinfer_batch_prefill_wrapper.py | 324 ++++ ktransformers/operators/gate.py | 69 + ktransformers/operators/layernorm.py | 89 +- ktransformers/operators/mlp.py | 18 +- .../DeepSeek-V3-Chat-serve.yaml | 2 +- .../Moonlight-16B-A3B-serve.yaml | 2 +- .../optimize/optimize_rules/Qwen2-serve.yaml | 95 ++ .../optimize_rules/Qwen3Moe-serve.yaml | 95 ++ ktransformers/server/args.py | 4 +- .../backend/interfaces/balance_serve.py | 43 +- .../balance_serve/inference/forward_batch.py | 2 +- .../balance_serve/inference/model_runner.py | 267 ++- .../server/balance_serve/sched_rpc.py | 9 +- .../server/balance_serve/settings.py | 108 +- ktransformers/server/config/config.py | 1 + ktransformers/server/requirements.txt | 2 +- ktransformers/util/custom_gguf.py | 15 +- pyproject.toml | 2 +- requirements-local_chat.txt | 2 +- third_party/custom_flashinfer | 2 +- 30 files changed, 3696 insertions(+), 290 deletions(-) create mode 100644 ktransformers/models/configuration_qwen2_moe.py create mode 100644 ktransformers/models/configuration_qwen3_moe.py create mode 100644 ktransformers/models/custom_modeling_qwen2_moe.py create mode 100644 ktransformers/models/custom_modeling_qwen3_moe.py create mode 100644 ktransformers/models/modeling_qwen3_moe.py create mode 100644 ktransformers/operators/balance_serve_attention.py create mode 100644 ktransformers/operators/flashinfer_batch_prefill_wrapper.py create mode 100644 ktransformers/optimize/optimize_rules/Qwen2-serve.yaml create mode 100644 ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml diff --git a/ktransformers/models/configuration_qwen2_moe.py b/ktransformers/models/configuration_qwen2_moe.py new file mode 100644 index 0000000..345af76 --- /dev/null +++ b/ktransformers/models/configuration_qwen2_moe.py @@ -0,0 +1,177 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2MoE model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen2MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a + Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of + Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B"). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen2MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 1408): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 5632): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 60): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + + ```python + >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig + + >>> # Initializing a Qwen2MoE style configuration + >>> configuration = Qwen2MoeConfig() + + >>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration + >>> model = Qwen2MoeModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen2_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=24, + num_attention_heads=16, + num_key_value_heads=16, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=1408, + shared_expert_intermediate_size=5632, + num_experts_per_tok=4, + num_experts=60, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/ktransformers/models/configuration_qwen3_moe.py b/ktransformers/models/configuration_qwen3_moe.py new file mode 100644 index 0000000..ebbf87f --- /dev/null +++ b/ktransformers/models/configuration_qwen3_moe.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3MoE model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen3MoeConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a + Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B). + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Qwen3MoeModel`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 4): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + use_sliding_window (`bool`, *optional*, defaults to `False`): + Whether to use sliding window attention. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention (SWA) window size. If not specified, will default to `4096`. + max_window_layers (`int`, *optional*, defaults to 28): + The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 768): + Intermediate size of the routed expert. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `False`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`List[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + ```python + >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig + >>> # Initializing a Qwen3MoE style configuration + >>> configuration = Qwen3MoeConfig() + >>> # Initializing a model from the Qwen3-15B-A2B" style configuration + >>> model = Qwen3MoeModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qwen3_moe" + keys_to_ignore_at_inference = ["past_key_values"] + + # Default tensor parallel plan for base model `Qwen3Moe` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=6144, + num_hidden_layers=24, + num_attention_heads=32, + num_key_value_heads=4, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=28, + attention_dropout=0.0, + decoder_sparse_step=1, + moe_intermediate_size=768, + num_experts_per_tok=8, + num_experts=128, + norm_topk_prob=False, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window if use_sliding_window else None + self.max_window_layers = max_window_layers + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + # Validate the correctness of rotary position embeddings parameters + # BC: if there is a 'type' field, move it to 'rope_type'. + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + rope_config_validation(self) + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen3MoeConfig"] \ No newline at end of file diff --git a/ktransformers/models/custom_cache.py b/ktransformers/models/custom_cache.py index 1030b61..e4a271e 100644 --- a/ktransformers/models/custom_cache.py +++ b/ktransformers/models/custom_cache.py @@ -275,3 +275,59 @@ class KDeepSeekV3Cache(nn.Module): return page_idx, page_offset +class KGQACache(nn.Module): + def __init__( + self, + config: PretrainedConfig, + page_size: int = 256, + dtype=torch.bfloat16, + device=torch.device("cuda:0"), + + ): + super().__init__() + self.config = config + self.dtype = dtype + self.device = device + self.page_size = page_size + self.k_caches = [] + self.v_caches = [] + + + def load(self, inference_context: sched_ext.InferenceContext): + print(self.config.num_hidden_layers) + for i in range(self.config.num_hidden_layers): + self.k_caches.append( + inference_context.k_cache[0][i] + ) + self.v_caches.append( + inference_context.v_cache[0][i] + ) + + + self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1] + + + + def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor): + page_offset = cache_position % self.page_size + page_idx_local = cache_position // self.page_size + query_ids = torch.zeros_like(cache_position) + for i in range(len(q_indptr) - 1): + start_idx = q_indptr[i] + end_idx = q_indptr[i + 1] + query_ids[start_idx:end_idx] = i + page_idx = torch.zeros_like(page_idx_local) + for i in range(bsz_tensors[0]): + query_id = query_ids[i] + local_block = page_idx_local[i] + start_block = kv_indptr[query_id] + if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]: + page_idx[i] = kv_indices[start_block + local_block] + + return page_idx, page_offset + + def get_k_cache(self, layer_idx): + return self.k_caches[layer_idx] + + def get_v_cache(self, layer_idx): + return self.v_caches[layer_idx] \ No newline at end of file diff --git a/ktransformers/models/custom_modeling_qwen2_moe.py b/ktransformers/models/custom_modeling_qwen2_moe.py new file mode 100644 index 0000000..5740c14 --- /dev/null +++ b/ktransformers/models/custom_modeling_qwen2_moe.py @@ -0,0 +1,133 @@ +""" +Date: 2024-11-06 10:05:11 +LastEditors: djw +LastEditTime: 2024-11-13 07:50:51 +""" + +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.models.custom_cache import KGQACache +from ktransformers.models.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoePreTrainedModel +from ktransformers.models.configuration_qwen2_moe import Qwen2MoeConfig +from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +import flashinfer + +class KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel): + + cache: KGQACache + use_cuda_graph = False + def __init__( + self, + config: Qwen2MoeConfig, + cache, + ): + super().__init__(config) + self.model = Qwen2MoeModel(config) + self.config = config + self.cache = cache + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.attn = [None] * 10 + + def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): + self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) + + + def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): + features = [] + for i in range(batch.batch_size): + tokens = batch.minibatch.tokens.contiguous() + feature = ( + self.model.embed_tokens(tokens.to(torch.device('cpu'))) + .to(torch.bfloat16) + .to(device=device) + ) + features.append(feature) + + return features + + + def forward( + self, + batch: ForwardBatchInput | None = None, + features: List[torch.Tensor] | None = None, + bsz_tensors: torch.Tensor | None = None, + num_tokens_tensors: torch.Tensor | None = None, + page_idx: torch.Tensor | None = None, + page_offset: torch.Tensor | None = None, + cuda_graph_idx: int | None = 0 + ) -> ForwardBatchOutput: + current_stream = torch.cuda.current_stream() + + forward_batch_output = ForwardBatchOutput() + + + hidden_states = features[0] + self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) + + with torch.cuda.stream(current_stream): + residual = torch.zeros_like(hidden_states) + for i, decode_layer in enumerate(self.model.layers): + if self.model.transfer_map is not None and i in self.model.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.model.transfer_map[i] + if cur_device not in self.model.stream_device_map: + self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.model.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.model.stream_device_map[cur_device]) + hidden_states = hidden_states.to( + self.model.transfer_map[i], non_blocking=True + ) + + batch.minibatch.position_ids = ( + batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) + if batch.minibatch.position_ids is not None + else None + ) + hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.self_attn(hidden_states, self.cache, + position_ids=batch.minibatch.position_ids, + wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, + page_idx=page_idx, + page_offset=page_offset + ) + + hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx) + hidden_states = hidden_states.squeeze(0) + forward_batch_output = ForwardBatchOutput() + with torch.cuda.stream(current_stream): + local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) + forward_batch_output.logits.append(local_logit) + + return forward_batch_output + + + + def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + causal: bool, + q_data_type: torch.dtype, + kv_data_type: torch.dtype, + cuda_graph_idx: int = 0 + ): + minibatch = batch.minibatch + self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, + minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors,num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) + \ No newline at end of file diff --git a/ktransformers/models/custom_modeling_qwen3_moe.py b/ktransformers/models/custom_modeling_qwen3_moe.py new file mode 100644 index 0000000..1cb8c46 --- /dev/null +++ b/ktransformers/models/custom_modeling_qwen3_moe.py @@ -0,0 +1,133 @@ +""" +Date: 2024-11-06 10:05:11 +LastEditors: djw +LastEditTime: 2024-11-13 07:50:51 +""" + +import math +from dataclasses import dataclass +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import List, Optional, Tuple, Union +import torch +import torch.utils.checkpoint +from torch import nn +from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput +from ktransformers.models.custom_cache import KGQACache +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeModel, Qwen3MoePreTrainedModel +from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig +from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +import flashinfer + +class KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel): + + cache: KGQACache + use_cuda_graph = False + def __init__( + self, + config: Qwen3MoeConfig, + cache = None, + ): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.config = config + self.cache = cache + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.attn = [None] * 10 + + def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0): + self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device) + + + def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): + features = [] + for i in range(batch.batch_size): + tokens = batch.minibatch.tokens.contiguous() + feature = ( + self.model.embed_tokens(tokens.to(torch.device('cpu'))) + .to(torch.bfloat16) + .to(device=device) + ) + features.append(feature) + + return features + + + def forward( + self, + batch: ForwardBatchInput | None = None, + features: List[torch.Tensor] | None = None, + bsz_tensors: torch.Tensor | None = None, + num_tokens_tensors: torch.Tensor | None = None, + page_idx: torch.Tensor | None = None, + page_offset: torch.Tensor | None = None, + cuda_graph_idx: int | None = 0 + ) -> ForwardBatchOutput: + current_stream = torch.cuda.current_stream() + + forward_batch_output = ForwardBatchOutput() + + + hidden_states = features[0] + self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0]) + + with torch.cuda.stream(current_stream): + residual = torch.zeros_like(hidden_states) + for i, decode_layer in enumerate(self.model.layers): + if self.model.transfer_map is not None and i in self.model.transfer_map: + prev_stream = torch.cuda.current_stream() + cur_device = self.model.transfer_map[i] + if cur_device not in self.model.stream_device_map: + self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + torch.cuda.set_device(cur_device) + self.model.stream_device_map[cur_device].wait_stream(prev_stream) + torch.cuda.set_stream(self.model.stream_device_map[cur_device]) + hidden_states = hidden_states.to( + self.model.transfer_map[i], non_blocking=True + ) + + batch.minibatch.position_ids = ( + batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True) + if batch.minibatch.position_ids is not None + else None + ) + hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.self_attn(hidden_states, self.cache, + position_ids=batch.minibatch.position_ids, + wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, + page_idx=page_idx, + page_offset=page_offset + ) + + hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual) + hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx) + hidden_states = hidden_states.squeeze(0) + forward_batch_output = ForwardBatchOutput() + with torch.cuda.stream(current_stream): + local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors) + forward_batch_output.logits.append(local_logit) + + return forward_batch_output + + + + def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + causal: bool, + q_data_type: torch.dtype, + kv_data_type: torch.dtype, + cuda_graph_idx: int = 0 + ): + minibatch = batch.minibatch + self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, + minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type) + \ No newline at end of file diff --git a/ktransformers/models/modeling_qwen3_moe.py b/ktransformers/models/modeling_qwen3_moe.py new file mode 100644 index 0000000..175f88c --- /dev/null +++ b/ktransformers/models/modeling_qwen3_moe.py @@ -0,0 +1,1472 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen3_moe.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +# from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +# from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.modeling_utils import PreTrainedModel +# from transformers.processing_utils import Unpack +from transformers.utils import ( + # LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from .configuration_qwen3_moe import Qwen3MoeConfig + +from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRotaryEmbedding + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Qwen/Qwen3-MoE-15B-A2B" +_CONFIG_FOR_DOC = "Qwen3MoeConfig" + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3MoeAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + + self.rotary_emb = Qwen2MoeRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.sliding_window = config.sliding_window + if not ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + self.sliding_window = None + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + # **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + # logger.warning_once( + # "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + # 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + # ) + # else: + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + # **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3MoeMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.experts = nn.ModuleList( + [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class Qwen3MoeRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen3MoeRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MoeDecoderLayer(nn.Module): + def __init__(self, config: Qwen3MoeConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + self.mlp = Qwen3MoeMLP(config) + + self.self_attn = Qwen3MoeAttention(config, layer_idx) + + if (layer_idx not in config.mlp_only_layers) and ( + config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 + ): + self.mlp = Qwen3MoeSparseMoeBlock(config) + else: + self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) + + self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + # **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, + and should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): + hidden_states, router_logits = hidden_states + else: + router_logits = None + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +def _compute_default_rope_parameters( + config: Optional[Qwen3MoeConfig] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + **rope_kwargs, +) -> Tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + rope_kwargs (`Dict`, *optional*): + BC compatibility with the previous RoPE class instantiation, will be removed in v4.45. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + if config is not None and len(rope_kwargs) > 0: + raise ValueError( + "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in " + f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}" + ) + if len(rope_kwargs) > 0: + base = rope_kwargs["base"] + dim = rope_kwargs["dim"] + elif config is not None: + base = config.rope_theta + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + dim = int(config.head_dim * partial_rotary_factor) + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + return inv_freq, attention_factor + +class Qwen3MoeRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen3MoeConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + self.scaling_factor = 1.0 + self.dim = config.head_dim + self.max_position_embeddings = config.max_position_embeddings + self.base = config.rope_theta + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + + inv_freq, self.attention_scaling = _compute_default_rope_parameters(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN3_MOE_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen3MoeConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.", + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoePreTrainedModel(PreTrainedModel): + config_class = Qwen3MoeConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3MoeDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + _supports_attention_backend = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN3_MOE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.", + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeModel(Qwen3MoePreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`] + + Args: + config: Qwen3MoeConfig + """ + + def __init__(self, config: Qwen3MoeConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + # **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen3MoeConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen3MoeConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... +class KwargsForCausalLM(): ... + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3MoeModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + # **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + # **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +@add_start_docstrings( + """ + The Qwen3Moe Model transformer with a sequence classification head on top (linear layer). + + [`Qwen3MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen3MoeModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen3Moe Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen3MoeModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen3Moe Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN3_MOE_START_DOCSTRING, +) +class Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen3MoeModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", +] \ No newline at end of file diff --git a/ktransformers/operators/RoPE.py b/ktransformers/operators/RoPE.py index 5233fc7..75d1a6e 100644 --- a/ktransformers/operators/RoPE.py +++ b/ktransformers/operators/RoPE.py @@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule): self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) # self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings \ No newline at end of file + self.max_seq_len_cached = max_position_embeddings + +class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + # device: str = "cuda", + generate_device: str = "cuda", + prefill_device: str = "cuda", + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs + ) + self.orig_module.__init__( + config, + ) + self.generate_device = generate_device + self.prefill_device = prefill_device + + def load(self): + self.orig_module.__init__( + self.orig_module.config + ) \ No newline at end of file diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py index 2d242f6..0f5f9ae 100644 --- a/ktransformers/operators/attention.py +++ b/ktransformers/operators/attention.py @@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule): attn_weights = None return attn_output, attn_weights, past_key_value - -class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention): - def __init__(self, - key: str, - gguf_loader : GGUFLoader, - config: PretrainedConfig, - orig_module: nn.Module, - prefill_device: str = "cuda", - generate_device: str = "cuda", - chunck_size: int = 1000, - **kwargs): - BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) - self.orig_module.__init__(orig_module.config, - orig_module.layer_idx) - self.chunck_size = chunck_size # TODO, generate chunck_size automatically. - - - def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: - if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): - kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) - q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) - out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) - self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, - bias=False, dtype=q_absorb.dtype, device=q_absorb.device) - self.q_absorb.weight.data = q_absorb - self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, - bias=False, dtype=out_absorb.dtype, device=out_absorb.device) - self.out_absorb.weight.data = out_absorb - #del self.orig_module.kv_b_proj - q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) - out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) - return q_absorb, out_absorb - - - - def forward(self, - hidden_states: torch.Tensor, - kv_cache: KDeepSeekV3Cache, - position_ids: torch.Tensor, - wrapper: BatchMLAPagedAttentionWrapper, - num_tokens_tensors: torch.Tensor, - page_idx: torch.Tensor, - page_offset: torch.Tensor, - ): - q_len, _ = hidden_states.size() - - if self.q_lora_rank is None: - q = self.q_proj(hidden_states, num_tokens_tensors) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors) - q = q.view(q_len, self.num_heads, self.q_head_dim) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - compressed_kv = compressed_kv.contiguous() - compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors) - k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim) - compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank) - - cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0)) - q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2) - q_pe = q_pe.squeeze(0) - if kv_cache is not None: - - # page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices) - cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models - compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs) - compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank) - k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim) - - q_absorb, out_absorb = self.get_absorbed() - q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below - q_nope = torch.matmul(q_nope, q_absorb) # batched MM - q_nope = q_nope.transpose(0, 1) - # q_nope.squeeze_(1) - # q_pe.squeeze_(1) - - attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank) - attn_output = attn_output.transpose(0, 1) - attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim] - attn_output = attn_output.transpose(0, 1) - attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) - attn_output = self.o_proj(attn_output, num_tokens_tensors) - return attn_output diff --git a/ktransformers/operators/balance_serve_attention.py b/ktransformers/operators/balance_serve_attention.py new file mode 100644 index 0000000..4a24fc9 --- /dev/null +++ b/ktransformers/operators/balance_serve_attention.py @@ -0,0 +1,287 @@ +''' +Description : +Author : Boxin Zhang +Version : 0.2.5 +Copyright (c) 2024 by KVCache.AI, All Rights Reserved. +''' +import torch +from torch import nn +from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb +from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention +from typing import Optional, Tuple +from ktransformers.operators.base_operator import BaseInjectedModule +from ktransformers.util.custom_gguf import GGUFLoader +import logging +from transformers.configuration_utils import PretrainedConfig +from flashinfer import BatchMLAPagedAttentionWrapper +from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn +from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache +logger = logging.getLogger("attention") + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + + def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]: + if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')): + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank) + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank) + self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, + bias=False, dtype=q_absorb.dtype, device=q_absorb.device) + self.q_absorb.weight.data = q_absorb + self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, + bias=False, dtype=out_absorb.dtype, device=out_absorb.device) + self.out_absorb.weight.data = out_absorb + #del self.orig_module.kv_b_proj + q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank) + out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank) + return q_absorb, out_absorb + + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KDeepSeekV3Cache, + position_ids: torch.Tensor, + wrapper: BatchMLAPagedAttentionWrapper, + num_tokens_tensors: torch.Tensor, + page_idx: torch.Tensor, + page_offset: torch.Tensor, + ): + q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states, num_tokens_tensors) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors) + q = q.view(q_len, self.num_heads, self.q_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + compressed_kv = compressed_kv.contiguous() + compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors) + k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim) + compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank) + + cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0)) + q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2) + q_pe = q_pe.squeeze(0) + if kv_cache is not None: + + # page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices) + cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models + compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs) + compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank) + k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim) + + q_absorb, out_absorb = self.get_absorbed() + q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below + q_nope = torch.matmul(q_nope, q_absorb) # batched MM + q_nope = q_nope.transpose(0, 1) + # q_nope.squeeze_(1) + # q_pe.squeeze_(1) + + attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank) + attn_output = attn_output.transpose(0, 1) + attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim] + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim) + attn_output = self.o_proj(attn_output, num_tokens_tensors) + return attn_output + +class KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + + # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KGQACache, + position_ids: torch.Tensor, + wrapper: flashInferAttn, + bsz_tensors: torch.Tensor, + page_idx: torch.Tensor, + page_offset: torch.Tensor, + ): + q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, bsz_tensors) + key_states = self.k_proj(hidden_states, bsz_tensors) + value_states = self.v_proj(hidden_states, bsz_tensors) + + + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0)) + query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) + + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view( + q_len, self.num_key_value_heads, self.head_dim + ) + value_states = value_states.view( + q_len, self.num_key_value_heads, self.head_dim + ) + + k_cache = kv_cache.get_k_cache(self.layer_idx) + v_cache = kv_cache.get_v_cache(self.layer_idx) + + + attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) + + + attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors) + + return attn_output + +class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + chunck_size: int = 1000, + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.layer_idx) + self.chunck_size = chunck_size # TODO, generate chunck_size automatically. + + + # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb + def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + + def forward(self, + hidden_states: torch.Tensor, + kv_cache: KGQACache, + position_ids: torch.Tensor, + wrapper: flashInferAttn, + bsz_tensors: torch.Tensor, + page_idx: torch.Tensor, + page_offset: torch.Tensor, + ): + q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors) + key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors) + value_states = self.v_proj(hidden_states, bsz_tensors) + + + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) + + cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0)) + query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2) + + query_states = query_states.view(q_len, self.num_heads, self.head_dim) + key_states = key_states.view( + q_len, self.num_key_value_heads, self.head_dim + ) + value_states = value_states.view( + q_len, self.num_key_value_heads, self.head_dim + ) + + k_cache = kv_cache.get_k_cache(self.layer_idx) + v_cache = kv_cache.get_v_cache(self.layer_idx) + + + attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states) + + + attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors) + + return attn_output diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index f73c4c3..8e8f2b0 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -689,6 +689,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase): from ktransformers.models.modeling_deepseek import DeepseekV2MoE from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock @@ -1267,3 +1268,229 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase): self.unload() else: raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") + +class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): + def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): + + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + router_logits = self.gate(hidden_states, bsz_tensor) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # only for generate phase + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) + y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + + y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) + + y += y_ + y.resize_(*orig_shape) + return y + + y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + y_ = ( + F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + ) + + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + +class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): + + orig_shape = hidden_states.shape + sequence_length = orig_shape[1] + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + router_logits = self.gate(hidden_states, bsz_tensor) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # only for generate phase + if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug + self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx) + # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + + y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0) + + # y += y_ + y.resize_(*orig_shape) + return y + + # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0) + # y_ = ( + # F.sigmoid(self.shared_expert_gate(hidden_states)) * y_ + # ) + + + if isinstance(self.experts, KExpertsBase): + y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device) + elif hidden_states.size(0) > 10: + # TODO may bugs here + y = ( + self.moe_infer(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + else: + # TODO may bugs here + y = ( + self.moe_infer_simple(hidden_states, selected_experts, routing_weights) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + # y += y_ + return y + + @torch.no_grad() + def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor: + outs = torch.empty_like(x) + outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ) -> torch.Tensor: + """ + x: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + """ + outs = torch.zeros_like(x) + for token_idx in range(topk_ids.size(0)): + for expert_idx in range(topk_ids.size(1)): + expert = self.experts[topk_ids[token_idx, expert_idx]] + outs[token_idx] += ( + expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx] + ) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert.forward(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out \ No newline at end of file diff --git a/ktransformers/operators/flashinfer_batch_prefill_wrapper.py b/ktransformers/operators/flashinfer_batch_prefill_wrapper.py new file mode 100644 index 0000000..e934654 --- /dev/null +++ b/ktransformers/operators/flashinfer_batch_prefill_wrapper.py @@ -0,0 +1,324 @@ +import torch +import flashinfer +import gc +try: + from flash_attn import flash_attn_with_kvcache + print("found flash_attn") + +except ImportError: + print("flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.") + +from typing import Union, Optional + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +setup_seed(998244353) + +torch.set_grad_enabled(False) +torch.set_default_dtype(torch.bfloat16) +global_dtype=torch.bfloat16 +global_device=torch.device("cuda",0) +torch.cuda.set_device(0) +torch.backends.cudnn.enabled =True +torch.backends.cudnn.benchmark = True + +class flashInferAttn(): + + float_workspace_buffer = None + def __init__(self, + max_batch_token, + max_batch_size, + max_pages, + device = "cuda:0", + kv_layout: str = "NHD", + use_cuda_graph: bool = False, + ) -> None: + self.device = device + self.max_batch_token = max_batch_token + self.kv_layout = kv_layout + self.use_cuda_graph = use_cuda_graph + if flashInferAttn.float_workspace_buffer is None: + flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device) + self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) + self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device) + self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device) + self.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device) + self.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device) + self.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device) + + # TODO: custom mask + self.custom_mask_buf = None + self.qk_indptr_buf = None + self.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + flashInferAttn.float_workspace_buffer, + self.kv_layout, + use_cuda_graph=self.use_cuda_graph, + qo_indptr_buf=self.qo_indptr_buf, + paged_kv_indptr_buf=self.paged_kv_indptr_buf, + paged_kv_indices_buf=self.paged_kv_indices_buf, + paged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf, + backend = "fa2", + ) + + def plan(self, + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + batch_size_tensor: torch.Tensor, + num_tokens_tensor: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + causal: bool = True, + pos_encoding_mode: str = "NONE", + q_data_type: Union[str, torch.dtype] = torch.bfloat16, + kv_data_type: Optional[Union[str, torch.dtype]] = None): + + self.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True) + self.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True) + self.page_size = page_size + self.warpper.plan( + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal = causal, + pos_encoding_mode = pos_encoding_mode, + q_data_type = q_data_type, + kv_data_type = kv_data_type + ) + + def calc_batch_indices(self, ragged_size = None): + if self.use_cuda_graph: + self.batch_indices, self.positions = flashinfer.get_batch_indices_positions( + self.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token) + else: + self.batch_indices, self.positions = flashinfer.get_batch_indices_positions( + self.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size) + + def forward(self, q, k_cache, v_cache, k, v): + if self.use_cuda_graph: + flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf) + return self.warpper.run(q, (k_cache, v_cache)) + else: + flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf) + return self.warpper.run(q, (k_cache, v_cache)) + + +def testCudaGraph(): + + # use max batch to create buffer + batch_decode = 8 + prefill_chunk = 48 + past_kv_0 = 4090 + past_kv_1 = 4096 + raged_size = prefill_chunk + batch_decode + num_key_value_heads = 8 + head_dim = 128 + num_attention_heads = 64 + page_size = 256 + num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size + total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size + attn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True) + + batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32) + + k_caches = [] + v_caches = [] + ks = [] + vs = [] + qs = [] + for layer_idx in range(3): + k_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + v_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + ks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + vs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + + # warmup and capture small batch + past_kv_0 = 250 + past_kv_1 = 256 + num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size + total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size + q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device) + q_indptr[0] = 0 + q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32) + kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32) + kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device) + kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1) + kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1) + + print(q_indptr) + print(kv_indptr) + print(kv_indices) + print(kv_last_page_len) + attn.plan(q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + batch_size_tensor, + num_attention_heads, + num_key_value_heads, + head_dim, + page_size, + causal = True, + pos_encoding_mode="NONE", + q_data_type=torch.bfloat16) + + attn.calc_batch_indices(raged_size) + for layer_idx in range(3): + attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]) + torch.cuda.synchronize() + + outs = [] + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for layer_idx in range(3): + outs.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])) + g.replay() + + kv_last_page_len[:1+batch_decode//2] = int(past_kv_0) + kv_last_page_len[1+batch_decode//2:] = int(past_kv_1) + for layer_idx in range(3): + for i in range(batch_decode + 1): + + qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]] + o_ref_i = flash_attn_with_kvcache( + qi.unsqueeze(0), + k_caches[layer_idx], + v_caches[layer_idx], + causal=True, + block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0), + cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32) + ) + o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]] + print(layer_idx, i) + torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3) + + # run another batch size use capture cuda graph + past_kv_0 = 4090 + past_kv_1 = 4096 + prefill_chunk = 24 + batch_decode = 4 + num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size + total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size + batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32) + num_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32) + + q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device) + q_indptr[0] = 0 + q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32) + kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32) + kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device) + kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1) + kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1) + attn.plan(q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + batch_size_tensor, + num_attention_heads, + num_key_value_heads, + head_dim, + page_size, + causal = True, + pos_encoding_mode="NONE", + q_data_type=torch.bfloat16) + attn.calc_batch_indices(raged_size) + g.replay() + + kv_last_page_len[:1+batch_decode//2] = int(past_kv_0) + kv_last_page_len[1+batch_decode//2:] = int(past_kv_1) + for layer_idx in range(3): + for i in range(batch_decode + 1): + + qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]] + o_ref_i = flash_attn_with_kvcache( + qi.unsqueeze(0), + k_caches[layer_idx], + v_caches[layer_idx], + causal=True, + block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0), + cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32) + ) + o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]] + print(layer_idx, i) + torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3) + + + +def testAttentionFlashInfer( + ): + batch_decode = 32 + prefill_chunk = 64 + past_kv_0 = 510 + past_kv_1 = 512 + raged_size = prefill_chunk + batch_decode + num_key_value_heads = 8 + head_dim = 128 + num_attention_heads = 64 + cases = 1 + page_size = 32 + num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size + total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + qs = [] + kvs = [] + q_indptrs = [] + kv_indptrs = [] + kv_indicess = [] + kv_last_page_lens = [] + wrappers = [] + for case_id in range(cases): + kvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16)) + q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device) + q_indptr[0] = 0 + q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32) + q_indptrs.append(q_indptr) + kv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq) + kv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)) + kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device) + kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1) + kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1) + kv_last_page_lens.append(kv_last_page_len) + wrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + qo_indptr_buf=q_indptrs[case_id], + paged_kv_indptr_buf=kv_indptrs[case_id], + paged_kv_indices_buf=kv_indicess[case_id], + paged_kv_last_page_len_buf=kv_last_page_lens[case_id], + )) + wrappers[case_id].plan( + q_indptrs[case_id], + kv_indptrs[case_id], + kv_indicess[case_id], + kv_last_page_lens[case_id], + num_attention_heads, + num_key_value_heads, + head_dim, + page_size, + causal = True, + pos_encoding_mode="ROPE_LLAMA", + q_data_type=torch.bfloat16 + ) + + def custom_forward(case_id): + out = wrappers[case_id].run(qs[case_id], kvs[case_id]) + + custom_forward(0) + +# testCudaGraph() +# pass \ No newline at end of file diff --git a/ktransformers/operators/gate.py b/ktransformers/operators/gate.py index 46b97c7..cf5799e 100644 --- a/ktransformers/operators/gate.py +++ b/ktransformers/operators/gate.py @@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): self.e_score_correction_bias = None +class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase): + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module = None, + generate_device: str = "cuda", + generate_op: str| None = "KLinearMarlin", + prefill_device: str = "cuda", + prefill_op: str| None = "KLinearMarlin", + use_quant: bool = False, + **kwargs, + ): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs) + KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) + self.generate_device = generate_device + self.prefill_device = prefill_device + self.generate_op = generate_op + self.prefill_op = prefill_op + self.is_windows = os.name == 'nt' + self.use_quant = use_quant + if not self.is_windows and use_quant: + self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device) + self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp", + gguf_loader, config, self.gate_linear, #orig_module + generate_device, generate_op, prefill_device, prefill_op) + else: + self.gate_linear = None + + def forward(self, hidden_states) -> torch.Tensor: + if self.is_windows: + return self.orig_module.forward(hidden_states) + + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + if self.use_quant: + logits = self.gate_linear.forward(logits) + else: + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + + return grouped_topk(hidden_states, logits, + self.top_k, self.norm_topk_prob, + self.n_group, self.topk_group) + + def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None): + if device is None: device = self.device + if w is None: w = self.load_weights(device=device) + + if isinstance(w, dict): + self.weight_type = w["weight_type"] + self.e_score_correction_bias_type = w["e_score_correction_bias_type"] + self.orig_module.weight = nn.Parameter(w["weight"]) + self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"]) + else: + raise ValueError("Invalid weight type") + self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device)) + self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device)) + if not self.is_windows and self.use_quant: + self.gate_linear.load(self.orig_module.weight) + + def unload(self): + if self.weight is not None: + self.weight = None + if self.e_score_correction_bias is not None: + self.e_score_correction_bias = None \ No newline at end of file diff --git a/ktransformers/operators/layernorm.py b/ktransformers/operators/layernorm.py index 8e8cbc7..62c5cba 100644 --- a/ktransformers/operators/layernorm.py +++ b/ktransformers/operators/layernorm.py @@ -26,6 +26,8 @@ from transformers import PretrainedConfig import torch import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm +from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.util.custom_gguf import GGUFLoader from flashinfer.norm import ( @@ -75,4 +77,89 @@ class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) \ No newline at end of file + return self.weight * hidden_states.to(input_dtype) + + +class KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(config.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x: torch.Tensor, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + #return self.forward_native(x, residual) + if batch_size_tensor is None: + return self.forward_native(x) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + #residual = x + residual + #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + return x, residual + # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) + out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) + return out + + def forward_native( + self, hidden_states + ): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.hidden_size, + orig_module.variance_epsilon) + + def forward( + self, + x: torch.Tensor, + batch_size_tensor: torch.Tensor = None, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + #return self.forward_native(x, residual) + bsz, hidden_size = x.shape + x = x.view(-1, self.orig_module.hidden_size) + if batch_size_tensor is None: + return self.forward_native(x) + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + #residual = x + residual + #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon) + return x, residual + # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous()) + out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon) + out = out.view(bsz, hidden_size) + return out + + def forward_native( + self, hidden_states + ): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) diff --git a/ktransformers/operators/mlp.py b/ktransformers/operators/mlp.py index d8c502d..02648b1 100644 --- a/ktransformers/operators/mlp.py +++ b/ktransformers/operators/mlp.py @@ -4,8 +4,7 @@ from ktransformers.util.custom_gguf import GGUFLoader from transformers import PretrainedConfig import torch.nn as nn from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP - - +from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule): def __init__(self, key: str, @@ -18,6 +17,21 @@ class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule): BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size) + def forward(self, x, bsz_tensor): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) + return down_proj +class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule): + def __init__(self, + key: str, + gguf_loader : GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + prefill_device: str = "cuda", + generate_device: str = "cuda", + **kwargs): + BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs) + self.orig_module.__init__(orig_module.config, + orig_module.intermediate_size) def forward(self, x, bsz_tensor): down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor) return down_proj \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml index 622ad21..e1c61b8 100644 --- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml +++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml @@ -56,7 +56,7 @@ - match: name: "^model\\.layers\\..*\\.self_attn$" replace: - class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation + class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml index b51fee5..bc52e0e 100644 --- a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml +++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml @@ -50,7 +50,7 @@ - match: name: "^model\\.layers\\..*\\.self_attn$" replace: - class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation + class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation kwargs: generate_device: "cuda" prefill_device: "cuda" diff --git a/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml b/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml new file mode 100644 index 0000000..41b41a7 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml @@ -0,0 +1,95 @@ +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +# - match: +# name: "^model\\.layers\\..*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" +# generate_op: "VLinearMarlin" +# prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen2MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm + replace: + class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP + replace: + class: ktransformers.operators.mlp.KQwen2MoeMLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml b/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml new file mode 100644 index 0000000..63f67da --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml @@ -0,0 +1,95 @@ +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +# - match: +# name: "^model\\.layers\\..*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" +# generate_op: "VLinearMarlin" +# prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module +- match: + name: "^model\\.layers\\..*\\.self_attn$" + replace: + class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation + kwargs: + generate_device: "cuda" + prefill_device: "cuda" +- match: + name: "^model$" + replace: + class: "ktransformers.operators.models.KQwen2MoeModel" + kwargs: + per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" + +- match: + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm + replace: + class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP + replace: + class: ktransformers.operators.mlp.KQwen2MoeMLP + kwargs: + generate_device: "cuda" + prefill_device: "cuda" \ No newline at end of file diff --git a/ktransformers/server/args.py b/ktransformers/server/args.py index 0536ec9..95934e4 100644 --- a/ktransformers/server/args.py +++ b/ktransformers/server/args.py @@ -20,6 +20,7 @@ class ArgumentParser: parser.add_argument( "--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter" ) + parser.add_argument("--architectures", type=str, default=self.cfg.model_name) parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path) parser.add_argument("--optimize_config_path", default=None, type=str, required=False) parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer) @@ -93,6 +94,7 @@ class ArgumentParser: parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm) parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think) parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph) + # parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=False) # web config parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain) @@ -137,7 +139,7 @@ class ArgumentParser: self.cfg.server_port = args.port self.cfg.user_force_think = args.force_think - args.gpu_memory_size = args.cache_lens*2*576*61 + args.gpu_memory_size = 4*1024*1024*1024 # TODO: set this to the actual GPU memory size self.cfg.gpu_memory_size = args.gpu_memory_size free_ports = get_free_ports(3, [args.port]) args.sched_port = free_ports[0] diff --git a/ktransformers/server/backend/interfaces/balance_serve.py b/ktransformers/server/backend/interfaces/balance_serve.py index 0abe2e0..6301e97 100644 --- a/ktransformers/server/backend/interfaces/balance_serve.py +++ b/ktransformers/server/backend/interfaces/balance_serve.py @@ -1,5 +1,5 @@ from typing import Any, AsyncIterator, List, Optional, Set -from ktransformers.models.custom_cache import KDeepSeekV3Cache +from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache from transformers import ( AutoTokenizer, AutoConfig, @@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM +from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM +from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM +from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig from ktransformers.server.balance_serve.inference.model_runner import ModelRunner from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions from ktransformers.server.balance_serve.inference.query_manager import QueryManager @@ -53,8 +56,10 @@ ktransformer_rules_dir = ( os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/") ) default_optimize_rules = { - "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml", - "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml", + "DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml", + # "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml", + "Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml", + "Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml", } @@ -105,7 +110,7 @@ class Engine: model_runner: ModelRunner sampler: Sampler query_manager: QueryManager - cache: KDeepSeekV3Cache + cache: KDeepSeekV3Cache | KGQACache def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None): self.args = args @@ -117,17 +122,32 @@ class Engine: self.device = self.args.device self.sched_client = SchedulerClient(args.sched_port) self.updates = [] - config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) - self.cache = KDeepSeekV3Cache(config, self.args.page_size) + + try: + config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + except: + if args.model_name == "Qwen3Moe": + config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + else: + assert False, f"model {args.model_name} not supported" + self.gen_queue = generated_token_queue with torch.device("meta"): if config.architectures[0] == "DeepseekV3ForCausalLM": + self.cache = KDeepSeekV3Cache(config, self.args.page_size) self.model = KDeepseekV3ForCausalLM(config, self.cache) elif config.architectures[0] == "DeepseekV2ForCausalLM": + self.cache = KDeepSeekV3Cache(config, self.args.page_size) self.model = KDeepseekV2ForCausalLM(config, self.cache) - # print(self.block_num) + elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM": + self.cache = KGQACache(config, self.args.page_size) + if config.architectures[0] == "Qwen2MoeForCausalLM": + self.model = KQwen2MoeForCausalLM(config, self.cache) + else: + self.model = KQwen3MoeForCausalLM(config, self.cache) + context = zmq.Context() @@ -176,9 +196,12 @@ class Engine: self.block_num = inference_context.k_cache[0].size(1) #@TODO add config - self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) + if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM": + self.model.init_wrapper(self.args.use_cuda_graph, self.device, 1024 ,args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens) + else: + self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num) - self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size) + self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num) self.sampler = Sampler() self.query_manager = QueryManager(device = self.device, page_size = args.page_size) @@ -231,7 +254,7 @@ class Engine: if self.batch is not None: self.model_runner.sync() - print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms") + print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms, {1000/self.model_runner.model_time:.3f} tokens/s") # if self.rank == 0: generated_tokens, probs = self.sampling( self.model_runner.output) diff --git a/ktransformers/server/balance_serve/inference/forward_batch.py b/ktransformers/server/balance_serve/inference/forward_batch.py index 4f79bc3..7022d9e 100644 --- a/ktransformers/server/balance_serve/inference/forward_batch.py +++ b/ktransformers/server/balance_serve/inference/forward_batch.py @@ -281,4 +281,4 @@ class ForwardBatchOutput: self.generated_tokens_num = [] self.top_ps = [] self.temperatures = [] - pass \ No newline at end of file + self.num_batchs = 1 \ No newline at end of file diff --git a/ktransformers/server/balance_serve/inference/model_runner.py b/ktransformers/server/balance_serve/inference/model_runner.py index 386307b..03e18d1 100644 --- a/ktransformers/server/balance_serve/inference/model_runner.py +++ b/ktransformers/server/balance_serve/inference/model_runner.py @@ -27,6 +27,8 @@ from ktransformers.server.balance_serve.inference.forward_batch import ForwardBa from ktransformers.server.config.config import Config from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM +from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM +from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM from ktransformers.server.balance_serve.inference.query_manager import QueryManager from ktransformers.server.balance_serve.settings import sched_ext @@ -40,11 +42,11 @@ def deduplicate_and_sort(lst): class ModelRunner: """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" - model: KDeepseekV3ForCausalLM + model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM input: ForwardBatchInput | list[ForwardBatchInput] output: ForwardBatchOutput - def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256): + def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8): self.stream = torch.cuda.Stream(device=device) # 先注释掉 @@ -58,120 +60,92 @@ class ModelRunner: self.use_cuda_graph = use_cuda_graph self.model_time = 0 self.page_size = page_size + self.block_num = block_num # GPU timing for model execution self.start_model_event = torch.cuda.Event(enable_timing=True) self.end_model_event = torch.cuda.Event(enable_timing=True) - if isinstance(self.cuda_graphs, list): - self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))] - self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] - self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] - else: - self.graphs = torch.cuda.CUDAGraph() - self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device) - self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device) + + self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))] + self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] + self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))] + self.num_mini_batches = num_mini_batches self.max_chunk_size = max_chunk_size self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device) self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device) + + def model_attn_plan(self, batch, cuda_graph_idx=0): + if isinstance(self.model, KDeepseekV3ForCausalLM): + self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, + num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, + head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, + sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) + elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM): + self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf, + num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads, + head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads, + page_size=self.model.cache.page_size, causal=True, + q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx) + else: + assert False, "model type not supported" + + def warmup(self): - def capture_graphs(cuda_graph_idx=-1): - if cuda_graph_idx != -1: - with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream): - self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx) - self.graph_memory_pool = self.graphs[cuda_graph_idx].pool() - else: - with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream): - self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf) - self.graph_memory_pool = self.graphs.pool() + def capture_graphs(cuda_graph_idx): + with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream): + self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx) + self.graph_memory_pool = self.graphs[cuda_graph_idx].pool() - if isinstance(self.cuda_graphs, list): - self.input = [] - self.features_buf = [] - self.outputs_buf = [] - self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) - self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) - for i in range(len(self.cuda_graphs)): - prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch - self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i])) + self.input = [] + self.features_buf = [] + self.outputs_buf = [] + self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) + self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device) + for i in range(len(self.cuda_graphs)): + prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch + self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens=self.cuda_graphs[i])) - self.features_buf.append(self.model.batch_embeddings(self.input[i])) - batch_size = self.input[i].minibatch.q_indptr.size(0)-1 - num_tokens = self.features_buf[i][0].size(0) - print("capturing cuda graph", batch_size, num_tokens) - self.bsz_tensor_buf[0] = batch_size - self.num_tokens_tensor_buf[0] = num_tokens + self.features_buf.append(self.model.batch_embeddings(self.input[i])) + batch_size = self.input[i].minibatch.q_indptr.size(0)-1 + num_tokens = self.features_buf[i][0].size(0) + print("capturing cuda graph", batch_size, num_tokens) - self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, - num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, - head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, - sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - - page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf) + if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM): + self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens) - self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens]) - self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens]) - self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) - - self.outputs_buf.append(None) - - torch.cuda.synchronize() - for warm_up_iters in range(11): - with torch.cuda.stream(self.stream): - self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i]) - torch.cuda.synchronize() + self.bsz_tensor_buf[0] = batch_size + self.num_tokens_tensor_buf[0] = num_tokens - capture_graphs(i) - - with torch.cuda.stream(self.stream): - self.graphs[i].replay() - - self.sync(calc_time=False) - print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.") - else: - self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches) - - self.features_buf = self.model.batch_embeddings(self.input) - batch_size = self.input.minibatch.q_indptr.size(0)-1 - num_tokens = self.features_buf[0].size(0) + self.model_attn_plan(self.input[i], i) + + page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf) - self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device) - self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device) + self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens]) + self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens]) - - self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf, - num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, - head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, - sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - - page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf) - self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens]) - self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens]) - self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) - - + self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) + + self.outputs_buf.append(None) + torch.cuda.synchronize() for warm_up_iters in range(11): with torch.cuda.stream(self.stream): - self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf) + self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], cuda_graph_idx=i) torch.cuda.synchronize() - def capture_graphs(): - with torch.cuda.graph(self.graphs, stream=self.stream): - self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf) - # self.graph_memory_pool = self.graphs.pool() + self.outputs_buf[i].num_batchs = batch_size - - capture_graphs() + capture_graphs(i) with torch.cuda.stream(self.stream): - self.graphs.replay() + self.graphs[i].replay() self.sync(calc_time=False) - print("warmup finished.") + print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.") def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None): with torch.cuda.stream(self.stream): @@ -189,107 +163,54 @@ class ModelRunner: - if isinstance(self.cuda_graphs, list): - # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens - cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs)) - if cuda_graph_idx == len(self.cuda_graphs): - assert False, "num_tokens is too large" - else: - cuda_graph_idx = -1 + # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens + cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs)) + if not self.use_cuda_graph: + cuda_graph_idx = 0 + # if cuda_graph_idx == len(self.cuda_graphs): + # assert False, "num_tokens is too large" if self.use_cuda_graph: - if cuda_graph_idx != -1: - self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size) - else: - self.input.fill(batch, query_manager, self.page_size) + self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size) else: - self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device) - + self.input = [ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)] + - if cuda_graph_idx != -1 and self.use_cuda_graph: + if self.use_cuda_graph: self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device) else: - self.features = self.model.batch_embeddings(self.input, device=self.device) - + self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device) + self.bsz_tensor_buf.copy_(batch_size) self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device)) if self.use_cuda_graph: - if cuda_graph_idx != -1: - self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True) - else: - self.features_buf[0].copy_(self.features[0], non_blocking=True) - """ - if num_tokens_0 > 64: - padded_num_tokens_0 = pad_num_tokens(num_tokens_0) - self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0 - """ - #self.input.forward_minibatchs[0].print() - # print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches]) - # print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}") + self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True) - # self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors) - - """ + self.model_attn_plan(self.input[cuda_graph_idx], cuda_graph_idx) + self.start_model_event.record(self.stream) + page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf) if self.use_cuda_graph: - print("before replay features_buf", self.features_buf[0]) - print("features_buf addr", self.features_buf[0].data_ptr()) + self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens]) + self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens]) + + self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) + self.replay(cuda_graph_idx) + self.output = ForwardBatchOutput() + + self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps) + self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures) + + + self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone()) else: - print("before run features", self.features[0]) - """ - if cuda_graph_idx != -1 and self.use_cuda_graph: - self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, - num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, - head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, - sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - self.start_model_event.record(self.stream) - page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf) - if self.use_cuda_graph: - self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens]) - self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens]) - self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) - self.replay(cuda_graph_idx) - self.output = ForwardBatchOutput() - - self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps) - self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures) + self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset) + self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start] + self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps) + self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures) + self.end_model_event.record(self.stream) - self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone()) - else: - self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset) - self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start] - self.end_model_event.record(self.stream) - else: - self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf, - num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, - head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True, - sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16) - self.start_model_event.record(self.stream) - page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf) - if self.use_cuda_graph: - self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens]) - self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens]) - self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1) - self.replay(cuda_graph_idx) - self.output = ForwardBatchOutput() - - self.output.top_ps.append(self.input.minibatch.top_ps) - self.output.temperatures.append(self.input.minibatch.temperatures) - - self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone()) - else: - self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset) - self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start] - self.output.top_ps.append(self.input.minibatch.top_ps) - self.output.temperatures.append(self.input.minibatch.temperatures) - - self.end_model_event.record(self.stream) - - if not self.use_cuda_graph: - self.output.num_batchs = self.input.batch_size - else: - self.output.num_batchs = self.input[cuda_graph_idx].batch_size def replay(self, cuda_graph_idx=-1): diff --git a/ktransformers/server/balance_serve/sched_rpc.py b/ktransformers/server/balance_serve/sched_rpc.py index 8294b43..218d1d3 100644 --- a/ktransformers/server/balance_serve/sched_rpc.py +++ b/ktransformers/server/balance_serve/sched_rpc.py @@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__) # sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) import pickle import argparse -from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings +from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe @@ -209,5 +209,10 @@ if __name__ == '__main__': args = parser.parse_args() with open(args.config, "rb") as f: main_args = pickle.load(f) - settings = create_sched_settings(main_args) + if main_args.architectures == "Qwen2MoeForCausalLM": + settings = create_sched_settings_qwen2moe(main_args) + elif main_args.architectures == "Qwen3MoeForCausalLM": + settings = create_sched_settings_qwen3moe(main_args) + else: + settings = create_sched_settings(main_args) start_server(settings, main_args) diff --git a/ktransformers/server/balance_serve/settings.py b/ktransformers/server/balance_serve/settings.py index a79cdac..540dc1c 100644 --- a/ktransformers/server/balance_serve/settings.py +++ b/ktransformers/server/balance_serve/settings.py @@ -11,6 +11,8 @@ from time import sleep import sched_ext from transformers import AutoConfig +from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig + def create_sched_settings(args): default_sample_options = sched_ext.SampleOptions() model_name = os.path.basename(os.path.normpath(args.model_dir)) @@ -64,7 +66,111 @@ def create_sched_settings(args): return settings - +def create_sched_settings_qwen2moe(args): + default_sample_options = sched_ext.SampleOptions() + model_name = os.path.basename(os.path.normpath(args.model_dir)) + input_model_settings = sched_ext.ModelSettings() + input_model_settings.model_path = args.model_dir + input_model_settings.params_count = int(0) + model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) + input_model_settings.layer_count = model_config.num_hidden_layers + input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] + input_model_settings.k_head_dim = 128 + input_model_settings.bytes_per_params = 2 + input_model_settings.bytes_per_kv_cache_element = 2 + settings = sched_ext.Settings() + settings.model_name = model_name + settings.quant_type = "BF16" + settings.model_settings = input_model_settings + settings.page_size = args.page_size + settings.gpu_device_count = 1 # tp + settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] + # settings.gpu_memory_size = args.cache_lens*576*2 + settings.gpu_memory_size = args.gpu_memory_size + settings.memory_utilization_percentage = args.utilization_percentage + max_batch_size = args.max_batch_size + chunk_size = args.chunk_size + + max_decode_batch_size = max_batch_size - 2 + + settings.max_batch_size = max_batch_size + settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 + settings.sample_options = default_sample_options + settings.sched_metrics_port = args.sched_metrics_port + settings.gpu_only = args.memory_gpu_only + settings.use_self_defined_head_dim = False + settings.self_defined_head_dim = 576 + settings.full_kv_cache_on_each_gpu = True + settings.k_cache_on = True + settings.v_cache_on = True + + settings.kvc2_root_path = '/mnt/data/persist-kvc' + settings.kvc2_config_path = args.kvc2_config_dir + settings.memory_pool_size_GB = args.cpu_memory_size_GB + settings.evict_count = 40 + settings.kvc2_metrics_port = args.kvc2_metrics_port + settings.load_from_disk = False + settings.save_to_disk = True + + + settings.strategy_name = args.sched_strategy + + settings.auto_derive() + return settings + + + +def create_sched_settings_qwen3moe(args): + default_sample_options = sched_ext.SampleOptions() + model_name = os.path.basename(os.path.normpath(args.model_dir)) + input_model_settings = sched_ext.ModelSettings() + input_model_settings.model_path = args.model_dir + input_model_settings.params_count = int(0) + model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True) + input_model_settings.layer_count = model_config.num_hidden_layers + input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"] + input_model_settings.k_head_dim = 128 + input_model_settings.bytes_per_params = 2 + input_model_settings.bytes_per_kv_cache_element = 2 + settings = sched_ext.Settings() + settings.model_name = model_name + settings.quant_type = "BF16" + settings.model_settings = input_model_settings + settings.page_size = args.page_size + settings.gpu_device_count = 1 # tp + settings.gpu_device_id = [i for i in range(settings.gpu_device_count)] + # settings.gpu_memory_size = args.cache_lens*576*2 + settings.gpu_memory_size = args.gpu_memory_size + settings.memory_utilization_percentage = args.utilization_percentage + max_batch_size = args.max_batch_size + chunk_size = args.chunk_size + + max_decode_batch_size = max_batch_size - 2 + + settings.max_batch_size = max_batch_size + settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2 + settings.sample_options = default_sample_options + settings.sched_metrics_port = args.sched_metrics_port + settings.gpu_only = args.memory_gpu_only + settings.use_self_defined_head_dim = False + settings.self_defined_head_dim = 576 + settings.full_kv_cache_on_each_gpu = True + settings.k_cache_on = True + settings.v_cache_on = True + + settings.kvc2_root_path = '/mnt/data/persist-kvc' + settings.kvc2_config_path = args.kvc2_config_dir + settings.memory_pool_size_GB = args.cpu_memory_size_GB + settings.evict_count = 40 + settings.kvc2_metrics_port = args.kvc2_metrics_port + settings.load_from_disk = False + settings.save_to_disk = True + + + settings.strategy_name = args.sched_strategy + + settings.auto_derive() + return settings diff --git a/ktransformers/server/config/config.py b/ktransformers/server/config/config.py index 055be06..0d7c17b 100644 --- a/ktransformers/server/config/config.py +++ b/ktransformers/server/config/config.py @@ -100,6 +100,7 @@ class Config(metaclass=Singleton): # to make sure it consistent with previous version self.model_path: str = self.model_dir self.model_name: str = self.model.get("name", "") + self.architectures: str = self.model.get("name", "") self.model_device: str = self.model.get("device", "cuda:0") self.gguf_path: Optional[str] = self.model.get("gguf_path", None) self.use_cuda_graph = self.model.get("use_cuda_graph", True) diff --git a/ktransformers/server/requirements.txt b/ktransformers/server/requirements.txt index 76377d5..b4fc2cf 100644 --- a/ktransformers/server/requirements.txt +++ b/ktransformers/server/requirements.txt @@ -1,5 +1,5 @@ torch >= 2.3.0 -transformers == 4.43.2 +transformers == 4.51.3 fastapi >= 0.111.0 langchain >= 0.2.0 blessed >= 1.20.0 diff --git a/ktransformers/util/custom_gguf.py b/ktransformers/util/custom_gguf.py index 84ada15..b3d98d3 100644 --- a/ktransformers/util/custom_gguf.py +++ b/ktransformers/util/custom_gguf.py @@ -912,6 +912,9 @@ def translate_name_to_gguf(name): name = name.replace(".self_attn.q_a_proj", ".attn_q_a") name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm") name = name.replace(".self_attn.q_b_proj", ".attn_q_b") + + name = name.replace(".self_attn.q_norm", ".attn_q_norm") + name = name.replace(".self_attn.k_norm", ".attn_k_norm") name = name.replace(".shared_expert.", ".shared_experts.") name = name.replace(".shared_expert_", ".shared_experts_") @@ -922,17 +925,23 @@ def translate_name_to_gguf(name): name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp") name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp") name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp") + + name = name.replace(".mlp.experts", "") - name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps") - name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps") - name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps") name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.") name = name.replace(".block_sparse_moe.experts", "") + name = name.replace(".feed_forward.experts", "") + name = name.replace(".feed_forward.router", ".ffn_gate_inp") + name = name.replace(".feed_forward.shared_experts.down_proj", ".ffn_down_shexp") + name = name.replace(".feed_forward.shared_experts.gate_proj", ".ffn_gate_shexp") + name = name.replace(".feed_forward.shared_experts.up_proj", ".ffn_up_shexp") + return name + if __name__ == '__main__': gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH' loader = GGUFLoader(gguf_path) diff --git a/pyproject.toml b/pyproject.toml index 028c6a3..b307e61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dynamic = ["version"] dependencies = [ "torch >= 2.3.0", - "transformers == 4.43.2", + "transformers == 4.51.3", "fastapi >= 0.111.0", "uvicorn >= 0.30.1", "langchain >= 0.2.0", diff --git a/requirements-local_chat.txt b/requirements-local_chat.txt index 855b360..4ace3b4 100644 --- a/requirements-local_chat.txt +++ b/requirements-local_chat.txt @@ -1,5 +1,5 @@ fire -transformers==4.43.2 +transformers==4.51.3 numpy torch>=2.3.0 packaging diff --git a/third_party/custom_flashinfer b/third_party/custom_flashinfer index fd94393..af4259e 160000 --- a/third_party/custom_flashinfer +++ b/third_party/custom_flashinfer @@ -1 +1 @@ -Subproject commit fd94393fb5b8ba8bae9c0bd6ab1c2a429d81ac76 +Subproject commit af4259e8a33f095b419d1fd1733a50b22fc84c49