From a6ed14f689368fa9e466c91901a2c6ff98a59503 Mon Sep 17 00:00:00 2001
From: Bocheng Li <251156266@qq.com>
Date: Tue, 25 Feb 2025 06:40:38 +0800
Subject: [PATCH] Fix: Correct model_args usage in parallelize_model call (#69)

---
 bytelatent/train.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/bytelatent/train.py b/bytelatent/train.py
index ad74b44..eb1c700 100644
--- a/bytelatent/train.py
+++ b/bytelatent/train.py
@@ -317,7 +317,7 @@ def train(args: TrainArgs):
         model = parallelize_model(
             model,
             world_mesh,
-            args.model,
+            model_args,
             args.distributed,
             fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
             tp_parallelize=tp_parallelize,