support smt and glm4

This commit is contained in:
djw 2025-07-24 09:39:19 +00:00
parent b66d96db97
commit 613f0b7c37
8 changed files with 115 additions and 28 deletions

View file

@ -25,19 +25,20 @@ import torch
import torch.nn.functional as F
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
# from transformers.utils import auto_docstring, can_return_tuple
from transformers.utils.generic import check_model_inputs
from .configuration_glm4_moe import Glm4MoeConfig
@ -61,7 +62,7 @@ def eager_attention_forward(
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
# **kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
@ -268,7 +269,7 @@ class Glm4MoeTopkRouter(nn.Module):
return topk_indices, topk_weights
@use_kernel_forward_from_hub("RMSNorm")
# @use_kernel_forward_from_hub("RMSNorm")
class Glm4MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
@ -369,7 +370,7 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
# **kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
@ -487,7 +488,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
def set_input_embeddings(self, value):
self.embed_tokens = value
@check_model_inputs
# @check_model_inputs
@auto_docstring
def forward(
self,
@ -498,7 +499,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
# **kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -594,7 +595,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
# **kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):