mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-14 17:19:42 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
177
ktransformers/models/configuration_qwen2_moe.py
Normal file
177
ktransformers/models/configuration_qwen2_moe.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen2MoE model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a
|
||||
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1408):
|
||||
Intermediate size of the routed expert.
|
||||
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Intermediate size of the shared expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 4):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 60):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
|
||||
|
||||
>>> # Initializing a Qwen2MoE style configuration
|
||||
>>> configuration = Qwen2MoeConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration
|
||||
>>> model = Qwen2MoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=5632,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=1408,
|
||||
shared_expert_intermediate_size=5632,
|
||||
num_experts_per_tok=4,
|
||||
num_experts=60,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
233
ktransformers/models/configuration_qwen3_moe.py
Normal file
233
ktransformers/models/configuration_qwen3_moe.py
Normal file
|
@ -0,0 +1,233 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3MoE model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a
|
||||
Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 6144):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 768):
|
||||
Intermediate size of the routed expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 128):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
```python
|
||||
>>> from transformers import Qwen3MoeModel, Qwen3MoeConfig
|
||||
>>> # Initializing a Qwen3MoE style configuration
|
||||
>>> configuration = Qwen3MoeConfig()
|
||||
>>> # Initializing a model from the Qwen3-15B-A2B" style configuration
|
||||
>>> model = Qwen3MoeModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3Moe`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=768,
|
||||
num_experts_per_tok=8,
|
||||
num_experts=128,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen3MoeConfig"]
|
|
@ -275,3 +275,59 @@ class KDeepSeekV3Cache(nn.Module):
|
|||
|
||||
return page_idx, page_offset
|
||||
|
||||
class KGQACache(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
page_size: int = 256,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.device("cuda:0"),
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
|
||||
|
||||
def load(self, inference_context: sched_ext.InferenceContext):
|
||||
print(self.config.num_hidden_layers)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
self.k_caches.append(
|
||||
inference_context.k_cache[0][i]
|
||||
)
|
||||
self.v_caches.append(
|
||||
inference_context.v_cache[0][i]
|
||||
)
|
||||
|
||||
|
||||
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
|
||||
|
||||
|
||||
|
||||
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
||||
page_offset = cache_position % self.page_size
|
||||
page_idx_local = cache_position // self.page_size
|
||||
query_ids = torch.zeros_like(cache_position)
|
||||
for i in range(len(q_indptr) - 1):
|
||||
start_idx = q_indptr[i]
|
||||
end_idx = q_indptr[i + 1]
|
||||
query_ids[start_idx:end_idx] = i
|
||||
page_idx = torch.zeros_like(page_idx_local)
|
||||
for i in range(bsz_tensors[0]):
|
||||
query_id = query_ids[i]
|
||||
local_block = page_idx_local[i]
|
||||
start_block = kv_indptr[query_id]
|
||||
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
||||
page_idx[i] = kv_indices[start_block + local_block]
|
||||
|
||||
return page_idx, page_offset
|
||||
|
||||
def get_k_cache(self, layer_idx):
|
||||
return self.k_caches[layer_idx]
|
||||
|
||||
def get_v_cache(self, layer_idx):
|
||||
return self.v_caches[layer_idx]
|
133
ktransformers/models/custom_modeling_qwen2_moe.py
Normal file
133
ktransformers/models/custom_modeling_qwen2_moe.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
Date: 2024-11-06 10:05:11
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-13 07:50:51
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.models.custom_cache import KGQACache
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoePreTrainedModel
|
||||
from ktransformers.models.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
import flashinfer
|
||||
|
||||
class KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
|
||||
cache: KGQACache
|
||||
use_cuda_graph = False
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2MoeConfig,
|
||||
cache,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.model = Qwen2MoeModel(config)
|
||||
self.config = config
|
||||
self.cache = cache
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.attn = [None] * 10
|
||||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
|
||||
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
|
||||
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
features = []
|
||||
for i in range(batch.batch_size):
|
||||
tokens = batch.minibatch.tokens.contiguous()
|
||||
feature = (
|
||||
self.model.embed_tokens(tokens.to(torch.device('cpu')))
|
||||
.to(torch.bfloat16)
|
||||
.to(device=device)
|
||||
)
|
||||
features.append(feature)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ForwardBatchInput | None = None,
|
||||
features: List[torch.Tensor] | None = None,
|
||||
bsz_tensors: torch.Tensor | None = None,
|
||||
num_tokens_tensors: torch.Tensor | None = None,
|
||||
page_idx: torch.Tensor | None = None,
|
||||
page_offset: torch.Tensor | None = None,
|
||||
cuda_graph_idx: int | None = 0
|
||||
) -> ForwardBatchOutput:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
||||
|
||||
hidden_states = features[0]
|
||||
self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
position_ids=batch.minibatch.position_ids,
|
||||
wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
|
||||
page_idx=page_idx,
|
||||
page_offset=page_offset
|
||||
)
|
||||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
with torch.cuda.stream(current_stream):
|
||||
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
|
||||
forward_batch_output.logits.append(local_logit)
|
||||
|
||||
return forward_batch_output
|
||||
|
||||
|
||||
|
||||
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
cuda_graph_idx: int = 0
|
||||
):
|
||||
minibatch = batch.minibatch
|
||||
self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors,num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
|
||||
|
133
ktransformers/models/custom_modeling_qwen3_moe.py
Normal file
133
ktransformers/models/custom_modeling_qwen3_moe.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
Date: 2024-11-06 10:05:11
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-13 07:50:51
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.models.custom_cache import KGQACache
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeModel, Qwen3MoePreTrainedModel
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
import flashinfer
|
||||
|
||||
class KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel):
|
||||
|
||||
cache: KGQACache
|
||||
use_cuda_graph = False
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3MoeConfig,
|
||||
cache = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.model = Qwen3MoeModel(config)
|
||||
self.config = config
|
||||
self.cache = cache
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.attn = [None] * 10
|
||||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
|
||||
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
|
||||
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
features = []
|
||||
for i in range(batch.batch_size):
|
||||
tokens = batch.minibatch.tokens.contiguous()
|
||||
feature = (
|
||||
self.model.embed_tokens(tokens.to(torch.device('cpu')))
|
||||
.to(torch.bfloat16)
|
||||
.to(device=device)
|
||||
)
|
||||
features.append(feature)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ForwardBatchInput | None = None,
|
||||
features: List[torch.Tensor] | None = None,
|
||||
bsz_tensors: torch.Tensor | None = None,
|
||||
num_tokens_tensors: torch.Tensor | None = None,
|
||||
page_idx: torch.Tensor | None = None,
|
||||
page_offset: torch.Tensor | None = None,
|
||||
cuda_graph_idx: int | None = 0
|
||||
) -> ForwardBatchOutput:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
||||
|
||||
hidden_states = features[0]
|
||||
self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
position_ids=batch.minibatch.position_ids,
|
||||
wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
|
||||
page_idx=page_idx,
|
||||
page_offset=page_offset
|
||||
)
|
||||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
with torch.cuda.stream(current_stream):
|
||||
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
|
||||
forward_batch_output.logits.append(local_logit)
|
||||
|
||||
return forward_batch_output
|
||||
|
||||
|
||||
|
||||
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
cuda_graph_idx: int = 0
|
||||
):
|
||||
minibatch = batch.minibatch
|
||||
self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
|
||||
|
1472
ktransformers/models/modeling_qwen3_moe.py
Normal file
1472
ktransformers/models/modeling_qwen3_moe.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue