diff --git a/bytelatent/train.py b/bytelatent/train.py
index ad74b44..eb1c700 100644
--- a/bytelatent/train.py
+++ b/bytelatent/train.py
@@ -317,7 +317,7 @@ def train(args: TrainArgs):
         model = parallelize_model(
             model,
             world_mesh,
-            args.model,
+            model_args,
             args.distributed,
             fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
             tp_parallelize=tp_parallelize,