mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
74 lines
2.4 KiB
Python
74 lines
2.4 KiB
Python
|
import copy
|
||
|
from typing import Type, TypeVar
|
||
|
|
||
|
import omegaconf
|
||
|
from omegaconf import DictConfig, OmegaConf
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
|
||
|
def parse_file_config(path: str) -> DictConfig:
|
||
|
file_cfg = OmegaConf.load(path)
|
||
|
if not isinstance(file_cfg, DictConfig):
|
||
|
raise ValueError(
|
||
|
f"File paths must parse to DictConfig, but it was: {type(file_cfg)}"
|
||
|
)
|
||
|
return file_cfg
|
||
|
|
||
|
|
||
|
def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]:
|
||
|
if "config" not in cfg:
|
||
|
return [cfg]
|
||
|
|
||
|
ordered_cfgs = []
|
||
|
cfg = copy.deepcopy(cfg)
|
||
|
config_arg = cfg["config"]
|
||
|
del cfg["config"]
|
||
|
ordered_cfgs.append(cfg)
|
||
|
|
||
|
if isinstance(config_arg, str):
|
||
|
file_cfg = parse_file_config(config_arg)
|
||
|
sub_configs = recursively_parse_config(file_cfg)
|
||
|
ordered_cfgs = sub_configs + ordered_cfgs
|
||
|
elif isinstance(config_arg, omegaconf.listconfig.ListConfig):
|
||
|
sub_configs = []
|
||
|
for c in config_arg:
|
||
|
if not isinstance(c, str):
|
||
|
raise ValueError(
|
||
|
f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}'
|
||
|
)
|
||
|
config_to_parse = parse_file_config(c)
|
||
|
sub_configs.extend(recursively_parse_config(config_to_parse))
|
||
|
ordered_cfgs = sub_configs + ordered_cfgs
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}'
|
||
|
)
|
||
|
return ordered_cfgs
|
||
|
|
||
|
|
||
|
def parse_args_with_default(
|
||
|
*, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None
|
||
|
):
|
||
|
if cli_args is None:
|
||
|
cli_args = OmegaConf.from_cli()
|
||
|
assert isinstance(
|
||
|
cli_args, DictConfig
|
||
|
), f"CLI Args must be a DictConfig, not {type(cli_args)}"
|
||
|
ordered_cfgs = recursively_parse_config(cli_args)
|
||
|
if default_cfg is not None:
|
||
|
ordered_cfgs.insert(0, default_cfg)
|
||
|
cfg = OmegaConf.merge(*ordered_cfgs)
|
||
|
return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||
|
|
||
|
|
||
|
T = TypeVar("T", bound=BaseModel)
|
||
|
|
||
|
|
||
|
def parse_args_to_pydantic_model(
|
||
|
args_cls: Type[T], cli_args: DictConfig | None = None
|
||
|
) -> T:
|
||
|
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||
|
parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args)
|
||
|
pydantic_args = args_cls.model_validate(parsed_cfg)
|
||
|
return pydantic_args
|