Make it possible to specify multiple config files

Summary:

Test Plan:

Test that this iterpolates in the right order, config -> configs -> cli args

```
# All three sources
python -m bytelatent.print_config config=bytelatent/configs/debug.yaml configs=[internal/configs/s3_debug.yaml] eval=null

# What worked before
python -m bytelatent.print_config config=internal/configs/s3_debug.yaml eval=null
```
This commit is contained in:
Pedro Rodriguez 2025-02-12 19:24:49 +00:00
parent 3e3193c1d4
commit ece82cb960
3 changed files with 26 additions and 7 deletions

View file

@ -40,12 +40,23 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
file_cfgs = []
if "config" in cli_args:
file_cfg = OmegaConf.load(cli_args["config"])
del cli_args["config"]
file_cfgs.append(file_cfg)
if "configs" in cli_args:
for c in cli_args["configs"]:
extra_file_cfg = OmegaConf.load(c)
file_cfgs.append(extra_file_cfg)
del cli_args["configs"]
default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
to_merge = [default_cfg]
to_merge.extend(file_cfgs)
to_merge.append(cli_args)
cfg = OmegaConf.merge(*to_merge)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args

View file

@ -56,13 +56,11 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
attn_bias_type: "block_causal"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512

View file

@ -0,0 +1,10 @@
from bytelatent.args import TrainArgs, parse_args
def main():
train_args = parse_args(TrainArgs)
print(train_args.model_dump_json(indent=4))
if __name__ == "__main__":
main()