modeling_deepseek_v3: fix GenerationMixin warning

Fix GenerationMixin warning introduced by upgrading transformers to 4.51.3.
This commit is contained in:
Aubrey Li 2025-05-01 07:48:15 +08:00
parent 7530491f5b
commit def1ec7683

View file

@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import ( from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter, AttentionMaskConverter,
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
@ -1598,7 +1599,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask return causal_mask
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):