mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-10 15:29:39 +00:00
support smt and glm4
This commit is contained in:
parent
1677e90092
commit
b66d96db97
18 changed files with 3519 additions and 16 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -27,3 +27,7 @@ mmlu_result*
|
|||
ktransformers/ktransformers_ext/cuda_musa/
|
||||
test_prompt.txt
|
||||
csrc/demo
|
||||
build*
|
||||
CMakeFiles/
|
||||
kvc2/
|
||||
sched/
|
1
ktransformers/ktransformers
Symbolic link
1
ktransformers/ktransformers
Symbolic link
|
@ -0,0 +1 @@
|
|||
/home/djw/py311_717/ktransformers/ktransformers
|
242
ktransformers/models/configuration_glm4_moe.py
Normal file
242
ktransformers/models/configuration_glm4_moe.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_glm4_moe.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
|
||||
|
||||
class Glm4MoeConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a
|
||||
Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151552):
|
||||
Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Glm4MoeModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 10944):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 46):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 96):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
||||
The factor of the partial rotary position.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 131072):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1408):
|
||||
Intermediate size of the routed expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 8):
|
||||
number of experts per token.
|
||||
n_shared_experts (`int`, *optional*, defaults to 1):
|
||||
Number of shared experts.
|
||||
n_routed_experts (`int`, *optional*, defaults to 128):
|
||||
Number of routed experts.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor or routed experts.
|
||||
n_group (`int`, *optional*, defaults to 1):
|
||||
Number of groups for routed experts.
|
||||
topk_group (`int`, *optional*, defaults to 1):
|
||||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 1):
|
||||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||
\--k dense layers--/
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the topk probabilities.
|
||||
use_qk_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use query-key normalization in the attention
|
||||
```python
|
||||
>>> from transformers import Glm4MoeModel, Glm4MoeConfig
|
||||
|
||||
>>> # Initializing a Glm4Moe style configuration
|
||||
>>> configuration = Glm4MoeConfig()
|
||||
|
||||
>>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration
|
||||
>>> model = Glm4MoeModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "glm4_moe"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Glm4Moe`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.experts.*.gate_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.up_proj": "colwise",
|
||||
"layers.*.mlp.experts.*.down_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151552,
|
||||
hidden_size=4096,
|
||||
intermediate_size=10944,
|
||||
num_hidden_layers=46,
|
||||
num_attention_heads=96,
|
||||
partial_rotary_factor=0.5,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=131072,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
moe_intermediate_size=1408,
|
||||
num_experts_per_tok=8,
|
||||
n_shared_experts=1,
|
||||
n_routed_experts=128,
|
||||
routed_scaling_factor=1.0,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
first_k_dense_replace=1,
|
||||
norm_topk_prob=True,
|
||||
use_qk_norm=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
# MoE arguments
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Glm4MoeConfig"]
|
177
ktransformers/models/configuration_smallthinker.py
Normal file
177
ktransformers/models/configuration_smallthinker.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
# coding=utf-8
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
class SmallthinkerConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`SmallthinkerModel`].
|
||||
It is used to instantiate a Smallthinker model according to the specified arguments, defining the model architecture.
|
||||
The default values for each of the parameters are the same as the ones used in the original Smallthinker 4B model.
|
||||
|
||||
General configs:
|
||||
- model_type: "smallthinker"
|
||||
- model_name
|
||||
- num_hidden_layers
|
||||
- hidden_size
|
||||
|
||||
Tokenizer configs:
|
||||
- pad_token_id
|
||||
- bos_token_id
|
||||
- eos_token_id
|
||||
|
||||
Embedding configs:
|
||||
- vocab_size
|
||||
|
||||
RMSNorm configs:
|
||||
- rms_norm_eps
|
||||
|
||||
Attention configs:
|
||||
- num_attention_heads
|
||||
- num_key_value_heads
|
||||
- head_dim
|
||||
- use_cache
|
||||
- use_qk_norm
|
||||
- rope_layout: array of 0 or 1s, 0 for nope, 1 for rope
|
||||
- rope_theta
|
||||
- max_position_embeddings
|
||||
- sliding_window_layout: array of 0 or 1s, 0 for normal attention, 1 for SWA
|
||||
- sliding_window_size
|
||||
|
||||
General FFN configs:
|
||||
- moe_layer_layout: array of 0 or 1s, 0 for dense layer, 1 for MoE layer
|
||||
|
||||
Dense FFN configs:
|
||||
- dense_ffn_hidden_size
|
||||
|
||||
MoE FFN configs:
|
||||
- moe_num_primary_experts
|
||||
- moe_shared_primary_experts
|
||||
- moe_ffn_hidden_size
|
||||
- moe_enable_early_router: Use attention output as router input if true
|
||||
- moe_primary_router_use_sigmoid: Use normalized sigmoid
|
||||
- moe_num_active_primary_experts
|
||||
- moe_enable_secondary_experts
|
||||
- moe_num_secondary_experts
|
||||
- moe_secondary_expert_size
|
||||
|
||||
LM Head configs:
|
||||
- tie_word_embeddings
|
||||
|
||||
Visibility configs:
|
||||
- profile_sparsity
|
||||
|
||||
Other configs:
|
||||
- initializer_range
|
||||
"""
|
||||
def __init__(self,
|
||||
model_type = "smallthinker",
|
||||
model_name="smallthinker_4b_base",
|
||||
num_hidden_layers=32,
|
||||
hidden_size=1536,
|
||||
pad_token_id=None,
|
||||
bos_token_id=151643,
|
||||
eos_token_id=[151643,151645],
|
||||
vocab_size=151936,
|
||||
rms_norm_eps=1e-6,
|
||||
num_attention_heads=12,
|
||||
num_key_value_heads=2,
|
||||
head_dim=128,
|
||||
use_cache=True,
|
||||
use_qk_norm=False,
|
||||
rope_layout=[1]*32,
|
||||
rope_theta=1e6,
|
||||
max_position_embeddings=4096 * 32,
|
||||
sliding_window_layout=[0]*32,
|
||||
sliding_window_size=4096,
|
||||
moe_layer_layout=[1]*32,
|
||||
dense_ffn_hidden_size=4096,
|
||||
moe_num_primary_experts=32,
|
||||
moe_shared_primary_experts=0,
|
||||
moe_ffn_hidden_size=768,
|
||||
moe_enable_early_router=True,
|
||||
moe_primary_router_apply_softmax=False,
|
||||
moe_num_active_primary_experts=4,
|
||||
moe_enable_secondary_experts=False,
|
||||
moe_num_secondary_experts=0,
|
||||
moe_secondary_expert_size=0,
|
||||
tie_word_embeddings=True,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
# Configuration sanitizers
|
||||
assert num_attention_heads % num_key_value_heads == 0, "[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads"
|
||||
assert len(rope_layout) == num_hidden_layers, "[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers"
|
||||
assert len(sliding_window_layout) == num_hidden_layers, "[Smallthinker config sanitizer] sliding_window_layout must have the same length as num_hidden_layers"
|
||||
assert len(moe_layer_layout) == num_hidden_layers, "[Smallthinker config sanitizer] moe_layer_layout must have the same length as num_hidden_layers"
|
||||
|
||||
if any(moe_layer_layout):
|
||||
assert moe_num_primary_experts != 0, "[Smallthinker config sanitizer] moe_num_primary_experts must be set non-zero if there is any MoE layer"
|
||||
assert moe_ffn_hidden_size != 0, "[Smallthinker config sanitizer] moe_ffn_hidden_size must be set non-zero if there is any MoE layer"
|
||||
assert moe_num_active_primary_experts != 0, "[Smallthinker config sanitizer] moe_num_active_primary_experts must be set non-zero if there is any MoE layer"
|
||||
if moe_enable_secondary_experts:
|
||||
assert moe_num_secondary_experts != 0, "[Smallthinker config sanitizer] moe_num_secondary_experts must be set non-zero if moe_enable_secondary_experts is True"
|
||||
assert moe_secondary_expert_size != 0, "[Smallthinker config sanitizer] moe_secondary_expert_size must be set non-zero if moe_enable_secondary_experts is True"
|
||||
assert moe_num_secondary_experts * moe_secondary_expert_size == moe_ffn_hidden_size, "[Smallthinker config sanitizer] moe_num_secondary_experts * moe_secondary_expert_size must equal moe_ffn_hidden_size"
|
||||
|
||||
if not all(moe_layer_layout):
|
||||
assert dense_ffn_hidden_size != 0, "[Smallthinker config sanitizer] dense_ffn_hidden_size must be set non-zero if there is any dense FFN layer"
|
||||
|
||||
# General configs
|
||||
self.model_type = model_type
|
||||
self.model_name = model_name
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Tokenizer configs
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
# Embedding configs
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# RMSNorm configs
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
# Attention configs
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.use_cache = use_cache
|
||||
self.use_qk_norm = use_qk_norm
|
||||
self.rope_layout = rope_layout
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.sliding_window_layout = sliding_window_layout
|
||||
self.sliding_window_size = sliding_window_size
|
||||
|
||||
# General FFN configs
|
||||
self.moe_layer_layout = moe_layer_layout
|
||||
|
||||
# Dense FFN configs
|
||||
self.dense_ffn_hidden_size = dense_ffn_hidden_size
|
||||
|
||||
# MoE FFN configs
|
||||
self.moe_num_primary_experts = moe_num_primary_experts
|
||||
self.moe_shared_primary_experts = moe_shared_primary_experts
|
||||
self.moe_ffn_hidden_size = moe_ffn_hidden_size
|
||||
self.moe_enable_early_router = moe_enable_early_router
|
||||
self.moe_primary_router_apply_softmax = moe_primary_router_apply_softmax
|
||||
self.moe_num_active_primary_experts = moe_num_active_primary_experts
|
||||
self.moe_enable_secondary_experts = moe_enable_secondary_experts
|
||||
self.moe_num_secondary_experts = moe_num_secondary_experts
|
||||
self.moe_secondary_expert_size = moe_secondary_expert_size
|
||||
|
||||
# Logging configs
|
||||
# self.output_router_logits = False
|
||||
|
||||
# Other configs
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
self._attn_implementation = "eager" # SDPA is not allowed for now
|
||||
|
||||
# if self._attn_implementation != "flash_attention_2":
|
||||
# raise NotImplementedError("SDPA impl is buggy for now. NEVER TRY TO USE IT.")
|
||||
|
||||
__all__ = ["SmallthinkerConfig"]
|
140
ktransformers/models/custom_modeling_glm4_moe.py
Normal file
140
ktransformers/models/custom_modeling_glm4_moe.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
Date: 2024-11-06 10:05:11
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-13 07:50:51
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.models.custom_cache import KGQACache
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeModel, Glm4MoePreTrainedModel
|
||||
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
import flashinfer
|
||||
|
||||
class KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):
|
||||
|
||||
cache: KGQACache
|
||||
use_cuda_graph = False
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
cache,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
self.model = Glm4MoeModel(config)
|
||||
self.config = config
|
||||
self.cache = cache
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.attn = [None] * 100
|
||||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
|
||||
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
|
||||
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
features = []
|
||||
for i in range(batch.batch_size):
|
||||
tokens = batch.minibatch.tokens.contiguous()
|
||||
feature = (
|
||||
self.model.embed_tokens(tokens.to(torch.device('cpu')))
|
||||
.to(torch.bfloat16)
|
||||
.to(device=device)
|
||||
)
|
||||
features.append(feature)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ForwardBatchInput | None = None,
|
||||
features: List[torch.Tensor] | None = None,
|
||||
bsz_tensors: torch.Tensor | None = None,
|
||||
num_tokens_tensors: torch.Tensor | None = None,
|
||||
page_idx: torch.Tensor | None = None,
|
||||
page_offset: torch.Tensor | None = None,
|
||||
cuda_graph_idx: int | None = 0
|
||||
) -> ForwardBatchOutput:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
||||
|
||||
hidden_states = features[0]
|
||||
self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
|
||||
|
||||
freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
router_input = hidden_states.clone()
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
freqs_cis,
|
||||
wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
|
||||
position_ids=batch.minibatch.position_ids
|
||||
)
|
||||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
if i < self.model.config.first_k_dense_replace:
|
||||
hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||
# hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
with torch.cuda.stream(current_stream):
|
||||
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
|
||||
forward_batch_output.logits.append(local_logit)
|
||||
|
||||
return forward_batch_output
|
||||
|
||||
|
||||
|
||||
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
cuda_graph_idx: int = 0
|
||||
):
|
||||
minibatch = batch.minibatch
|
||||
self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
|
||||
|
140
ktransformers/models/custom_modeling_smallthinker.py
Normal file
140
ktransformers/models/custom_modeling_smallthinker.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
Date: 2024-11-06 10:05:11
|
||||
LastEditors: djw
|
||||
LastEditTime: 2024-11-13 07:50:51
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
|
||||
from ktransformers.models.custom_cache import KGQACache
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerModel, SmallthinkerPreTrainedModel
|
||||
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
||||
from ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
import flashinfer
|
||||
|
||||
class KSmallthinkerForCausalLM(SmallthinkerPreTrainedModel):
|
||||
|
||||
cache: KGQACache
|
||||
use_cuda_graph = False
|
||||
def __init__(
|
||||
self,
|
||||
config: SmallthinkerConfig,
|
||||
cache,
|
||||
):
|
||||
|
||||
super().__init__(config)
|
||||
self.model = SmallthinkerModel(config)
|
||||
self.config = config
|
||||
self.cache = cache
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.attn = [None] * 100
|
||||
|
||||
def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):
|
||||
self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)
|
||||
|
||||
|
||||
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
|
||||
features = []
|
||||
for i in range(batch.batch_size):
|
||||
tokens = batch.minibatch.tokens.contiguous()
|
||||
feature = (
|
||||
self.model.embed_tokens(tokens.to(torch.device('cpu')))
|
||||
.to(torch.bfloat16)
|
||||
.to(device=device)
|
||||
)
|
||||
features.append(feature)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: ForwardBatchInput | None = None,
|
||||
features: List[torch.Tensor] | None = None,
|
||||
bsz_tensors: torch.Tensor | None = None,
|
||||
num_tokens_tensors: torch.Tensor | None = None,
|
||||
page_idx: torch.Tensor | None = None,
|
||||
page_offset: torch.Tensor | None = None,
|
||||
cuda_graph_idx: int | None = 0
|
||||
) -> ForwardBatchOutput:
|
||||
current_stream = torch.cuda.current_stream()
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
|
||||
|
||||
hidden_states = features[0]
|
||||
self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])
|
||||
|
||||
freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))
|
||||
|
||||
with torch.cuda.stream(current_stream):
|
||||
residual = torch.zeros_like(hidden_states)
|
||||
for i, decode_layer in enumerate(self.model.layers):
|
||||
if self.model.transfer_map is not None and i in self.model.transfer_map:
|
||||
prev_stream = torch.cuda.current_stream()
|
||||
cur_device = self.model.transfer_map[i]
|
||||
if cur_device not in self.model.stream_device_map:
|
||||
self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
|
||||
torch.cuda.set_device(cur_device)
|
||||
self.model.stream_device_map[cur_device].wait_stream(prev_stream)
|
||||
torch.cuda.set_stream(self.model.stream_device_map[cur_device])
|
||||
hidden_states = hidden_states.to(
|
||||
self.model.transfer_map[i], non_blocking=True
|
||||
)
|
||||
|
||||
batch.minibatch.position_ids = (
|
||||
batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)
|
||||
if batch.minibatch.position_ids is not None
|
||||
else None
|
||||
)
|
||||
router_input = hidden_states.clone()
|
||||
hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
hidden_states = decode_layer.self_attn(hidden_states, self.cache,
|
||||
freqs_cis if self.model.rope_layout[i] else None,
|
||||
wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors,
|
||||
position_ids=batch.minibatch.position_ids
|
||||
)
|
||||
|
||||
hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)
|
||||
if not self.config.moe_layer_layout[i]:
|
||||
hidden_states = decode_layer.feed_forward(router_input, hidden_states, num_tokens_tensors)
|
||||
else:
|
||||
hidden_states = decode_layer.feed_forward(hidden_states, num_tokens_tensors, cuda_graph_idx)
|
||||
# hidden_states = hidden_states.squeeze(0)
|
||||
|
||||
forward_batch_output = ForwardBatchOutput()
|
||||
with torch.cuda.stream(current_stream):
|
||||
local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)
|
||||
forward_batch_output.logits.append(local_logit)
|
||||
|
||||
return forward_batch_output
|
||||
|
||||
|
||||
|
||||
def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
page_size: int,
|
||||
causal: bool,
|
||||
q_data_type: torch.dtype,
|
||||
kv_data_type: torch.dtype,
|
||||
cuda_graph_idx: int = 0
|
||||
):
|
||||
minibatch = batch.minibatch
|
||||
self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
|
||||
minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)
|
||||
|
650
ktransformers/models/modeling_glm4_moe.py
Normal file
650
ktransformers/models/modeling_glm4_moe.py
Normal file
|
@ -0,0 +1,650 @@
|
|||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_glm4_moe.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_glm4_moe import Glm4MoeConfig
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
|
||||
# Keep half or full tensor for later concatenation
|
||||
rotary_dim = cos.shape[-1]
|
||||
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
|
||||
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
|
||||
|
||||
# Apply rotary embeddings on the first half or full tensor
|
||||
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
|
||||
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
|
||||
|
||||
# Concatenate back to full shape
|
||||
q_embed = torch.cat([q_embed, q_pass], dim=-1)
|
||||
k_embed = torch.cat([k_embed, k_pass], dim=-1)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Glm4MoeAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
if self.use_qk_norm: # main diff from Llama
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; position_ids needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Glm4MoeMLP(nn.Module):
|
||||
def __init__(self, config, hidden_size=None, intermediate_size=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
|
||||
self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size
|
||||
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
class Glm4MoeTopkRouter(nn.Module):
|
||||
def __init__(self, config: Glm4MoeConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.n_routed_experts = config.n_routed_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_group = config.n_group
|
||||
self.topk_group = config.topk_group
|
||||
self.norm_topk_prob = config.norm_topk_prob
|
||||
|
||||
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
|
||||
self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts), dtype=torch.float32))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_topk_indices(self, scores):
|
||||
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.topk(2, dim=-1)[0]
|
||||
.sum(dim=-1)
|
||||
)
|
||||
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
|
||||
group_mask = torch.zeros_like(group_scores)
|
||||
group_mask.scatter_(1, group_idx, 1)
|
||||
score_mask = (
|
||||
group_mask.unsqueeze(-1)
|
||||
.expand(-1, self.n_group, self.n_routed_experts // self.n_group)
|
||||
.reshape(-1, self.n_routed_experts)
|
||||
)
|
||||
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
|
||||
return topk_indices
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.view(-1, self.config.hidden_size)
|
||||
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
|
||||
scores = router_logits.sigmoid()
|
||||
topk_indices = self.get_topk_indices(scores)
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if self.norm_topk_prob:
|
||||
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weights /= denominator
|
||||
topk_weights = topk_weights * self.routed_scaling_factor
|
||||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class Glm4MoeRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Glm4MoeRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Glm4MoeMoE(nn.Module):
|
||||
"""
|
||||
A mixed expert module containing shared experts.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)
|
||||
for _ in range(config.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.gate = Glm4MoeTopkRouter(config)
|
||||
self.shared_experts = Glm4MoeMLP(
|
||||
config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts
|
||||
)
|
||||
|
||||
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
|
||||
r"""
|
||||
CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused
|
||||
to not have to do a loop here (deepseek has 256 experts soooo yeah).
|
||||
"""
|
||||
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
|
||||
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
|
||||
expert_mask = expert_mask.permute(2, 0, 1)
|
||||
|
||||
for expert_idx in range(len(self.experts)):
|
||||
expert = self.experts[expert_idx]
|
||||
mask = expert_mask[expert_idx]
|
||||
token_indices, weight_indices = torch.where(mask)
|
||||
|
||||
if token_indices.numel() > 0:
|
||||
expert_weights = topk_weights[token_indices, weight_indices]
|
||||
expert_input = hidden_states[token_indices]
|
||||
expert_output = expert(expert_input)
|
||||
weighted_output = expert_output * expert_weights.unsqueeze(-1)
|
||||
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
||||
|
||||
# in original deepseek, the output of the experts are gathered once we leave this module
|
||||
# thus the moe module is itelsf an IsolatedParallel module
|
||||
# and all expert are "local" meaning we shard but we don't gather
|
||||
return final_hidden_states.type(hidden_states.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
topk_indices, topk_weights = self.gate(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
hidden_states = hidden_states + self.shared_experts(residuals)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: Glm4MoeConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
if layer_idx >= config.first_k_dense_replace:
|
||||
self.mlp = Glm4MoeMoE(config)
|
||||
else:
|
||||
self.mlp = Glm4MoeMLP(config)
|
||||
|
||||
self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Glm4MoePreTrainedModel(PreTrainedModel):
|
||||
config: Glm4MoeConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Glm4MoeDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_static_cache = False
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Glm4MoeDecoderLayer,
|
||||
"attentions": Glm4MoeAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, Glm4MoeRMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, Glm4MoeTopkRouter):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
|
||||
|
||||
class Glm4MoeRotaryEmbedding(nn.Module):
|
||||
def __init__(self, config: Glm4MoeConfig, device=None):
|
||||
super().__init__()
|
||||
# BC: "rope_type" was originally "type"
|
||||
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
||||
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||
else:
|
||||
self.rope_type = "default"
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
|
||||
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.original_inv_freq = self.inv_freq
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Glm4MoeModel(Glm4MoePreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"model\.layers\.92.*", r"model\.layers\.46.*"]
|
||||
|
||||
def __init__(self, config: Glm4MoeConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = Glm4MoeRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position: torch.Tensor = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = Glm4MoeModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Glm4MoeForCausalLM
|
||||
|
||||
>>> model = Glm4MoeForCausalLM.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-glm4_moe/Glm4Moe-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Glm4MoePreTrainedModel", "Glm4MoeModel", "Glm4MoeForCausalLM"]
|
1235
ktransformers/models/modeling_smallthinker.py
Normal file
1235
ktransformers/models/modeling_smallthinker.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -26,6 +26,8 @@ from ktransformers.operators.base_operator import BaseInjectedModule
|
|||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
from ktransformers.util.utils import InferenceState
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding
|
||||
import torch
|
||||
|
||||
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
|
||||
|
@ -438,3 +440,92 @@ class KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
|
|||
self.orig_module.__init__(
|
||||
self.orig_module.config
|
||||
)
|
||||
|
||||
|
||||
class KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config,
|
||||
device = self.generate_device,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
||||
|
||||
class KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
gguf_loader: GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
# device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
prefill_device: str = "cuda",
|
||||
**kwargs,
|
||||
):
|
||||
BaseInjectedModule.__init__(
|
||||
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
|
||||
)
|
||||
self.orig_module.__init__(
|
||||
config
|
||||
)
|
||||
self.generate_device = generate_device
|
||||
self.prefill_device = prefill_device
|
||||
|
||||
def load(self):
|
||||
self.orig_module.__init__(
|
||||
self.orig_module.config,
|
||||
device = self.generate_device,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids):
|
||||
if "dynamic" in self.rope_type:
|
||||
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||
# Core RoPE block
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
# print(inv_freq_expanded.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
freqs_cis = freqs_cis * self.attention_scaling
|
||||
return freqs_cis
|
|
@ -9,6 +9,8 @@ from torch import nn
|
|||
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerAttention
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeAttention
|
||||
from typing import Optional, Tuple
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
|
@ -455,3 +457,230 @@ class deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):
|
|||
attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)
|
||||
final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)
|
||||
return final_attention_output
|
||||
|
||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
class KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
raise NotImplementedError("use_qk_norm is not implemented yet")
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
chunck_size: int = 1000,
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config,
|
||||
orig_module.layer_idx)
|
||||
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
self,
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KGQACache,
|
||||
freqs_cis: torch.Tensor,
|
||||
wrapper: flashInferAttn,
|
||||
bsz_tensors: torch.Tensor,
|
||||
position_ids: torch.Tensor = None,
|
||||
):
|
||||
|
||||
if self.use_qk_norm:
|
||||
query_states = self.q_norm(query_states)
|
||||
key_states = self.k_norm(key_states)
|
||||
|
||||
q_len, _ = hidden_states.size()
|
||||
query_states = self.q_proj(hidden_states, bsz_tensors)
|
||||
key_states = self.k_proj(hidden_states, bsz_tensors)
|
||||
value_states = self.v_proj(hidden_states, bsz_tensors)
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
# cos, sin = freqs_cis
|
||||
"""
|
||||
print(query_states.shape)
|
||||
print(key_states.shape)
|
||||
print(cos.shape)
|
||||
print(sin.shape)
|
||||
"""
|
||||
if freqs_cis:
|
||||
query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)
|
||||
|
||||
|
||||
|
||||
query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)
|
||||
key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
k_cache = kv_cache.get_k_cache(self.layer_idx)
|
||||
v_cache = kv_cache.get_v_cache(self.layer_idx)
|
||||
|
||||
|
||||
attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)
|
||||
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)
|
||||
|
||||
return attn_output
|
|
@ -729,6 +729,8 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE
|
|||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeMoE
|
||||
|
||||
|
||||
class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
|
||||
|
@ -1248,6 +1250,12 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
|
|||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
|
||||
|
||||
if prefill_op == 'None':
|
||||
prefill_op = None
|
||||
if generate_op == 'None':
|
||||
generate_op = None
|
||||
|
||||
if generate_op is not None:
|
||||
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
|
||||
else:
|
||||
|
@ -1464,6 +1472,264 @@ class KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):
|
|||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):
|
||||
def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states)
|
||||
else:
|
||||
if self.enable_early_router:
|
||||
router_logits = self.primary_router(router_input, bsz_tensor)
|
||||
else:
|
||||
router_logits = self.primary_router(hidden_states, bsz_tensor)
|
||||
|
||||
router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)
|
||||
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
if self.moe_primary_router_apply_softmax:
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
else:
|
||||
routing_weights = F.sigmoid(router_logits)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
# y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
else:
|
||||
# TODO may bugs here
|
||||
y = (
|
||||
self.moe_infer_simple(hidden_states, selected_experts, routing_weights)
|
||||
.view(*orig_shape)
|
||||
.to(device=hidden_states.device)
|
||||
)
|
||||
# y += y_
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
|
||||
outs = torch.empty_like(x)
|
||||
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer_simple(
|
||||
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: [num_tokens, hidden_size]
|
||||
topk_ids, topk_weight: [num_tokens, num_selected_experts]
|
||||
"""
|
||||
outs = torch.zeros_like(x)
|
||||
for token_idx in range(topk_ids.size(0)):
|
||||
for expert_idx in range(topk_ids.size(1)):
|
||||
expert = self.experts[topk_ids[token_idx, expert_idx]]
|
||||
outs[token_idx] += (
|
||||
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
|
||||
)
|
||||
return outs
|
||||
|
||||
@torch.no_grad()
|
||||
# TODO may bugs here
|
||||
def moe_infer(self, x, topk_ids, topk_weight):
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
||||
cnts.scatter_(1, topk_ids, 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
||||
expert_out = expert.forward(tokens_for_this_expert)
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
|
||||
new_x = torch.empty_like(outs)
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weight.dtype)
|
||||
.mul_(topk_weight.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
class KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):
|
||||
def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
sequence_length = orig_shape[1]
|
||||
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
if bsz_tensor is None:
|
||||
router_logits = self.gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.gate(hidden_states, bsz_tensor)
|
||||
|
||||
if router_logits.device.type == "xpu":
|
||||
# TODO: support self.moe_primary_router_apply_softmax False case
|
||||
from ipex_llm.transformers.models.common import moe_softmax_topk
|
||||
selected_experts, routing_weights = moe_softmax_topk(
|
||||
router_logits.half(), self.top_k, self.norm_topk_prob
|
||||
)
|
||||
else:
|
||||
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
if self.norm_topk_prob:
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
# only for generate phase
|
||||
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
|
||||
self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)
|
||||
y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
|
||||
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
|
||||
|
||||
y += y_
|
||||
y.resize_(*orig_shape)
|
||||
return y
|
||||
|
||||
# y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)
|
||||
# y_ = (
|
||||
# F.sigmoid(self.shared_expert_gate(hidden_states)) * y_
|
||||
# )
|
||||
|
||||
|
||||
if isinstance(self.experts, KExpertsBase):
|
||||
y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
|
||||
elif hidden_states.size(0) > 10:
|
||||
|
|
|
@ -213,3 +213,4 @@ class KMoEGateIPEXLLM(KMoEGate):
|
|||
self.n_group, self.topk_group, self.top_k,
|
||||
self.norm_topk_prob, self.routed_scaling_factor)
|
||||
return topk_idx, topk_weight.to(x.dtype)
|
||||
|
||||
|
|
|
@ -28,6 +28,8 @@ import torch.nn as nn
|
|||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm
|
||||
from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm
|
||||
from ktransformers.operators.base_operator import BaseInjectedModule
|
||||
from ktransformers.util.custom_loader import GGUFLoader
|
||||
if not torch.xpu.is_available():
|
||||
|
@ -165,6 +167,94 @@ class KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):
|
|||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
bsz, hidden_size = x.shape
|
||||
x = x.view(-1, self.orig_module.hidden_size)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
out = out.view(bsz, hidden_size)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.hidden_size,
|
||||
orig_module.variance_epsilon)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
batch_size_tensor: torch.Tensor = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
#return self.forward_native(x, residual)
|
||||
bsz, hidden_size = x.shape
|
||||
x = x.view(-1, self.orig_module.hidden_size)
|
||||
if batch_size_tensor is None:
|
||||
return self.forward_native(x)
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
#residual = x + residual
|
||||
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
|
||||
return x, residual
|
||||
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
|
||||
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
|
||||
out = out.view(bsz, hidden_size)
|
||||
return out
|
||||
|
||||
def forward_native(
|
||||
self, hidden_states
|
||||
):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
|
||||
class DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
|
|
|
@ -5,6 +5,8 @@ from transformers import PretrainedConfig
|
|||
import torch.nn as nn
|
||||
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
|
||||
from ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP
|
||||
from ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock
|
||||
from ktransformers.models.modeling_glm4_moe import Glm4MoeMLP
|
||||
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
|
@ -35,3 +37,34 @@ class KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):
|
|||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
||||
|
||||
|
||||
class KSmallthinkerDenseMlpBlock(SmallthinkerDenseMlpBlock, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down(nn.functional.relu(self.gate(x, bsz_tensor)) * self.up(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
||||
|
||||
class KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule):
|
||||
def __init__(self,
|
||||
key: str,
|
||||
gguf_loader : GGUFLoader,
|
||||
config: PretrainedConfig,
|
||||
orig_module: nn.Module,
|
||||
prefill_device: str = "cuda",
|
||||
generate_device: str = "cuda",
|
||||
**kwargs):
|
||||
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
|
||||
self.orig_module.__init__(orig_module.config)
|
||||
def forward(self, x, bsz_tensor):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
|
||||
return down_proj
|
90
ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml
Normal file
90
ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml
Normal file
|
@ -0,0 +1,90 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_glm4_moe.Glm4MoeRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.KGlm4MoeRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "VLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# - match:
|
||||
# name: "^model\\.layers\\..*$" # regular expression
|
||||
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
# replace:
|
||||
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
# kwargs:
|
||||
# generate_device: "cuda"
|
||||
# prefill_device: "cuda"
|
||||
# generate_op: "VLinearMarlin"
|
||||
# prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp$"
|
||||
class: ktransformers.models.modeling_glm4_moe.Glm4MoeMoE
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KGlm4MoeMoE
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.mlp\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: None
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_glm4_moe.Glm4MoeRMSNorm
|
||||
replace:
|
||||
class: ktransformers.operators.layernorm.KGlm4MoeRMSNorm
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_glm4_moe.Glm4MoeMLP
|
||||
replace:
|
||||
class: ktransformers.operators.mlp.KGlm4MoeMLP
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
|
@ -0,0 +1,90 @@
|
|||
- match:
|
||||
class: ktransformers.models.modeling_smallthinker.SmallthinkerRotaryEmbedding
|
||||
replace:
|
||||
class: ktransformers.operators.RoPE.KSmallthinkerRotaryEmbedding
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^lm_head$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "VLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
|
||||
# - match:
|
||||
# name: "^model\\.layers\\..*$" # regular expression
|
||||
# class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
# replace:
|
||||
# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
# kwargs:
|
||||
# generate_device: "cuda"
|
||||
# prefill_device: "cuda"
|
||||
# generate_op: "VLinearMarlin"
|
||||
# prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\.(?!.*feed_forward\\.shared_expert_gate).*$" # regular expression
|
||||
class: torch.nn.Linear # only match modules matching name and class simultaneously
|
||||
replace:
|
||||
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
generate_op: "KLinearMarlin"
|
||||
prefill_op: "KLinearTorch"
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.block_sparse_moe$"
|
||||
class: ktransformers.models.modeling_smallthinker.SmallthinkerMoeBlock
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KSmallthinkerMoeBlock
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$"
|
||||
replace:
|
||||
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
|
||||
kwargs:
|
||||
prefill_device: "cuda"
|
||||
prefill_op: None
|
||||
generate_device: "cpu"
|
||||
generate_op: "KExpertsCPU"
|
||||
out_device: "cuda"
|
||||
recursive: False # don't recursively inject submodules of this module
|
||||
- match:
|
||||
name: "^model\\.layers\\..*\\.self_attn$"
|
||||
replace:
|
||||
class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
name: "^model.embed_tokens"
|
||||
replace:
|
||||
class: "default"
|
||||
kwargs:
|
||||
generate_device: "cpu"
|
||||
prefill_device: "cpu"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_smallthinker.SmallthinkerRMSNorm
|
||||
replace:
|
||||
class: ktransformers.operators.layernorm.KSmallthinkerRMSNorm
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
||||
|
||||
- match:
|
||||
class: ktransformers.models.modeling_smallthinker.SmallthinkerDenseMlpBlock
|
||||
replace:
|
||||
class: ktransformers.operators.mlp.KSmallthinkerDenseMlpBlock
|
||||
kwargs:
|
||||
generate_device: "cuda"
|
||||
prefill_device: "cuda"
|
|
@ -24,7 +24,11 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
|
|||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
|
||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||
from ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig
|
||||
from ktransformers.models.configuration_smallthinker import SmallthinkerConfig
|
||||
from ktransformers.models.configuration_glm4_moe import Glm4MoeConfig
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
|
@ -60,6 +64,8 @@ default_optimize_rules = {
|
|||
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
|
||||
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-serve.yaml",
|
||||
"Qwen3MoeForCausalLM": ktransformer_rules_dir + "Qwen3Moe-serve.yaml",
|
||||
"SmallthinkerForCausalLM": ktransformer_rules_dir + "Smallthinker-serve.yaml",
|
||||
"Glm4MoeForCausalLM": ktransformer_rules_dir + "Glm4Moe-serve.yaml",
|
||||
}
|
||||
|
||||
|
||||
|
@ -123,13 +129,22 @@ class Engine:
|
|||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.updates = []
|
||||
|
||||
print(f"args.model_name: {args.model_name}")
|
||||
|
||||
if args.model_name == "Qwen3MoeForCausalLM":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "Glm4MoeForCausalLM":
|
||||
config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
elif args.model_name == "SmallthinkerForCausalLM":
|
||||
config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
config._attn_implementation = "eager"
|
||||
else:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
except:
|
||||
if args.model_name == "Qwen3Moe":
|
||||
config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
else:
|
||||
assert False, f"model {args.model_name} not supported"
|
||||
raise ValueError(f"Model {args.model_name} not supported. Please check your model directory or model name.")
|
||||
|
||||
|
||||
|
||||
|
||||
self.gen_queue = generated_token_queue
|
||||
|
@ -147,6 +162,13 @@ class Engine:
|
|||
self.model = KQwen2MoeForCausalLM(config, self.cache)
|
||||
else:
|
||||
self.model = KQwen3MoeForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "SmallthinkerForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KSmallthinkerForCausalLM(config, self.cache)
|
||||
elif config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
self.cache = KGQACache(config, self.args.page_size)
|
||||
self.model = KGlm4MoeForCausalLM(config, self.cache)
|
||||
|
||||
|
||||
|
||||
context = zmq.Context()
|
||||
|
@ -197,7 +219,7 @@ class Engine:
|
|||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
#@TODO add config
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM":
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallthinkerForCausalLM":
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num)
|
||||
else:
|
||||
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
|
||||
|
|
|
@ -29,6 +29,8 @@ from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausa
|
|||
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_smallthinker import KSmallthinkerForCausalLM
|
||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
|
||||
|
@ -53,7 +55,7 @@ def generate_cuda_graphs(chunk_size: int) -> list:
|
|||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallthinkerForCausalLM | KGlm4MoeForCausalLM
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
|
||||
|
@ -93,7 +95,7 @@ class ModelRunner:
|
|||
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
||||
self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,
|
||||
num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,
|
||||
head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads,
|
||||
|
@ -124,7 +126,7 @@ class ModelRunner:
|
|||
num_tokens = self.features_buf[i][0].size(0)
|
||||
print("capturing cuda graph", batch_size, num_tokens)
|
||||
|
||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):
|
||||
if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallthinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM):
|
||||
self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)
|
||||
|
||||
self.bsz_tensor_buf[0] = batch_size
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue