support smt and glm4

This commit is contained in:
djw 2025-07-24 09:39:19 +00:00
parent b66d96db97
commit 613f0b7c37
8 changed files with 115 additions and 28 deletions

View file

@ -19,8 +19,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from ...modeling_rope_utils import rope_config_validation from transformers.modeling_rope_utils import rope_config_validation
class Glm4MoeConfig(PretrainedConfig): class Glm4MoeConfig(PretrainedConfig):

View file

@ -154,6 +154,8 @@ class SmallthinkerConfig(PretrainedConfig):
self.moe_num_primary_experts = moe_num_primary_experts self.moe_num_primary_experts = moe_num_primary_experts
self.moe_shared_primary_experts = moe_shared_primary_experts self.moe_shared_primary_experts = moe_shared_primary_experts
self.moe_ffn_hidden_size = moe_ffn_hidden_size self.moe_ffn_hidden_size = moe_ffn_hidden_size
self.num_experts_per_tok = moe_num_active_primary_experts
self.moe_intermediate_size = moe_ffn_hidden_size
self.moe_enable_early_router = moe_enable_early_router self.moe_enable_early_router = moe_enable_early_router
self.moe_primary_router_apply_softmax = moe_primary_router_apply_softmax self.moe_primary_router_apply_softmax = moe_primary_router_apply_softmax
self.moe_num_active_primary_experts = moe_num_active_primary_experts self.moe_num_active_primary_experts = moe_num_active_primary_experts

View file

@ -25,19 +25,20 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ...activations import ACT2FN from transformers.activations import ACT2FN
from ...cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin from transformers.generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub from transformers.integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask from transformers.masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer from transformers.modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from transformers.processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs # from transformers.utils import auto_docstring, can_return_tuple
from transformers.utils.generic import check_model_inputs
from .configuration_glm4_moe import Glm4MoeConfig from .configuration_glm4_moe import Glm4MoeConfig
@ -61,7 +62,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
scaling: float, scaling: float,
dropout: float = 0.0, dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs], # **kwargs: Unpack[TransformersKwargs],
): ):
key_states = repeat_kv(key, module.num_key_value_groups) key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups)
@ -268,7 +269,7 @@ class Glm4MoeTopkRouter(nn.Module):
return topk_indices, topk_weights return topk_indices, topk_weights
@use_kernel_forward_from_hub("RMSNorm") # @use_kernel_forward_from_hub("RMSNorm")
class Glm4MoeRMSNorm(nn.Module): class Glm4MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@ -369,7 +370,7 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs], # **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -487,7 +488,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
self.embed_tokens = value self.embed_tokens = value
@check_model_inputs # @check_model_inputs
@auto_docstring @auto_docstring
def forward( def forward(
self, self,
@ -498,7 +499,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs], # **kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast: ) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -594,7 +595,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs], # **kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

View file

@ -1315,6 +1315,80 @@ class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
else: else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD") raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
class KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
prefill_device:str = "cuda",
prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
self.generate_experts = None
if prefill_op is not None:
self.prefill_experts = None
self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
# TODO support w as input
if not mode: mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
# self.prefill_experts.unload()
self.generate_experts.load(w, warmup=warmup)
self.device = self.generate_experts.device
self.mode = mode
elif mode == InferenceState.PREFILL:
self.generate_experts.unload()
self.prefill_experts.load(w, warmup=warmup)
self.device = self.prefill_experts.device
self.mode = mode
elif mode == InferenceState.UNLOAD:
self.unload()
self.mode = mode
self.device = self.generate_experts.device
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
def unload(self):
if self.generate_experts is not None:
self.generate_experts.unload()
if self.prefill_experts is not None:
self.prefill_experts.unload()
self.device = self.generate_experts.device
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
else:
raise ValueError("load or set_inference_mode before forward")
def set_inference_mode(self, mode: InferenceState):
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE, warmup=False)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL, warmup=False)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock): class KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0): def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):

View file

@ -49,7 +49,7 @@
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$" name: "^model\\.layers\\..*\\.block_sparse_moe\\.experts$"
replace: replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism class: ktransformers.operators.experts.KSmallthinkerExperts # custom MoE Kernel with expert paralleism
kwargs: kwargs:
prefill_device: "cuda" prefill_device: "cuda"
prefill_op: None prefill_op: None

View file

@ -138,8 +138,12 @@ class SafeTensorLoader(ModelLoader):
base_key = key # e.g. "model.layers.3.mlp.experts" base_key = key # e.g. "model.layers.3.mlp.experts"
experts_count = 0 experts_count = 0
key_no_proj = False
if self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
key_no_proj = True
# First, count how many experts we have by checking for expert 0's up_proj # First, count how many experts we have by checking for expert 0's up_proj
while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight"): while self.has_tensor(f"{base_key}.{experts_count}.up_proj.weight") or self.has_tensor(f"{base_key}.{experts_count}.up.weight"):
experts_count += 1 experts_count += 1
if experts_count == 0: if experts_count == 0:
@ -152,6 +156,12 @@ class SafeTensorLoader(ModelLoader):
# Load all expert weights # Load all expert weights
for expert_id in range(experts_count): for expert_id in range(experts_count):
if key_no_proj:
up_key = f"{base_key}.{expert_id}.up.weight"
gate_key = f"{base_key}.{expert_id}.gate.weight"
down_key = f"{base_key}.{expert_id}.down.weight"
else:
up_key = f"{base_key}.{expert_id}.up_proj.weight" up_key = f"{base_key}.{expert_id}.up_proj.weight"
gate_key = f"{base_key}.{expert_id}.gate_proj.weight" gate_key = f"{base_key}.{expert_id}.gate_proj.weight"
down_key = f"{base_key}.{expert_id}.down_proj.weight" down_key = f"{base_key}.{expert_id}.down_proj.weight"

View file

@ -16,7 +16,7 @@ dynamic = ["version"]
dependencies = [ dependencies = [
"torch >= 2.3.0", "torch >= 2.3.0",
"transformers == 4.51.3", "transformers == 4.53.3",
"fastapi >= 0.111.0", "fastapi >= 0.111.0",
"uvicorn >= 0.30.1", "uvicorn >= 0.30.1",
"langchain >= 0.2.0", "langchain >= 0.2.0",

View file

@ -1,5 +1,5 @@
fire fire
transformers==4.51.3 transformers==4.53.3
numpy numpy
torch>=2.3.0 torch>=2.3.0
packaging packaging