mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-03 19:20:04 +00:00
support qwen3, dont speak human language
This commit is contained in:
parent
f3d842a0ca
commit
3f9bbf1181
30 changed files with 3696 additions and 290 deletions
177
ktransformers/models/configuration_qwen2_moe.py
Normal file
177
ktransformers/models/configuration_qwen2_moe.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen2MoE model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen2MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a
|
||||
Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of
|
||||
Qwen1.5-MoE-A2.7B" [Qwen/Qwen1.5-MoE-A2.7B"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B").
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen2MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1408):
|
||||
Intermediate size of the routed expert.
|
||||
shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Intermediate size of the shared expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 4):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 60):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
|
||||
|
||||
>>> # Initializing a Qwen2MoE style configuration
|
||||
>>> configuration = Qwen2MoeConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen1.5-MoE-A2.7B" style configuration
|
||||
>>> model = Qwen2MoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen2_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=5632,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=1408,
|
||||
shared_expert_intermediate_size=5632,
|
||||
num_experts_per_tok=4,
|
||||
num_experts=60,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
233
ktransformers/models/configuration_qwen3_moe.py
Normal file
233
ktransformers/models/configuration_qwen3_moe.py
Normal file
|
@ -0,0 +1,233 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3MoE model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a
|
||||
Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 6144):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 768):
|
||||
Intermediate size of the routed expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 128):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabeling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
```python
|
||||
>>> from transformers import Qwen3MoeModel, Qwen3MoeConfig
|
||||
>>> # Initializing a Qwen3MoE style configuration
|
||||
>>> configuration = Qwen3MoeConfig()
|
||||
>>> # Initializing a model from the Qwen3-15B-A2B" style configuration
|
||||
>>> model = Qwen3MoeModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "qwen3_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3Moe`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=6144,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=4,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=768,
|
||||
num_experts_per_tok=8,
|
||||
num_experts=128,
|
||||
norm_topk_prob=False,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen3MoeConfig"]
|
|
@ -275,3 +275,59 @@ class KDeepSeekV3Cache(nn.Module):
|
|||
|
||||
return page_idx, page_offset
|
||||
|
||||
class KGQACache(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
page_size: int = 256,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.device("cuda:0"),
|
||||
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
self.k_caches = []
|
||||
self.v_caches = []
|
||||
|
||||
|
||||
def load(self, inference_context: sched_ext.InferenceContext):
|
||||
print(self.config.num_hidden_layers)
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
self.k_caches.append(
|
||||
inference_context.k_cache[0][i]
|
||||
)
|
||||
self.v_caches.append(
|
||||
inference_context.v_cache[0][i]
|
||||
)
|
||||
|
||||
|
||||
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
|
||||
|
||||
|
||||
|
||||
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
||||
page_offset = cache_position % self.page_size
|
||||
page_idx_local = cache_position // self.page_size
|
||||
query_ids = torch.zeros_like(cache_position)
|
||||
for i in range(len(q_indptr) - 1):
|
||||
start_idx = q_indptr[i]
|
||||
end_idx = q_indptr[i + 1]
|
||||
query_ids[start_idx:end_idx] = i
|
||||
page_idx = torch.zeros_like(page_idx_local)
|
||||
for i in range(bsz_tensors[0]):
|
||||
query_id = query_ids[i]
|
||||
local_block = page_idx_local[i]
|
||||
start_block = kv_indptr[query_id]
|
||||
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
||||
page_idx[i] = kv_indices[start_block + local_block]
|
||||
|
||||
return page_idx, page_offset
|
||||
|
||||
def get_k_cache(self, layer_idx):
|
||||
return self.k_caches[layer_idx]
|
||||
|
||||
def get_v_cache(self, layer_idx):
|
||||
return self.v_caches[layer_idx]
|
133
ktransformers/models/custom_modeling_qwen2_moe.py
Normal file
133
ktransformers/models/custom_modeling_qwen2_moe.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
Date: 2024-11-06 10:05:11
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-13 07:50:51
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.models.custom_cache import KGQACache
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoePreTrainedModel
|
||||
from ktransformers.models.configuration_qwen2_moe import Qwen2MoeConfig
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
import flashinfer
|
||||
|
||||
class KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
|
||||
|
||||
cache: KGQACache
|
||||
use_cuda_graph = False
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2MoeConfig,
|
||||
cache,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.model = Qwen2MoeModel(config)
|
||||
self.config = config
|
||||
self.cache = cache
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.attn = [None] * 10
|
||||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
|
||||
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
|
||||
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
features = []
|
||||
for i in range(batch.batch_size):
|
||||
tokens = batch.minibatch.tokens.contiguous()
|
||||
feature = (
|
||||
self.model.embed_tokens(tokens.to(torch.device('cpu')))
|
||||
.to(torch.bfloat16)
|
||||
.to(device=device)
|
||||
)
|
||||
features.append(feature)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ForwardBatchInput | None = None,
|
||||
features: List[torch.Tensor] | None = None,
|
||||
bsz_tensors: torch.Tensor | None = None,
|
||||
num_tokens_tensors: torch.Tensor | None = None,
|
||||
page_idx: torch.Tensor | None = None,
|
||||
page_offset: torch.Tensor | None = None,
|
||||
cuda_graph_idx: int | None = 0
|
||||
) -> ForwardBatchOutput:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
||||
|
||||
hidden_states = features[0]
|
||||
self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
position_ids=batch.minibatch.position_ids,
|
||||
wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
|
||||
page_idx=page_idx,
|
||||
page_offset=page_offset
|
||||
)
|
||||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
with torch.cuda.stream(current_stream):
|
||||
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
|
||||
forward_batch_output.logits.append(local_logit)
|
||||
|
||||
return forward_batch_output
|
||||
|
||||
|
||||
|
||||
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
cuda_graph_idx: int = 0
|
||||
):
|
||||
minibatch = batch.minibatch
|
||||
self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors,num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
|
||||
|
133
ktransformers/models/custom_modeling_qwen3_moe.py
Normal file
133
ktransformers/models/custom_modeling_qwen3_moe.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
"""
|
||||
Date: 2024-11-06 10:05:11
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-13 07:50:51
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.models.custom_cache import KGQACache
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeModel, Qwen3MoePreTrainedModel
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
import flashinfer
|
||||
|
||||
class KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel):
|
||||
|
||||
cache: KGQACache
|
||||
use_cuda_graph = False
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3MoeConfig,
|
||||
cache = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.model = Qwen3MoeModel(config)
|
||||
self.config = config
|
||||
self.cache = cache
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.attn = [None] * 10
|
||||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
|
||||
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
|
||||
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
features = []
|
||||
for i in range(batch.batch_size):
|
||||
tokens = batch.minibatch.tokens.contiguous()
|
||||
feature = (
|
||||
self.model.embed_tokens(tokens.to(torch.device('cpu')))
|
||||
.to(torch.bfloat16)
|
||||
.to(device=device)
|
||||
)
|
||||
features.append(feature)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ForwardBatchInput | None = None,
|
||||
features: List[torch.Tensor] | None = None,
|
||||
bsz_tensors: torch.Tensor | None = None,
|
||||
num_tokens_tensors: torch.Tensor | None = None,
|
||||
page_idx: torch.Tensor | None = None,
|
||||
page_offset: torch.Tensor | None = None,
|
||||
cuda_graph_idx: int | None = 0
|
||||
) -> ForwardBatchOutput:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
||||
|
||||
hidden_states = features[0]
|
||||
self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
position_ids=batch.minibatch.position_ids,
|
||||
wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
|
||||
page_idx=page_idx,
|
||||
page_offset=page_offset
|
||||
)
|
||||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
with torch.cuda.stream(current_stream):
|
||||
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
|
||||
forward_batch_output.logits.append(local_logit)
|
||||
|
||||
return forward_batch_output
|
||||
|
||||
|
||||
|
||||
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
cuda_graph_idx: int = 0
|
||||
):
|
||||
minibatch = batch.minibatch
|
||||
self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
|
||||
|
1472
ktransformers/models/modeling_qwen3_moe.py
Normal file
1472
ktransformers/models/modeling_qwen3_moe.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -411,4 +411,30 @@ class RotaryEmbeddingV4(BaseInjectedModule):
|
|||
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
|
||||
class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config,
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config
|
||||
)
|
|
@ -762,92 +762,3 @@ class KLlamaAttention(BaseInjectedModule):
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
num_tokens_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
|
||||
q = q.view(q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
|
||||
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
if kv_cache is not None:
|
||||
|
||||
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# q_nope.squeeze_(1)
|
||||
# q_pe.squeeze_(1)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, num_tokens_tensors)
|
||||
return attn_output
|
||||
|
|
287
ktransformers/operators/balance_serve_attention.py
Normal file
287
ktransformers/operators/balance_serve_attention.py
Normal file
|
@ -0,0 +1,287 @@
|
|||
'''
|
||||
Description :
|
||||
Author : Boxin Zhang
|
||||
Version : 0.2.5
|
||||
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||||
'''
|
||||
import torch
|
||||
from torch import nn
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
import logging
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from flashinfer import BatchMLAPagedAttentionWrapper
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache
|
||||
logger = logging.getLogger("attention")
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
|
||||
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
|
||||
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
|
||||
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
|
||||
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
|
||||
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
|
||||
self.q_absorb.weight.data = q_absorb
|
||||
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
|
||||
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
|
||||
self.out_absorb.weight.data = out_absorb
|
||||
#del self.orig_module.kv_b_proj
|
||||
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
|
||||
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
|
||||
return q_absorb, out_absorb
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KDeepSeekV3Cache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: BatchMLAPagedAttentionWrapper,
|
||||
num_tokens_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
q = self.q_proj(hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
|
||||
q = q.view(q_len, self.num_heads, self.q_head_dim)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
|
||||
compressed_kv, k_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
compressed_kv = compressed_kv.contiguous()
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
|
||||
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
|
||||
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
|
||||
|
||||
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
|
||||
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
q_pe = q_pe.squeeze(0)
|
||||
if kv_cache is not None:
|
||||
|
||||
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
|
||||
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
|
||||
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
|
||||
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
|
||||
|
||||
q_absorb, out_absorb = self.get_absorbed()
|
||||
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
|
||||
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# q_nope.squeeze_(1)
|
||||
# q_pe.squeeze_(1)
|
||||
|
||||
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
|
||||
attn_output = self.o_proj(attn_output, num_tokens_tensors)
|
||||
return attn_output
|
||||
|
||||
class KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
value_states = value_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
class KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
position_ids: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
page_idx: torch.Tensor,
|
||||
page_offset: torch.Tensor,
|
||||
):
|
||||
q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors)
|
||||
key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
||||
key_states = key_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
value_states = value_states.view(
|
||||
q_len, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
|
@ -689,6 +689,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
|
|||
from ktransformers.models.modeling_deepseek import DeepseekV2MoE
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
||||
|
@ -1267,3 +1268,229 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
self.unload()
|
||||
else:
|
||||
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
|
||||
|
||||
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
y_ = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
)
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
||||
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
# y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
324
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
Normal file
324
ktransformers/operators/flashinfer_batch_prefill_wrapper.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
import torch
|
||||
import flashinfer
|
||||
import gc
|
||||
try:
|
||||
from flash_attn import flash_attn_with_kvcache
|
||||
print("found flash_attn")
|
||||
|
||||
except ImportError:
|
||||
print("flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.")
|
||||
|
||||
from typing import Union, Optional
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
setup_seed(998244353)
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
global_dtype=torch.bfloat16
|
||||
global_device=torch.device("cuda",0)
|
||||
torch.cuda.set_device(0)
|
||||
torch.backends.cudnn.enabled =True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
class flashInferAttn():
|
||||
|
||||
float_workspace_buffer = None
|
||||
def __init__(self,
|
||||
max_batch_token,
|
||||
max_batch_size,
|
||||
max_pages,
|
||||
device = "cuda:0",
|
||||
kv_layout: str = "NHD",
|
||||
use_cuda_graph: bool = False,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.max_batch_token = max_batch_token
|
||||
self.kv_layout = kv_layout
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
if flashInferAttn.float_workspace_buffer is None:
|
||||
flashInferAttn.float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
|
||||
self.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)
|
||||
self.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device)
|
||||
self.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device)
|
||||
|
||||
# TODO: custom mask
|
||||
self.custom_mask_buf = None
|
||||
self.qk_indptr_buf = None
|
||||
self.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
flashInferAttn.float_workspace_buffer,
|
||||
self.kv_layout,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
qo_indptr_buf=self.qo_indptr_buf,
|
||||
paged_kv_indptr_buf=self.paged_kv_indptr_buf,
|
||||
paged_kv_indices_buf=self.paged_kv_indices_buf,
|
||||
paged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf,
|
||||
backend = "fa2",
|
||||
)
|
||||
|
||||
def plan(self,
|
||||
qo_indptr: torch.Tensor,
|
||||
paged_kv_indptr: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_last_page_len: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor,
|
||||
num_tokens_tensor: torch.Tensor,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool = True,
|
||||
pos_encoding_mode: str = "NONE",
|
||||
q_data_type: Union[str, torch.dtype] = torch.bfloat16,
|
||||
kv_data_type: Optional[Union[str, torch.dtype]] = None):
|
||||
|
||||
self.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True)
|
||||
self.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True)
|
||||
self.page_size = page_size
|
||||
self.warpper.plan(
|
||||
qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = causal,
|
||||
pos_encoding_mode = pos_encoding_mode,
|
||||
q_data_type = q_data_type,
|
||||
kv_data_type = kv_data_type
|
||||
)
|
||||
|
||||
def calc_batch_indices(self, ragged_size = None):
|
||||
if self.use_cuda_graph:
|
||||
self.batch_indices, self.positions = flashinfer.get_batch_indices_positions(
|
||||
self.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token)
|
||||
else:
|
||||
self.batch_indices, self.positions = flashinfer.get_batch_indices_positions(
|
||||
self.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size)
|
||||
|
||||
def forward(self, q, k_cache, v_cache, k, v):
|
||||
if self.use_cuda_graph:
|
||||
flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)
|
||||
return self.warpper.run(q, (k_cache, v_cache))
|
||||
else:
|
||||
flashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)
|
||||
return self.warpper.run(q, (k_cache, v_cache))
|
||||
|
||||
|
||||
def testCudaGraph():
|
||||
|
||||
# use max batch to create buffer
|
||||
batch_decode = 8
|
||||
prefill_chunk = 48
|
||||
past_kv_0 = 4090
|
||||
past_kv_1 = 4096
|
||||
raged_size = prefill_chunk + batch_decode
|
||||
num_key_value_heads = 8
|
||||
head_dim = 128
|
||||
num_attention_heads = 64
|
||||
page_size = 256
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
attn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True)
|
||||
|
||||
batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)
|
||||
|
||||
k_caches = []
|
||||
v_caches = []
|
||||
ks = []
|
||||
vs = []
|
||||
qs = []
|
||||
for layer_idx in range(3):
|
||||
k_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
v_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
ks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
vs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
|
||||
# warmup and capture small batch
|
||||
past_kv_0 = 250
|
||||
past_kv_1 = 256
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
|
||||
q_indptr[0] = 0
|
||||
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
|
||||
kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq
|
||||
kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)
|
||||
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
|
||||
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
|
||||
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
|
||||
|
||||
print(q_indptr)
|
||||
print(kv_indptr)
|
||||
print(kv_indices)
|
||||
print(kv_last_page_len)
|
||||
attn.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
batch_size_tensor,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = True,
|
||||
pos_encoding_mode="NONE",
|
||||
q_data_type=torch.bfloat16)
|
||||
|
||||
attn.calc_batch_indices(raged_size)
|
||||
for layer_idx in range(3):
|
||||
attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])
|
||||
torch.cuda.synchronize()
|
||||
|
||||
outs = []
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for layer_idx in range(3):
|
||||
outs.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]))
|
||||
g.replay()
|
||||
|
||||
kv_last_page_len[:1+batch_decode//2] = int(past_kv_0)
|
||||
kv_last_page_len[1+batch_decode//2:] = int(past_kv_1)
|
||||
for layer_idx in range(3):
|
||||
for i in range(batch_decode + 1):
|
||||
|
||||
qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
o_ref_i = flash_attn_with_kvcache(
|
||||
qi.unsqueeze(0),
|
||||
k_caches[layer_idx],
|
||||
v_caches[layer_idx],
|
||||
causal=True,
|
||||
block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),
|
||||
cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)
|
||||
)
|
||||
o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
print(layer_idx, i)
|
||||
torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)
|
||||
|
||||
# run another batch size use capture cuda graph
|
||||
past_kv_0 = 4090
|
||||
past_kv_1 = 4096
|
||||
prefill_chunk = 24
|
||||
batch_decode = 4
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
batch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)
|
||||
num_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32)
|
||||
|
||||
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
|
||||
q_indptr[0] = 0
|
||||
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
|
||||
kv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq
|
||||
kv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)
|
||||
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
|
||||
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
|
||||
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
|
||||
attn.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_len,
|
||||
batch_size_tensor,
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = True,
|
||||
pos_encoding_mode="NONE",
|
||||
q_data_type=torch.bfloat16)
|
||||
attn.calc_batch_indices(raged_size)
|
||||
g.replay()
|
||||
|
||||
kv_last_page_len[:1+batch_decode//2] = int(past_kv_0)
|
||||
kv_last_page_len[1+batch_decode//2:] = int(past_kv_1)
|
||||
for layer_idx in range(3):
|
||||
for i in range(batch_decode + 1):
|
||||
|
||||
qi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
o_ref_i = flash_attn_with_kvcache(
|
||||
qi.unsqueeze(0),
|
||||
k_caches[layer_idx],
|
||||
v_caches[layer_idx],
|
||||
causal=True,
|
||||
block_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),
|
||||
cache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)
|
||||
)
|
||||
o_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]
|
||||
print(layer_idx, i)
|
||||
torch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)
|
||||
|
||||
|
||||
|
||||
def testAttentionFlashInfer(
|
||||
):
|
||||
batch_decode = 32
|
||||
prefill_chunk = 64
|
||||
past_kv_0 = 510
|
||||
past_kv_1 = 512
|
||||
raged_size = prefill_chunk + batch_decode
|
||||
num_key_value_heads = 8
|
||||
head_dim = 128
|
||||
num_attention_heads = 64
|
||||
cases = 1
|
||||
page_size = 32
|
||||
num_pages_per_seq = (past_kv_1 + page_size - 1) // page_size
|
||||
total_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
|
||||
qs = []
|
||||
kvs = []
|
||||
q_indptrs = []
|
||||
kv_indptrs = []
|
||||
kv_indicess = []
|
||||
kv_last_page_lens = []
|
||||
wrappers = []
|
||||
for case_id in range(cases):
|
||||
kvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
qs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))
|
||||
q_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)
|
||||
q_indptr[0] = 0
|
||||
q_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)
|
||||
q_indptrs.append(q_indptr)
|
||||
kv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq)
|
||||
kv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32))
|
||||
kv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)
|
||||
kv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)
|
||||
kv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
wrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
qo_indptr_buf=q_indptrs[case_id],
|
||||
paged_kv_indptr_buf=kv_indptrs[case_id],
|
||||
paged_kv_indices_buf=kv_indicess[case_id],
|
||||
paged_kv_last_page_len_buf=kv_last_page_lens[case_id],
|
||||
))
|
||||
wrappers[case_id].plan(
|
||||
q_indptrs[case_id],
|
||||
kv_indptrs[case_id],
|
||||
kv_indicess[case_id],
|
||||
kv_last_page_lens[case_id],
|
||||
num_attention_heads,
|
||||
num_key_value_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal = True,
|
||||
pos_encoding_mode="ROPE_LLAMA",
|
||||
q_data_type=torch.bfloat16
|
||||
)
|
||||
|
||||
def custom_forward(case_id):
|
||||
out = wrappers[case_id].run(qs[case_id], kvs[case_id])
|
||||
|
||||
custom_forward(0)
|
||||
|
||||
# testCudaGraph()
|
||||
# pass
|
|
@ -122,3 +122,72 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
|
|||
self.e_score_correction_bias = None
|
||||
|
||||
|
||||
class KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module = None,
|
||||
generate_device: str = "cuda",
|
||||
generate_op: str| None = "KLinearMarlin",
|
||||
prefill_device: str = "cuda",
|
||||
prefill_op: str| None = "KLinearMarlin",
|
||||
use_quant: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
|
||||
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
self.generate_op = generate_op
|
||||
self.prefill_op = prefill_op
|
||||
self.is_windows = os.name == 'nt'
|
||||
self.use_quant = use_quant
|
||||
if not self.is_windows and use_quant:
|
||||
self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)
|
||||
self.gate_linear = KTransformersLinear(key + ".ffn_gate_inp",
|
||||
gguf_loader, config, self.gate_linear, #orig_module
|
||||
generate_device, generate_op, prefill_device, prefill_op)
|
||||
else:
|
||||
self.gate_linear = None
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
if self.is_windows:
|
||||
return self.orig_module.forward(hidden_states)
|
||||
|
||||
bsz, seq_len, h = hidden_states.shape
|
||||
### compute gating score
|
||||
hidden_states = hidden_states.view(-1, h)
|
||||
if self.use_quant:
|
||||
logits = self.gate_linear.forward(logits)
|
||||
else:
|
||||
logits = F.linear(
|
||||
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
|
||||
)
|
||||
|
||||
return grouped_topk(hidden_states, logits,
|
||||
self.top_k, self.norm_topk_prob,
|
||||
self.n_group, self.topk_group)
|
||||
|
||||
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
|
||||
if device is None: device = self.device
|
||||
if w is None: w = self.load_weights(device=device)
|
||||
|
||||
if isinstance(w, dict):
|
||||
self.weight_type = w["weight_type"]
|
||||
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
|
||||
self.orig_module.weight = nn.Parameter(w["weight"])
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
|
||||
else:
|
||||
raise ValueError("Invalid weight type")
|
||||
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
|
||||
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
|
||||
if not self.is_windows and self.use_quant:
|
||||
self.gate_linear.load(self.orig_module.weight)
|
||||
|
||||
def unload(self):
|
||||
if self.weight is not None:
|
||||
self.weight = None
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias = None
|
|
@ -26,6 +26,8 @@ from transformers import PretrainedConfig
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_gguf import GGUFLoader
|
||||
from flashinfer.norm import (
|
||||
|
@ -75,4 +77,89 @@ class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
|
|||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(config.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
bsz, hidden_size = x.shape
|
||||
x = x.view(-1, self.orig_module.hidden_size)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
out = out.view(bsz, hidden_size)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
|
|
@ -4,8 +4,7 @@ from ktransformers.util.custom_gguf import GGUFLoader
|
|||
from transformers import PretrainedConfig
|
||||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
|
||||
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP
|
||||
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
|
@ -18,6 +17,21 @@ class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
|||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.hidden_size, orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
||||
class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.intermediate_size)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
|
@ -56,7 +56,7 @@
|
|||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
|
||||
class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
|
|
@ -50,7 +50,7 @@
|
|||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
|
||||
class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
|
95
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
Normal file
95
ktransformers/optimize/optimize_rules/Qwen2-serve.yaml
Normal file
|
@ -0,0 +1,95 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# - match:
|
||||
# name: "^model\\.layers\\..*$" # regular expression
|
||||
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
# replace:
|
||||
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
# kwargs:
|
||||
# generate_device: "cuda"
|
||||
# prefill_device: "cuda"
|
||||
# generate_op: "VLinearMarlin"
|
||||
# prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "VLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2 # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KQwen2MoeModel"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm
|
||||
replace:
|
||||
class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP
|
||||
replace:
|
||||
class: ktransformers.operators.mlp.KQwen2MoeMLP
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
95
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
Normal file
95
ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml
Normal file
|
@ -0,0 +1,95 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.RotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "VLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# - match:
|
||||
# name: "^model\\.layers\\..*$" # regular expression
|
||||
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
# replace:
|
||||
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
# kwargs:
|
||||
# generate_device: "cuda"
|
||||
# prefill_device: "cuda"
|
||||
# generate_op: "VLinearMarlin"
|
||||
# prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2 # mlp module with custom forward function
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: "KExpertsTorch"
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
- match:
|
||||
name: "^model$"
|
||||
replace:
|
||||
class: "ktransformers.operators.models.KQwen2MoeModel"
|
||||
kwargs:
|
||||
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm
|
||||
replace:
|
||||
class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP
|
||||
replace:
|
||||
class: ktransformers.operators.mlp.KQwen2MoeMLP
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
|
@ -20,6 +20,7 @@ class ArgumentParser:
|
|||
parser.add_argument(
|
||||
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
|
||||
)
|
||||
parser.add_argument("--architectures", type=str, default=self.cfg.model_name)
|
||||
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
|
||||
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
|
||||
|
@ -93,6 +94,7 @@ class ArgumentParser:
|
|||
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
|
||||
parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
|
||||
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)
|
||||
# parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=False)
|
||||
|
||||
# web config
|
||||
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
|
||||
|
@ -137,7 +139,7 @@ class ArgumentParser:
|
|||
self.cfg.server_port = args.port
|
||||
self.cfg.user_force_think = args.force_think
|
||||
|
||||
args.gpu_memory_size = args.cache_lens*2*576*61
|
||||
args.gpu_memory_size = 4*1024*1024*1024 # TODO: set this to the actual GPU memory size
|
||||
self.cfg.gpu_memory_size = args.gpu_memory_size
|
||||
free_ports = get_free_ports(3, [args.port])
|
||||
args.sched_port = free_ports[0]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache
|
||||
from ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
|
@ -22,6 +22,9 @@ from ktransformers.server.config.log import logger
|
|||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
|
@ -53,8 +56,10 @@ ktransformer_rules_dir = (
|
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
|
||||
)
|
||||
default_optimize_rules = {
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
|
||||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "Moonlight-16B-A3B-serve.yaml",
|
||||
# "DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,7 +110,7 @@ class Engine:
|
|||
model_runner: ModelRunner
|
||||
sampler: Sampler
|
||||
query_manager: QueryManager
|
||||
cache: KDeepSeekV3Cache
|
||||
cache: KDeepSeekV3Cache | KGQACache
|
||||
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):
|
||||
self.args = args
|
||||
|
||||
|
@ -117,17 +122,32 @@ class Engine:
|
|||
self.device = self.args.device
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
if args.model_name == "Qwen3Moe":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
else:
|
||||
assert False, f"model {args.model_name} not supported"
|
||||
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
||||
with torch.device("meta"):
|
||||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV3ForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "DeepseekV2ForCausalLM":
|
||||
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
|
||||
self.model = KDeepseekV2ForCausalLM(config, self.cache)
|
||||
# print(self.block_num)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
|
@ -176,9 +196,12 @@ class Engine:
|
|||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
#@TODO add config
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, 1024 ,args.max_batch_size, self.block_num) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
self.sampler = Sampler()
|
||||
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
|
||||
|
||||
|
@ -231,7 +254,7 @@ class Engine:
|
|||
|
||||
if self.batch is not None:
|
||||
self.model_runner.sync()
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
|
||||
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms, {1000/self.model_runner.model_time:.3f} tokens/s")
|
||||
# if self.rank == 0:
|
||||
|
||||
generated_tokens, probs = self.sampling( self.model_runner.output)
|
||||
|
|
|
@ -281,4 +281,4 @@ class ForwardBatchOutput:
|
|||
self.generated_tokens_num = []
|
||||
self.top_ps = []
|
||||
self.temperatures = []
|
||||
pass
|
||||
self.num_batchs = 1
|
|
@ -27,6 +27,8 @@ from ktransformers.server.balance_serve.inference.forward_batch import ForwardBa
|
|||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
|
||||
|
@ -40,11 +42,11 @@ def deduplicate_and_sort(lst):
|
|||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
model: KDeepseekV3ForCausalLM
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
|
||||
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256):
|
||||
def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8):
|
||||
|
||||
self.stream = torch.cuda.Stream(device=device)
|
||||
# 先注释掉
|
||||
|
@ -58,120 +60,92 @@ class ModelRunner:
|
|||
self.use_cuda_graph = use_cuda_graph
|
||||
self.model_time = 0
|
||||
self.page_size = page_size
|
||||
self.block_num = block_num
|
||||
# GPU timing for model execution
|
||||
self.start_model_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_model_event = torch.cuda.Event(enable_timing=True)
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
|
||||
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
else:
|
||||
self.graphs = torch.cuda.CUDAGraph()
|
||||
self.page_idx_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
|
||||
self.page_offset_buf = torch.zeros([self.cuda_graphs], dtype=torch.int32, device = self.device)
|
||||
|
||||
self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]
|
||||
self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]
|
||||
|
||||
self.num_mini_batches = num_mini_batches
|
||||
|
||||
self.max_chunk_size = max_chunk_size
|
||||
|
||||
self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
|
||||
self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)
|
||||
|
||||
def model_attn_plan(self, batch, cuda_graph_idx=0):
|
||||
if isinstance(self.model, KDeepseekV3ForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||
head_dim=self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
page_size=self.model.cache.page_size, causal=True,
|
||||
q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)
|
||||
else:
|
||||
assert False, "model type not supported"
|
||||
|
||||
|
||||
def warmup(self):
|
||||
|
||||
def capture_graphs(cuda_graph_idx=-1):
|
||||
if cuda_graph_idx != -1:
|
||||
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
|
||||
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
|
||||
else:
|
||||
with torch.cuda.graph(self.graphs, pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
self.graph_memory_pool = self.graphs.pool()
|
||||
def capture_graphs(cuda_graph_idx):
|
||||
with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):
|
||||
self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)
|
||||
self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()
|
||||
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
self.input = []
|
||||
self.features_buf = []
|
||||
self.outputs_buf = []
|
||||
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
for i in range(len(self.cuda_graphs)):
|
||||
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
|
||||
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens = self.cuda_graphs[i]))
|
||||
self.input = []
|
||||
self.features_buf = []
|
||||
self.outputs_buf = []
|
||||
self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)
|
||||
for i in range(len(self.cuda_graphs)):
|
||||
prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0 #@TODO only supprot 2 prefill batch
|
||||
self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens=self.cuda_graphs[i]))
|
||||
|
||||
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
|
||||
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
self.num_tokens_tensor_buf[0] = num_tokens
|
||||
self.features_buf.append(self.model.batch_embeddings(self.input[i]))
|
||||
batch_size = self.input[i].minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
|
||||
self.model.flash_infer_attn_plan(self.input[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
|
||||
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
|
||||
self.outputs_buf.append(None)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for warm_up_iters in range(11):
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i])
|
||||
torch.cuda.synchronize()
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
self.num_tokens_tensor_buf[0] = num_tokens
|
||||
|
||||
capture_graphs(i)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.graphs[i].replay()
|
||||
|
||||
self.sync(calc_time=False)
|
||||
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
|
||||
else:
|
||||
self.input = ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches)
|
||||
|
||||
self.features_buf = self.model.batch_embeddings(self.input)
|
||||
batch_size = self.input.minibatch.q_indptr.size(0)-1
|
||||
num_tokens = self.features_buf[0].size(0)
|
||||
self.model_attn_plan(self.input[i], i)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
|
||||
|
||||
self.bsz_tensor_buf = torch.tensor([batch_size], dtype=torch.int32, device=self.device)
|
||||
self.num_tokens_tensor_buf = torch.tensor([num_tokens], dtype=torch.int32, device=self.device)
|
||||
self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
|
||||
|
||||
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
|
||||
|
||||
self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
|
||||
self.outputs_buf.append(None)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for warm_up_iters in range(11):
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], cuda_graph_idx=i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def capture_graphs():
|
||||
with torch.cuda.graph(self.graphs, stream=self.stream):
|
||||
self.outputs_buf = self.model(self.input, self.features_buf, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf, self.page_offset_buf)
|
||||
# self.graph_memory_pool = self.graphs.pool()
|
||||
self.outputs_buf[i].num_batchs = batch_size
|
||||
|
||||
|
||||
capture_graphs()
|
||||
capture_graphs(i)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.graphs.replay()
|
||||
self.graphs[i].replay()
|
||||
|
||||
self.sync(calc_time=False)
|
||||
print("warmup finished.")
|
||||
print(f"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.")
|
||||
|
||||
def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):
|
||||
with torch.cuda.stream(self.stream):
|
||||
|
@ -189,107 +163,54 @@ class ModelRunner:
|
|||
|
||||
|
||||
|
||||
if isinstance(self.cuda_graphs, list):
|
||||
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
|
||||
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
|
||||
if cuda_graph_idx == len(self.cuda_graphs):
|
||||
assert False, "num_tokens is too large"
|
||||
else:
|
||||
cuda_graph_idx = -1
|
||||
# cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens
|
||||
cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))
|
||||
if not self.use_cuda_graph:
|
||||
cuda_graph_idx = 0
|
||||
# if cuda_graph_idx == len(self.cuda_graphs):
|
||||
# assert False, "num_tokens is too large"
|
||||
|
||||
if self.use_cuda_graph:
|
||||
if cuda_graph_idx != -1:
|
||||
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
|
||||
else:
|
||||
self.input.fill(batch, query_manager, self.page_size)
|
||||
self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)
|
||||
else:
|
||||
self.input = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)
|
||||
|
||||
self.input = [ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)]
|
||||
|
||||
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
if self.use_cuda_graph:
|
||||
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
|
||||
else:
|
||||
self.features = self.model.batch_embeddings(self.input, device=self.device)
|
||||
|
||||
self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)
|
||||
|
||||
|
||||
self.bsz_tensor_buf.copy_(batch_size)
|
||||
self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))
|
||||
|
||||
if self.use_cuda_graph:
|
||||
if cuda_graph_idx != -1:
|
||||
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
|
||||
else:
|
||||
self.features_buf[0].copy_(self.features[0], non_blocking=True)
|
||||
"""
|
||||
if num_tokens_0 > 64:
|
||||
padded_num_tokens_0 = pad_num_tokens(num_tokens_0)
|
||||
self.features_buf[0][num_tokens_0:padded_num_tokens_0] = 0
|
||||
"""
|
||||
#self.input.forward_minibatchs[0].print()
|
||||
# print([[hash(k[i].float().cpu().numpy().tobytes()) for i in self.input.forward_minibatchs[0].kv_indices] for k in self.model.cache.k_caches])
|
||||
# print(f"overlap: {overlap}, is_compute_bound: {is_compute_bound}")
|
||||
self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)
|
||||
|
||||
# self.model.flash_infer_attn_plan(self.input, self.bsz_tensors, self.num_tokens_tensors)
|
||||
|
||||
"""
|
||||
self.model_attn_plan(self.input[cuda_graph_idx], cuda_graph_idx)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
print("before replay features_buf", self.features_buf[0])
|
||||
print("features_buf addr", self.features_buf[0].data_ptr())
|
||||
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
|
||||
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
|
||||
|
||||
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
|
||||
else:
|
||||
print("before run features", self.features[0])
|
||||
"""
|
||||
if cuda_graph_idx != -1 and self.use_cuda_graph:
|
||||
self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
|
||||
self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)
|
||||
self.end_model_event.record(self.stream)
|
||||
|
||||
self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())
|
||||
else:
|
||||
self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]
|
||||
self.end_model_event.record(self.stream)
|
||||
else:
|
||||
self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.model.cache.get_page_table(self.input.minibatch.position_ids, self.input.minibatch.q_indptr, self.input.minibatch.kv_indptr, self.input.minibatch.kv_indices, self.num_tokens_tensor_buf)
|
||||
if self.use_cuda_graph:
|
||||
self.page_idx_buf[:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[:num_tokens].copy_(page_offset[:num_tokens])
|
||||
self.page_idx_buf[num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size - 1)
|
||||
self.replay(cuda_graph_idx)
|
||||
self.output = ForwardBatchOutput()
|
||||
|
||||
self.output.top_ps.append(self.input.minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input.minibatch.temperatures)
|
||||
|
||||
self.output.logits.append(self.outputs_buf.logits[0][self.input.minibatch.logits_start].clone())
|
||||
else:
|
||||
self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)
|
||||
self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]
|
||||
self.output.top_ps.append(self.input.minibatch.top_ps)
|
||||
self.output.temperatures.append(self.input.minibatch.temperatures)
|
||||
|
||||
self.end_model_event.record(self.stream)
|
||||
|
||||
if not self.use_cuda_graph:
|
||||
self.output.num_batchs = self.input.batch_size
|
||||
else:
|
||||
self.output.num_batchs = self.input[cuda_graph_idx].batch_size
|
||||
|
||||
|
||||
def replay(self, cuda_graph_idx=-1):
|
||||
|
|
|
@ -10,7 +10,7 @@ current_file_path = os.path.abspath(__file__)
|
|||
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
import pickle
|
||||
import argparse
|
||||
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings
|
||||
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe
|
||||
|
||||
|
||||
|
||||
|
@ -209,5 +209,10 @@ if __name__ == '__main__':
|
|||
args = parser.parse_args()
|
||||
with open(args.config, "rb") as f:
|
||||
main_args = pickle.load(f)
|
||||
settings = create_sched_settings(main_args)
|
||||
if main_args.architectures == "Qwen2MoeForCausalLM":
|
||||
settings = create_sched_settings_qwen2moe(main_args)
|
||||
elif main_args.architectures == "Qwen3MoeForCausalLM":
|
||||
settings = create_sched_settings_qwen3moe(main_args)
|
||||
else:
|
||||
settings = create_sched_settings(main_args)
|
||||
start_server(settings, main_args)
|
||||
|
|
|
@ -11,6 +11,8 @@ from time import sleep
|
|||
import sched_ext
|
||||
from transformers import AutoConfig
|
||||
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
|
||||
def create_sched_settings(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
|
@ -64,7 +66,111 @@ def create_sched_settings(args):
|
|||
return settings
|
||||
|
||||
|
||||
|
||||
def create_sched_settings_qwen2moe(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
input_model_settings = sched_ext.ModelSettings()
|
||||
input_model_settings.model_path = args.model_dir
|
||||
input_model_settings.params_count = int(0)
|
||||
model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
input_model_settings.layer_count = model_config.num_hidden_layers
|
||||
input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"]
|
||||
input_model_settings.k_head_dim = 128
|
||||
input_model_settings.bytes_per_params = 2
|
||||
input_model_settings.bytes_per_kv_cache_element = 2
|
||||
settings = sched_ext.Settings()
|
||||
settings.model_name = model_name
|
||||
settings.quant_type = "BF16"
|
||||
settings.model_settings = input_model_settings
|
||||
settings.page_size = args.page_size
|
||||
settings.gpu_device_count = 1 # tp
|
||||
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
|
||||
# settings.gpu_memory_size = args.cache_lens*576*2
|
||||
settings.gpu_memory_size = args.gpu_memory_size
|
||||
settings.memory_utilization_percentage = args.utilization_percentage
|
||||
max_batch_size = args.max_batch_size
|
||||
chunk_size = args.chunk_size
|
||||
|
||||
max_decode_batch_size = max_batch_size - 2
|
||||
|
||||
settings.max_batch_size = max_batch_size
|
||||
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
|
||||
settings.sample_options = default_sample_options
|
||||
settings.sched_metrics_port = args.sched_metrics_port
|
||||
settings.gpu_only = args.memory_gpu_only
|
||||
settings.use_self_defined_head_dim = False
|
||||
settings.self_defined_head_dim = 576
|
||||
settings.full_kv_cache_on_each_gpu = True
|
||||
settings.k_cache_on = True
|
||||
settings.v_cache_on = True
|
||||
|
||||
settings.kvc2_root_path = '/mnt/data/persist-kvc'
|
||||
settings.kvc2_config_path = args.kvc2_config_dir
|
||||
settings.memory_pool_size_GB = args.cpu_memory_size_GB
|
||||
settings.evict_count = 40
|
||||
settings.kvc2_metrics_port = args.kvc2_metrics_port
|
||||
settings.load_from_disk = False
|
||||
settings.save_to_disk = True
|
||||
|
||||
|
||||
settings.strategy_name = args.sched_strategy
|
||||
|
||||
settings.auto_derive()
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
def create_sched_settings_qwen3moe(args):
|
||||
default_sample_options = sched_ext.SampleOptions()
|
||||
model_name = os.path.basename(os.path.normpath(args.model_dir))
|
||||
input_model_settings = sched_ext.ModelSettings()
|
||||
input_model_settings.model_path = args.model_dir
|
||||
input_model_settings.params_count = int(0)
|
||||
model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
input_model_settings.layer_count = model_config.num_hidden_layers
|
||||
input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config["num_key_value_heads"]
|
||||
input_model_settings.k_head_dim = 128
|
||||
input_model_settings.bytes_per_params = 2
|
||||
input_model_settings.bytes_per_kv_cache_element = 2
|
||||
settings = sched_ext.Settings()
|
||||
settings.model_name = model_name
|
||||
settings.quant_type = "BF16"
|
||||
settings.model_settings = input_model_settings
|
||||
settings.page_size = args.page_size
|
||||
settings.gpu_device_count = 1 # tp
|
||||
settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]
|
||||
# settings.gpu_memory_size = args.cache_lens*576*2
|
||||
settings.gpu_memory_size = args.gpu_memory_size
|
||||
settings.memory_utilization_percentage = args.utilization_percentage
|
||||
max_batch_size = args.max_batch_size
|
||||
chunk_size = args.chunk_size
|
||||
|
||||
max_decode_batch_size = max_batch_size - 2
|
||||
|
||||
settings.max_batch_size = max_batch_size
|
||||
settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2
|
||||
settings.sample_options = default_sample_options
|
||||
settings.sched_metrics_port = args.sched_metrics_port
|
||||
settings.gpu_only = args.memory_gpu_only
|
||||
settings.use_self_defined_head_dim = False
|
||||
settings.self_defined_head_dim = 576
|
||||
settings.full_kv_cache_on_each_gpu = True
|
||||
settings.k_cache_on = True
|
||||
settings.v_cache_on = True
|
||||
|
||||
settings.kvc2_root_path = '/mnt/data/persist-kvc'
|
||||
settings.kvc2_config_path = args.kvc2_config_dir
|
||||
settings.memory_pool_size_GB = args.cpu_memory_size_GB
|
||||
settings.evict_count = 40
|
||||
settings.kvc2_metrics_port = args.kvc2_metrics_port
|
||||
settings.load_from_disk = False
|
||||
settings.save_to_disk = True
|
||||
|
||||
|
||||
settings.strategy_name = args.sched_strategy
|
||||
|
||||
settings.auto_derive()
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -100,6 +100,7 @@ class Config(metaclass=Singleton):
|
|||
# to make sure it consistent with previous version
|
||||
self.model_path: str = self.model_dir
|
||||
self.model_name: str = self.model.get("name", "")
|
||||
self.architectures: str = self.model.get("name", "")
|
||||
self.model_device: str = self.model.get("device", "cuda:0")
|
||||
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
|
||||
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
torch >= 2.3.0
|
||||
transformers == 4.43.2
|
||||
transformers == 4.51.3
|
||||
fastapi >= 0.111.0
|
||||
langchain >= 0.2.0
|
||||
blessed >= 1.20.0
|
||||
|
|
|
@ -912,6 +912,9 @@ def translate_name_to_gguf(name):
|
|||
name = name.replace(".self_attn.q_a_proj", ".attn_q_a")
|
||||
name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm")
|
||||
name = name.replace(".self_attn.q_b_proj", ".attn_q_b")
|
||||
|
||||
name = name.replace(".self_attn.q_norm", ".attn_q_norm")
|
||||
name = name.replace(".self_attn.k_norm", ".attn_k_norm")
|
||||
|
||||
name = name.replace(".shared_expert.", ".shared_experts.")
|
||||
name = name.replace(".shared_expert_", ".shared_experts_")
|
||||
|
@ -922,17 +925,23 @@ def translate_name_to_gguf(name):
|
|||
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
|
||||
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
|
||||
name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp")
|
||||
|
||||
|
||||
name = name.replace(".mlp.experts", "")
|
||||
name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps")
|
||||
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
|
||||
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
|
||||
|
||||
|
||||
name = name.replace(".block_sparse_moe.gate.", ".ffn_gate_inp.")
|
||||
name = name.replace(".block_sparse_moe.experts", "")
|
||||
|
||||
name = name.replace(".feed_forward.experts", "")
|
||||
name = name.replace(".feed_forward.router", ".ffn_gate_inp")
|
||||
name = name.replace(".feed_forward.shared_experts.down_proj", ".ffn_down_shexp")
|
||||
name = name.replace(".feed_forward.shared_experts.gate_proj", ".ffn_gate_shexp")
|
||||
name = name.replace(".feed_forward.shared_experts.up_proj", ".ffn_up_shexp")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
|
||||
loader = GGUFLoader(gguf_path)
|
||||
|
|
|
@ -16,7 +16,7 @@ dynamic = ["version"]
|
|||
|
||||
dependencies = [
|
||||
"torch >= 2.3.0",
|
||||
"transformers == 4.43.2",
|
||||
"transformers == 4.51.3",
|
||||
"fastapi >= 0.111.0",
|
||||
"uvicorn >= 0.30.1",
|
||||
"langchain >= 0.2.0",
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
fire
|
||||
transformers==4.43.2
|
||||
transformers==4.51.3
|
||||
numpy
|
||||
torch>=2.3.0
|
||||
packaging
|
||||
|
|
2
third_party/custom_flashinfer
vendored
2
third_party/custom_flashinfer
vendored
|
@ -1 +1 @@
|
|||
Subproject commit fd94393fb5b8ba8bae9c0bd6ab1c2a429d81ac76
|
||||
Subproject commit af4259e8a33f095b419d1fd1733a50b22fc84c49
|
Loading…
Add table
Reference in a new issue