diff --git a/bytelatent/stool.py b/bytelatent/stool.py index 965f4cb..b156177 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,14 +4,15 @@ import json import os import shutil import subprocess -from dataclasses import dataclass +from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf -@dataclass -class StoolArgs: +class StoolArgs(BaseModel): + name: str = None + dump_dir: str = None config: Any = None launcher: str = "sbatch" # Can be sbatch or bash if already in salloc script: str = "apps.main.train" # The script to run. @@ -64,7 +65,7 @@ source activate {conda_env_path} export OMP_NUM_THREADS=1 export LAUNCH_WITH="SBATCH" export DUMP_DIR={dump_dir} -srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml +srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name} """ @@ -150,8 +151,8 @@ def validate_args(args) -> None: def launch_job(args: StoolArgs): # Set up args default and validate them depending on the cluster or partition requested validate_args(args) - dump_dir = args.config["dump_dir"] - job_name = args.config["name"] + job_name = args.name or args.config["name"] + dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"] print("Creating directories...") os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) if args.override: @@ -230,8 +231,7 @@ if __name__ == "__main__": Then you can pass model.dim=32 to change values in LMTransformerArgs or just name=tictac for top level attributes. """ - raise NotImplementedError("Update this to blt code") args = OmegaConf.from_cli() args.config = OmegaConf.load(args.config) - args = dataclass_from_dict(StoolArgs, args) + args = StoolArgs.model_validate(args) launch_job(args)