diff --git a/bytelatent/train.py b/bytelatent/train.py index ad74b44..eb1c700 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -317,7 +317,7 @@ def train(args: TrainArgs): model = parallelize_model( model, world_mesh, - args.model, + model_args, args.distributed, fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize,