modeling_deepseek_v3: fix GenerationMixin warning

Fix GenerationMixin warning introduced by upgrading transformers to 4.51.3.
This commit is contained in:
Aubrey Li 2025-05-01 07:48:15 +08:00
parent 7530491f5b
commit def1ec7683

View file

@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
@ -1598,7 +1599,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):