mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Merge ece82cb960
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
078791996f
|
@ -40,12 +40,23 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
|
|||
|
||||
def parse_args(args_cls):
|
||||
cli_args = OmegaConf.from_cli()
|
||||
file_cfg = OmegaConf.load(cli_args.config)
|
||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
||||
del cli_args.config
|
||||
file_cfgs = []
|
||||
if "config" in cli_args:
|
||||
file_cfg = OmegaConf.load(cli_args["config"])
|
||||
del cli_args["config"]
|
||||
file_cfgs.append(file_cfg)
|
||||
|
||||
if "configs" in cli_args:
|
||||
for c in cli_args["configs"]:
|
||||
extra_file_cfg = OmegaConf.load(c)
|
||||
file_cfgs.append(extra_file_cfg)
|
||||
del cli_args["configs"]
|
||||
|
||||
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
||||
to_merge = [default_cfg]
|
||||
to_merge.extend(file_cfgs)
|
||||
to_merge.append(cli_args)
|
||||
cfg = OmegaConf.merge(*to_merge)
|
||||
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||
pydantic_args = args_cls.model_validate(cfg)
|
||||
return pydantic_args
|
||||
|
|
|
@ -56,13 +56,11 @@ model:
|
|||
recompute_attn: false
|
||||
custom_bwd: false
|
||||
layer_ckpt: "none"
|
||||
patch_only_encoder: false
|
||||
patch_only_decoder: false
|
||||
use_local_encoder_transformer: true
|
||||
init_use_gaussian: true
|
||||
init_use_depth: "current"
|
||||
attn_bias_type: "block_causal"
|
||||
attn_impl: "xformers"
|
||||
attn_bias_type: "block_causal"
|
||||
alpha_depth: "disabled"
|
||||
max_length: 256
|
||||
local_attention_window_len: 512
|
||||
|
|
10
bytelatent/print_config.py
Normal file
10
bytelatent/print_config.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
from bytelatent.args import TrainArgs, parse_args
|
||||
|
||||
|
||||
def main():
|
||||
train_args = parse_args(TrainArgs)
|
||||
print(train_args.model_dump_json(indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue