mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 22:34:37 +00:00
Make it possible to specify multiple config files (#54)
Summary: Make it possible to specify multiple config files. Parsing CLI is not a special case anymore, just uses the same config inheritance method. Test Plan: Test that this iterpolates in the right order via unit tests Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is: - Default pydantic args - Included configs, eg `config` - CLI args ``` python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null ``` Summary: Test Plan:
This commit is contained in:
parent
9f29e0de18
commit
82ab5930ec
13 changed files with 286 additions and 27 deletions
73
bytelatent/config_parser.py
Normal file
73
bytelatent/config_parser.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import copy
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import omegaconf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def parse_file_config(path: str) -> DictConfig:
|
||||
file_cfg = OmegaConf.load(path)
|
||||
if not isinstance(file_cfg, DictConfig):
|
||||
raise ValueError(
|
||||
f"File paths must parse to DictConfig, but it was: {type(file_cfg)}"
|
||||
)
|
||||
return file_cfg
|
||||
|
||||
|
||||
def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]:
|
||||
if "config" not in cfg:
|
||||
return [cfg]
|
||||
|
||||
ordered_cfgs = []
|
||||
cfg = copy.deepcopy(cfg)
|
||||
config_arg = cfg["config"]
|
||||
del cfg["config"]
|
||||
ordered_cfgs.append(cfg)
|
||||
|
||||
if isinstance(config_arg, str):
|
||||
file_cfg = parse_file_config(config_arg)
|
||||
sub_configs = recursively_parse_config(file_cfg)
|
||||
ordered_cfgs = sub_configs + ordered_cfgs
|
||||
elif isinstance(config_arg, omegaconf.listconfig.ListConfig):
|
||||
sub_configs = []
|
||||
for c in config_arg:
|
||||
if not isinstance(c, str):
|
||||
raise ValueError(
|
||||
f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}'
|
||||
)
|
||||
config_to_parse = parse_file_config(c)
|
||||
sub_configs.extend(recursively_parse_config(config_to_parse))
|
||||
ordered_cfgs = sub_configs + ordered_cfgs
|
||||
else:
|
||||
raise ValueError(
|
||||
f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}'
|
||||
)
|
||||
return ordered_cfgs
|
||||
|
||||
|
||||
def parse_args_with_default(
|
||||
*, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None
|
||||
):
|
||||
if cli_args is None:
|
||||
cli_args = OmegaConf.from_cli()
|
||||
assert isinstance(
|
||||
cli_args, DictConfig
|
||||
), f"CLI Args must be a DictConfig, not {type(cli_args)}"
|
||||
ordered_cfgs = recursively_parse_config(cli_args)
|
||||
if default_cfg is not None:
|
||||
ordered_cfgs.insert(0, default_cfg)
|
||||
cfg = OmegaConf.merge(*ordered_cfgs)
|
||||
return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def parse_args_to_pydantic_model(
|
||||
args_cls: Type[T], cli_args: DictConfig | None = None
|
||||
) -> T:
|
||||
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||||
parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args)
|
||||
pydantic_args = args_cls.model_validate(parsed_cfg)
|
||||
return pydantic_args
|
Loading…
Add table
Add a link
Reference in a new issue