mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-15 09:39:42 +00:00
support smt and glm4
This commit is contained in:
parent
b66d96db97
commit
613f0b7c37
8 changed files with 115 additions and 28 deletions
|
@ -25,19 +25,20 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations import use_kernel_forward_from_hub
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.generation import GenerationMixin
|
||||
from transformers.integrations import use_kernel_forward_from_hub
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
||||
# from transformers.utils import auto_docstring, can_return_tuple
|
||||
from transformers.utils.generic import check_model_inputs
|
||||
from .configuration_glm4_moe import Glm4MoeConfig
|
||||
|
||||
|
||||
|
@ -61,7 +62,7 @@ def eager_attention_forward(
|
|||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
@ -268,7 +269,7 @@ class Glm4MoeTopkRouter(nn.Module):
|
|||
return topk_indices, topk_weights
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
# @use_kernel_forward_from_hub("RMSNorm")
|
||||
class Glm4MoeRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
|
@ -369,7 +370,7 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer):
|
|||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
@ -487,7 +488,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
|
|||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
@check_model_inputs
|
||||
# @check_model_inputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
|
@ -498,7 +499,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel):
|
|||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPast:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
@ -594,7 +595,7 @@ class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):
|
|||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
# **kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue