Fix: Correct model_args usage in parallelize_model call ()

This commit is contained in:
Bocheng Li 2025-02-25 06:40:38 +08:00 committed by GitHub
parent fc3399ef40
commit a6ed14f689
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -317,7 +317,7 @@ def train(args: TrainArgs):
model = parallelize_model(
model,
world_mesh,
args.model,
model_args,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
tp_parallelize=tp_parallelize,