diff --git a/bytelatent/train.py b/bytelatent/train.py index a7ca405..ed88200 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler from bytelatent.stool import StoolArgs, launch_job from bytelatent.transformer import ( + LMTransformer, build_fsdp_grouping_plan, get_no_recompute_ops, get_num_flop_per_token,