fix stool (#44)

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-02-05 17:18:40 -08:00 committed by GitHub
parent 7cf8fab49b
commit 6fbaf7266f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,14 +4,15 @@ import json
import os
import shutil
import subprocess
from dataclasses import dataclass
from pydantic import BaseModel
from typing import Any, Dict
from omegaconf import OmegaConf
@dataclass
class StoolArgs:
class StoolArgs(BaseModel):
name: str = None
dump_dir: str = None
config: Any = None
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
script: str = "apps.main.train" # The script to run.
@ -64,7 +65,7 @@ source activate {conda_env_path}
export OMP_NUM_THREADS=1
export LAUNCH_WITH="SBATCH"
export DUMP_DIR={dump_dir}
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
"""
@ -150,8 +151,8 @@ def validate_args(args) -> None:
def launch_job(args: StoolArgs):
# Set up args default and validate them depending on the cluster or partition requested
validate_args(args)
dump_dir = args.config["dump_dir"]
job_name = args.config["name"]
job_name = args.name or args.config["name"]
dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
print("Creating directories...")
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
if args.override:
@ -230,8 +231,7 @@ if __name__ == "__main__":
Then you can pass model.dim=32 to change values in LMTransformerArgs
or just name=tictac for top level attributes.
"""
raise NotImplementedError("Update this to blt code")
args = OmegaConf.from_cli()
args.config = OmegaConf.load(args.config)
args = dataclass_from_dict(StoolArgs, args)
args = StoolArgs.model_validate(args)
launch_job(args)