import os import pytest from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.config_parser import ( parse_args_to_pydantic_model, parse_file_config, recursively_parse_config, ) FIXTURE_DIR = "fixtures/test-cfgs" def test_parse_file_config(): with pytest.raises(ValueError): cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) assert isinstance(cfg, DictConfig) def test_nop(): cfg = OmegaConf.create({"a": 1}) parsed_cfgs = recursively_parse_config(cfg) assert len(parsed_cfgs) == 1 assert parsed_cfgs[0] == cfg def test_root(): cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) parsed_cfgs = recursively_parse_config(cli_cfg) assert len(parsed_cfgs) == 2 assert len(parsed_cfgs[1]) == 0 assert parsed_cfgs[0]["seed"] == -1 with pytest.raises(MissingMandatoryValue): assert parsed_cfgs[0]["b"]["y"] is not None # Test basic cli override cli_cfg = OmegaConf.create( {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} ) parsed_cfgs = recursively_parse_config(cli_cfg) assert parsed_cfgs[1]["seed"] == 42 cfg = OmegaConf.merge(*parsed_cfgs) assert cfg["seed"] == 42 def test_one_level_include(): cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) parsed_cfgs = recursively_parse_config(cli_cfg) assert len(parsed_cfgs) == 3 assert parsed_cfgs[0]["seed"] == -1 assert parsed_cfgs[1]["b"]["y"] == 10 assert len(parsed_cfgs[2]) == 0 cfg = OmegaConf.merge(*parsed_cfgs) assert cfg["b"]["y"] == 10 cli_cfg = OmegaConf.create( {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} ) parsed_cfgs = recursively_parse_config(cli_cfg) assert len(parsed_cfgs) == 3 assert parsed_cfgs[0]["seed"] == -1 assert parsed_cfgs[1]["b"]["y"] == 10 assert parsed_cfgs[2]["b"]["y"] == 100 cfg = OmegaConf.merge(*parsed_cfgs) assert cfg["b"]["y"] == 100 def test_two_level_include(): cli_cfg = OmegaConf.create( {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} ) parsed_cfgs = recursively_parse_config(cli_cfg) assert len(parsed_cfgs) == 4 assert parsed_cfgs[0]["seed"] == -1 assert parsed_cfgs[1]["b"]["y"] == 10 assert parsed_cfgs[2]["hello"] == "world" assert parsed_cfgs[3]["p"] == 500 assert parsed_cfgs[3]["b"]["z"] == -2 cfg = OmegaConf.merge(*parsed_cfgs) assert cfg["a"] == 1 assert cfg["seed"] == -1 assert cfg["b"]["x"] == 0 assert cfg["b"]["y"] == 10 assert cfg["b"]["z"] == -2 assert cfg["hello"] == "world" def test_multiple_includes(): cli_cfg = OmegaConf.create( { "config": [ os.path.join(FIXTURE_DIR, "top.yaml"), os.path.join(FIXTURE_DIR, "override.yaml"), ], "p": 500, "b": {"z": -2}, } ) parsed_cfgs = recursively_parse_config(cli_cfg) assert len(parsed_cfgs) == 5 assert parsed_cfgs[0]["seed"] == -1 assert parsed_cfgs[1]["b"]["y"] == 10 assert parsed_cfgs[2]["hello"] == "world" assert parsed_cfgs[3]["a"] == 100 assert parsed_cfgs[4]["p"] == 500 assert parsed_cfgs[4]["b"]["z"] == -2 cfg = OmegaConf.merge(*parsed_cfgs) assert cfg["a"] == 100 assert cfg["seed"] == -1 assert cfg["b"]["x"] == 0 assert cfg["b"]["y"] == 10 assert cfg["b"]["z"] == -2 assert cfg["hello"] == "world" cli_cfg = OmegaConf.create( { "config": [ os.path.join(FIXTURE_DIR, "top.yaml"), os.path.join(FIXTURE_DIR, "override.yaml"), ], "p": 500, "b": {"z": -2}, "a": 1000, } ) parsed_cfgs = recursively_parse_config(cli_cfg) assert len(parsed_cfgs) == 5 assert parsed_cfgs[0]["seed"] == -1 assert parsed_cfgs[1]["b"]["y"] == 10 assert parsed_cfgs[2]["hello"] == "world" assert parsed_cfgs[3]["a"] == 100 assert parsed_cfgs[4]["p"] == 500 assert parsed_cfgs[4]["b"]["z"] == -2 cfg = OmegaConf.merge(*parsed_cfgs) assert cfg["a"] == 1000 assert cfg["seed"] == -1 assert cfg["b"]["x"] == 0 assert cfg["b"]["y"] == 10 assert cfg["b"]["z"] == -2 assert cfg["hello"] == "world" class SubConfig(BaseModel): model_config = ConfigDict(extra="forbid") x: int = -100 y: int = -100 z: int = -5 class SampleConfig(BaseModel): model_config = ConfigDict(extra="forbid") a: int = -100 seed: int = -100 b: SubConfig = SubConfig() hello: str = "" p: int = -100 def test_pydantic_parse(): cli_cfg = OmegaConf.create( { "config": [ os.path.join(FIXTURE_DIR, "top.yaml"), os.path.join(FIXTURE_DIR, "override.yaml"), ], "p": 500, "a": 1000, } ) cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) assert isinstance(cfg, SampleConfig) assert cfg.a == 1000 assert cfg.p == 500 assert cfg.seed == -1 assert cfg.b.x == 0 assert cfg.b.y == 10 assert cfg.b.z == -5 assert cfg.hello == "world"