Merge ece82cb960 into sapling-pr-archive-EntilZha
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-12 11:25:18 -08:00 committed by GitHub
commit 078791996f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 26 additions and 7 deletions

View file

@ -40,12 +40,23 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
file_cfgs = []
if "config" in cli_args:
file_cfg = OmegaConf.load(cli_args["config"])
del cli_args["config"]
file_cfgs.append(file_cfg)
if "configs" in cli_args:
for c in cli_args["configs"]:
extra_file_cfg = OmegaConf.load(c)
file_cfgs.append(extra_file_cfg)
del cli_args["configs"]
default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
to_merge = [default_cfg]
to_merge.extend(file_cfgs)
to_merge.append(cli_args)
cfg = OmegaConf.merge(*to_merge)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args

View file

@ -56,13 +56,11 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
attn_bias_type: "block_causal"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512

View file

@ -0,0 +1,10 @@
from bytelatent.args import TrainArgs, parse_args
def main():
train_args = parse_args(TrainArgs)
print(train_args.model_dump_json(indent=4))
if __name__ == "__main__":
main()