mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 13:02:14 +00:00
parent
7cf8fab49b
commit
6fbaf7266f
|
@ -4,14 +4,15 @@ import json
|
|||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoolArgs:
|
||||
class StoolArgs(BaseModel):
|
||||
name: str = None
|
||||
dump_dir: str = None
|
||||
config: Any = None
|
||||
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
||||
script: str = "apps.main.train" # The script to run.
|
||||
|
@ -64,7 +65,7 @@ source activate {conda_env_path}
|
|||
export OMP_NUM_THREADS=1
|
||||
export LAUNCH_WITH="SBATCH"
|
||||
export DUMP_DIR={dump_dir}
|
||||
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
|
||||
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
|
||||
"""
|
||||
|
||||
|
||||
|
@ -150,8 +151,8 @@ def validate_args(args) -> None:
|
|||
def launch_job(args: StoolArgs):
|
||||
# Set up args default and validate them depending on the cluster or partition requested
|
||||
validate_args(args)
|
||||
dump_dir = args.config["dump_dir"]
|
||||
job_name = args.config["name"]
|
||||
job_name = args.name or args.config["name"]
|
||||
dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
|
||||
print("Creating directories...")
|
||||
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
||||
if args.override:
|
||||
|
@ -230,8 +231,7 @@ if __name__ == "__main__":
|
|||
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
||||
or just name=tictac for top level attributes.
|
||||
"""
|
||||
raise NotImplementedError("Update this to blt code")
|
||||
args = OmegaConf.from_cli()
|
||||
args.config = OmegaConf.load(args.config)
|
||||
args = dataclass_from_dict(StoolArgs, args)
|
||||
args = StoolArgs.model_validate(args)
|
||||
launch_job(args)
|
||||
|
|
Loading…
Reference in a new issue