Make it possible to specify multiple config files (#54)

Summary:

Make it possible to specify multiple config files.
Parsing CLI is not a special case anymore, just uses the same config inheritance method.

Test Plan:

Test that this iterpolates in the right order via unit tests

Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is:

- Default pydantic args
- Included configs, eg `config`
- CLI args

```
python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null

```


Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-18 10:42:44 -08:00 committed by GitHub
parent 9f29e0de18
commit 82ab5930ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 286 additions and 27 deletions

View file

@ -5,7 +5,6 @@ from typing import Any
import numpy as np import numpy as np
import yaml import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs from bytelatent.checkpoint import CheckpointArgs
@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state return np.random.default_rng((seed, rank, world_size)).bit_generator.state
def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args
TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"

View file

@ -0,0 +1,73 @@
import copy
from typing import Type, TypeVar
import omegaconf
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel
def parse_file_config(path: str) -> DictConfig:
file_cfg = OmegaConf.load(path)
if not isinstance(file_cfg, DictConfig):
raise ValueError(
f"File paths must parse to DictConfig, but it was: {type(file_cfg)}"
)
return file_cfg
def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]:
if "config" not in cfg:
return [cfg]
ordered_cfgs = []
cfg = copy.deepcopy(cfg)
config_arg = cfg["config"]
del cfg["config"]
ordered_cfgs.append(cfg)
if isinstance(config_arg, str):
file_cfg = parse_file_config(config_arg)
sub_configs = recursively_parse_config(file_cfg)
ordered_cfgs = sub_configs + ordered_cfgs
elif isinstance(config_arg, omegaconf.listconfig.ListConfig):
sub_configs = []
for c in config_arg:
if not isinstance(c, str):
raise ValueError(
f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}'
)
config_to_parse = parse_file_config(c)
sub_configs.extend(recursively_parse_config(config_to_parse))
ordered_cfgs = sub_configs + ordered_cfgs
else:
raise ValueError(
f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}'
)
return ordered_cfgs
def parse_args_with_default(
*, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None
):
if cli_args is None:
cli_args = OmegaConf.from_cli()
assert isinstance(
cli_args, DictConfig
), f"CLI Args must be a DictConfig, not {type(cli_args)}"
ordered_cfgs = recursively_parse_config(cli_args)
if default_cfg is not None:
ordered_cfgs.insert(0, default_cfg)
cfg = OmegaConf.merge(*ordered_cfgs)
return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
T = TypeVar("T", bound=BaseModel)
def parse_args_to_pydantic_model(
args_cls: Type[T], cli_args: DictConfig | None = None
) -> T:
default_cfg = OmegaConf.create(args_cls().model_dump())
parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args)
pydantic_args = args_cls.model_validate(parsed_cfg)
return pydantic_args

View file

@ -56,13 +56,11 @@ model:
recompute_attn: false recompute_attn: false
custom_bwd: false custom_bwd: false
layer_ckpt: "none" layer_ckpt: "none"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true use_local_encoder_transformer: true
init_use_gaussian: true init_use_gaussian: true
init_use_depth: "current" init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers" attn_impl: "xformers"
attn_bias_type: "block_causal"
alpha_depth: "disabled" alpha_depth: "disabled"
max_length: 256 max_length: 256
local_attention_window_len: 512 local_attention_window_len: 512

View file

@ -2,9 +2,10 @@
# Evals can be activated by uncommenting its config # Evals can be activated by uncommenting its config
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
dump_dir: /tmp/ dump_dir: /tmp/blt-entropy
name: "debug" name: "debug"
steps: 100_000 steps: 100_000
max_steps: null
probe_freq: null probe_freq: null
seed: 777 seed: 777
optim: optim:
@ -35,7 +36,6 @@ entropy_model:
attn_impl: "xformers" attn_impl: "xformers"
data: data:
s3_profile: blt
root_dir: ??? root_dir: ???
sources: sources:
dclm_baseline_1.0: 1.0 dclm_baseline_1.0: 1.0

View file

@ -5,18 +5,15 @@ import logging
import os import os
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Any
import torch import torch
from lm_eval import simple_evaluate from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.args import EvalArgs, ValidationArgs, parse_args from bytelatent.args import EvalArgs, ValidationArgs
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.distributed import ( from bytelatent.distributed import (
DistributedArgs, DistributedArgs,
@ -29,7 +26,6 @@ from bytelatent.generate import (
PackedCausalTransformerGenerator, PackedCausalTransformerGenerator,
load_consolidated_model_and_tokenizer, load_consolidated_model_and_tokenizer,
) )
from bytelatent.transformer import LMTransformer, LMTransformerArgs
EVAL_FOLDER_NAME = "{:010d}" EVAL_FOLDER_NAME = "{:010d}"

View file

@ -0,0 +1,11 @@
from bytelatent.args import TrainArgs
from bytelatent.config_parser import parse_args_to_pydantic_model
def main():
train_args = parse_args_to_pydantic_model(TrainArgs)
print(train_args.model_dump_json(indent=4))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,180 @@
import os
import pytest
from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.config_parser import (
parse_args_to_pydantic_model,
parse_file_config,
recursively_parse_config,
)
FIXTURE_DIR = "fixtures/test-cfgs"
def test_parse_file_config():
with pytest.raises(ValueError):
cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml"))
assert isinstance(cfg, DictConfig)
def test_nop():
cfg = OmegaConf.create({"a": 1})
parsed_cfgs = recursively_parse_config(cfg)
assert len(parsed_cfgs) == 1
assert parsed_cfgs[0] == cfg
def test_root():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 2
assert len(parsed_cfgs[1]) == 0
assert parsed_cfgs[0]["seed"] == -1
with pytest.raises(MissingMandatoryValue):
assert parsed_cfgs[0]["b"]["y"] is not None
# Test basic cli override
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert parsed_cfgs[1]["seed"] == 42
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["seed"] == 42
def test_one_level_include():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert len(parsed_cfgs[2]) == 0
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 10
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["b"]["y"] == 100
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 100
def test_two_level_include():
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 4
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["p"] == 500
assert parsed_cfgs[3]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"
def test_multiple_includes():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 100
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
"a": 1000,
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1000
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"
class SubConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
x: int = -100
y: int = -100
z: int = -5
class SampleConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
a: int = -100
seed: int = -100
b: SubConfig = SubConfig()
hello: str = ""
p: int = -100
def test_pydantic_parse():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"a": 1000,
}
)
cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg)
assert isinstance(cfg, SampleConfig)
assert cfg.a == 1000
assert cfg.p == 500
assert cfg.seed == -1
assert cfg.b.x == 0
assert cfg.b.y == 10
assert cfg.b.z == -5
assert cfg.hello == "world"

View file

@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs, parse_args from bytelatent.args import TrainArgs
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import ( from bytelatent.data.iterators.multiprocess_iterator import (
@ -824,7 +825,7 @@ def main():
Plus all the default values in TrainArgs dataclass. Plus all the default values in TrainArgs dataclass.
""" """
train_args = parse_args(TrainArgs) train_args = parse_args_to_pydantic_model(TrainArgs)
if train_args.debug_dynamo: if train_args.debug_dynamo:
import torch._dynamo import torch._dynamo

View file

@ -0,0 +1 @@
[1, 2, 3]

View file

@ -0,0 +1,3 @@
config: fixtures/test-cfgs/root.yaml
b:
y: 10

View file

@ -0,0 +1 @@
a: 100

View file

@ -0,0 +1,6 @@
seed: -1
a: 1
b:
x: 0
y: ???
z: ???

View file

@ -0,0 +1,3 @@
config: fixtures/test-cfgs/middle.yaml
hello: world