mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
fix stool
This commit is contained in:
parent
7cf8fab49b
commit
8212e9b6f2
|
@ -4,14 +4,15 @@ import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from dataclasses import dataclass
|
from pydantic import BaseModel
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class StoolArgs(BaseModel):
|
||||||
class StoolArgs:
|
name: str = None
|
||||||
|
dump_dir: str = None
|
||||||
config: Any = None
|
config: Any = None
|
||||||
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
||||||
script: str = "apps.main.train" # The script to run.
|
script: str = "apps.main.train" # The script to run.
|
||||||
|
@ -64,7 +65,7 @@ source activate {conda_env_path}
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
export LAUNCH_WITH="SBATCH"
|
export LAUNCH_WITH="SBATCH"
|
||||||
export DUMP_DIR={dump_dir}
|
export DUMP_DIR={dump_dir}
|
||||||
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
|
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,8 +151,8 @@ def validate_args(args) -> None:
|
||||||
def launch_job(args: StoolArgs):
|
def launch_job(args: StoolArgs):
|
||||||
# Set up args default and validate them depending on the cluster or partition requested
|
# Set up args default and validate them depending on the cluster or partition requested
|
||||||
validate_args(args)
|
validate_args(args)
|
||||||
dump_dir = args.config["dump_dir"]
|
job_name = args.name or args.config["name"]
|
||||||
job_name = args.config["name"]
|
dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
|
||||||
print("Creating directories...")
|
print("Creating directories...")
|
||||||
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
||||||
if args.override:
|
if args.override:
|
||||||
|
@ -230,8 +231,7 @@ if __name__ == "__main__":
|
||||||
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
||||||
or just name=tictac for top level attributes.
|
or just name=tictac for top level attributes.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Update this to blt code")
|
|
||||||
args = OmegaConf.from_cli()
|
args = OmegaConf.from_cli()
|
||||||
args.config = OmegaConf.load(args.config)
|
args.config = OmegaConf.load(args.config)
|
||||||
args = dataclass_from_dict(StoolArgs, args)
|
args = StoolArgs.model_validate(args)
|
||||||
launch_job(args)
|
launch_job(args)
|
||||||
|
|
Loading…
Reference in a new issue