mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-15 16:49:53 +00:00
Make it possible to specify multiple config files (#54)
Summary: Make it possible to specify multiple config files. Parsing CLI is not a special case anymore, just uses the same config inheritance method. Test Plan: Test that this iterpolates in the right order via unit tests Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is: - Default pydantic args - Included configs, eg `config` - CLI args ``` python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null ``` Summary: Test Plan:
This commit is contained in:
parent
9f29e0de18
commit
82ab5930ec
13 changed files with 286 additions and 27 deletions
180
bytelatent/test_config_parser.py
Normal file
180
bytelatent/test_config_parser.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bytelatent.config_parser import (
|
||||
parse_args_to_pydantic_model,
|
||||
parse_file_config,
|
||||
recursively_parse_config,
|
||||
)
|
||||
|
||||
FIXTURE_DIR = "fixtures/test-cfgs"
|
||||
|
||||
|
||||
def test_parse_file_config():
|
||||
with pytest.raises(ValueError):
|
||||
cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml"))
|
||||
assert isinstance(cfg, DictConfig)
|
||||
|
||||
|
||||
def test_nop():
|
||||
cfg = OmegaConf.create({"a": 1})
|
||||
parsed_cfgs = recursively_parse_config(cfg)
|
||||
assert len(parsed_cfgs) == 1
|
||||
assert parsed_cfgs[0] == cfg
|
||||
|
||||
|
||||
def test_root():
|
||||
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")})
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert len(parsed_cfgs) == 2
|
||||
assert len(parsed_cfgs[1]) == 0
|
||||
assert parsed_cfgs[0]["seed"] == -1
|
||||
with pytest.raises(MissingMandatoryValue):
|
||||
assert parsed_cfgs[0]["b"]["y"] is not None
|
||||
|
||||
# Test basic cli override
|
||||
cli_cfg = OmegaConf.create(
|
||||
{"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42}
|
||||
)
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert parsed_cfgs[1]["seed"] == 42
|
||||
cfg = OmegaConf.merge(*parsed_cfgs)
|
||||
assert cfg["seed"] == 42
|
||||
|
||||
|
||||
def test_one_level_include():
|
||||
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")})
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert len(parsed_cfgs) == 3
|
||||
assert parsed_cfgs[0]["seed"] == -1
|
||||
assert parsed_cfgs[1]["b"]["y"] == 10
|
||||
assert len(parsed_cfgs[2]) == 0
|
||||
cfg = OmegaConf.merge(*parsed_cfgs)
|
||||
assert cfg["b"]["y"] == 10
|
||||
|
||||
cli_cfg = OmegaConf.create(
|
||||
{"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}}
|
||||
)
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert len(parsed_cfgs) == 3
|
||||
assert parsed_cfgs[0]["seed"] == -1
|
||||
assert parsed_cfgs[1]["b"]["y"] == 10
|
||||
assert parsed_cfgs[2]["b"]["y"] == 100
|
||||
cfg = OmegaConf.merge(*parsed_cfgs)
|
||||
assert cfg["b"]["y"] == 100
|
||||
|
||||
|
||||
def test_two_level_include():
|
||||
cli_cfg = OmegaConf.create(
|
||||
{"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}}
|
||||
)
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert len(parsed_cfgs) == 4
|
||||
assert parsed_cfgs[0]["seed"] == -1
|
||||
assert parsed_cfgs[1]["b"]["y"] == 10
|
||||
assert parsed_cfgs[2]["hello"] == "world"
|
||||
assert parsed_cfgs[3]["p"] == 500
|
||||
assert parsed_cfgs[3]["b"]["z"] == -2
|
||||
cfg = OmegaConf.merge(*parsed_cfgs)
|
||||
assert cfg["a"] == 1
|
||||
assert cfg["seed"] == -1
|
||||
assert cfg["b"]["x"] == 0
|
||||
assert cfg["b"]["y"] == 10
|
||||
assert cfg["b"]["z"] == -2
|
||||
assert cfg["hello"] == "world"
|
||||
|
||||
|
||||
def test_multiple_includes():
|
||||
cli_cfg = OmegaConf.create(
|
||||
{
|
||||
"config": [
|
||||
os.path.join(FIXTURE_DIR, "top.yaml"),
|
||||
os.path.join(FIXTURE_DIR, "override.yaml"),
|
||||
],
|
||||
"p": 500,
|
||||
"b": {"z": -2},
|
||||
}
|
||||
)
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert len(parsed_cfgs) == 5
|
||||
assert parsed_cfgs[0]["seed"] == -1
|
||||
assert parsed_cfgs[1]["b"]["y"] == 10
|
||||
assert parsed_cfgs[2]["hello"] == "world"
|
||||
assert parsed_cfgs[3]["a"] == 100
|
||||
assert parsed_cfgs[4]["p"] == 500
|
||||
assert parsed_cfgs[4]["b"]["z"] == -2
|
||||
cfg = OmegaConf.merge(*parsed_cfgs)
|
||||
assert cfg["a"] == 100
|
||||
assert cfg["seed"] == -1
|
||||
assert cfg["b"]["x"] == 0
|
||||
assert cfg["b"]["y"] == 10
|
||||
assert cfg["b"]["z"] == -2
|
||||
assert cfg["hello"] == "world"
|
||||
|
||||
cli_cfg = OmegaConf.create(
|
||||
{
|
||||
"config": [
|
||||
os.path.join(FIXTURE_DIR, "top.yaml"),
|
||||
os.path.join(FIXTURE_DIR, "override.yaml"),
|
||||
],
|
||||
"p": 500,
|
||||
"b": {"z": -2},
|
||||
"a": 1000,
|
||||
}
|
||||
)
|
||||
parsed_cfgs = recursively_parse_config(cli_cfg)
|
||||
assert len(parsed_cfgs) == 5
|
||||
assert parsed_cfgs[0]["seed"] == -1
|
||||
assert parsed_cfgs[1]["b"]["y"] == 10
|
||||
assert parsed_cfgs[2]["hello"] == "world"
|
||||
assert parsed_cfgs[3]["a"] == 100
|
||||
assert parsed_cfgs[4]["p"] == 500
|
||||
assert parsed_cfgs[4]["b"]["z"] == -2
|
||||
cfg = OmegaConf.merge(*parsed_cfgs)
|
||||
assert cfg["a"] == 1000
|
||||
assert cfg["seed"] == -1
|
||||
assert cfg["b"]["x"] == 0
|
||||
assert cfg["b"]["y"] == 10
|
||||
assert cfg["b"]["z"] == -2
|
||||
assert cfg["hello"] == "world"
|
||||
|
||||
|
||||
class SubConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
x: int = -100
|
||||
y: int = -100
|
||||
z: int = -5
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
a: int = -100
|
||||
seed: int = -100
|
||||
b: SubConfig = SubConfig()
|
||||
hello: str = ""
|
||||
p: int = -100
|
||||
|
||||
|
||||
def test_pydantic_parse():
|
||||
cli_cfg = OmegaConf.create(
|
||||
{
|
||||
"config": [
|
||||
os.path.join(FIXTURE_DIR, "top.yaml"),
|
||||
os.path.join(FIXTURE_DIR, "override.yaml"),
|
||||
],
|
||||
"p": 500,
|
||||
"a": 1000,
|
||||
}
|
||||
)
|
||||
cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg)
|
||||
assert isinstance(cfg, SampleConfig)
|
||||
assert cfg.a == 1000
|
||||
assert cfg.p == 500
|
||||
assert cfg.seed == -1
|
||||
assert cfg.b.x == 0
|
||||
assert cfg.b.y == 10
|
||||
assert cfg.b.z == -5
|
||||
assert cfg.hello == "world"
|
Loading…
Add table
Add a link
Reference in a new issue