blt/bytelatent/test_config_parser.py
Pedro Rodriguez 82ab5930ec
Make it possible to specify multiple config files (#54)
Summary:

Make it possible to specify multiple config files.
Parsing CLI is not a special case anymore, just uses the same config inheritance method.

Test Plan:

Test that this iterpolates in the right order via unit tests

Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is:

- Default pydantic args
- Included configs, eg `config`
- CLI args

```
python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null

```


Summary:

Test Plan:
2025-02-18 10:42:44 -08:00

181 lines
5.2 KiB
Python

import os
import pytest
from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf
from pydantic import BaseModel, ConfigDict
from bytelatent.config_parser import (
parse_args_to_pydantic_model,
parse_file_config,
recursively_parse_config,
)
FIXTURE_DIR = "fixtures/test-cfgs"
def test_parse_file_config():
with pytest.raises(ValueError):
cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml"))
assert isinstance(cfg, DictConfig)
def test_nop():
cfg = OmegaConf.create({"a": 1})
parsed_cfgs = recursively_parse_config(cfg)
assert len(parsed_cfgs) == 1
assert parsed_cfgs[0] == cfg
def test_root():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 2
assert len(parsed_cfgs[1]) == 0
assert parsed_cfgs[0]["seed"] == -1
with pytest.raises(MissingMandatoryValue):
assert parsed_cfgs[0]["b"]["y"] is not None
# Test basic cli override
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert parsed_cfgs[1]["seed"] == 42
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["seed"] == 42
def test_one_level_include():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert len(parsed_cfgs[2]) == 0
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 10
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["b"]["y"] == 100
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 100
def test_two_level_include():
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 4
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["p"] == 500
assert parsed_cfgs[3]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"
def test_multiple_includes():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 100
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
"a": 1000,
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1000
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"
class SubConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
x: int = -100
y: int = -100
z: int = -5
class SampleConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
a: int = -100
seed: int = -100
b: SubConfig = SubConfig()
hello: str = ""
p: int = -100
def test_pydantic_parse():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"a": 1000,
}
)
cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg)
assert isinstance(cfg, SampleConfig)
assert cfg.a == 1000
assert cfg.p == 500
assert cfg.seed == -1
assert cfg.b.x == 0
assert cfg.b.y == 10
assert cfg.b.z == -5
assert cfg.hello == "world"