fix stool (#44)

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-02-05 17:18:40 -08:00 committed by GitHub
parent 7cf8fab49b
commit 6fbaf7266f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,14 +4,15 @@ import json
import os import os
import shutil import shutil
import subprocess import subprocess
from dataclasses import dataclass from pydantic import BaseModel
from typing import Any, Dict from typing import Any, Dict
from omegaconf import OmegaConf from omegaconf import OmegaConf
@dataclass class StoolArgs(BaseModel):
class StoolArgs: name: str = None
dump_dir: str = None
config: Any = None config: Any = None
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
script: str = "apps.main.train" # The script to run. script: str = "apps.main.train" # The script to run.
@ -64,7 +65,7 @@ source activate {conda_env_path}
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
export LAUNCH_WITH="SBATCH" export LAUNCH_WITH="SBATCH"
export DUMP_DIR={dump_dir} export DUMP_DIR={dump_dir}
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
""" """
@ -150,8 +151,8 @@ def validate_args(args) -> None:
def launch_job(args: StoolArgs): def launch_job(args: StoolArgs):
# Set up args default and validate them depending on the cluster or partition requested # Set up args default and validate them depending on the cluster or partition requested
validate_args(args) validate_args(args)
dump_dir = args.config["dump_dir"] job_name = args.name or args.config["name"]
job_name = args.config["name"] dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
print("Creating directories...") print("Creating directories...")
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
if args.override: if args.override:
@ -230,8 +231,7 @@ if __name__ == "__main__":
Then you can pass model.dim=32 to change values in LMTransformerArgs Then you can pass model.dim=32 to change values in LMTransformerArgs
or just name=tictac for top level attributes. or just name=tictac for top level attributes.
""" """
raise NotImplementedError("Update this to blt code")
args = OmegaConf.from_cli() args = OmegaConf.from_cli()
args.config = OmegaConf.load(args.config) args.config = OmegaConf.load(args.config)
args = dataclass_from_dict(StoolArgs, args) args = StoolArgs.model_validate(args)
launch_job(args) launch_job(args)