mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
Merge ece82cb960
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
078791996f
|
@ -40,12 +40,23 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
|
||||||
|
|
||||||
def parse_args(args_cls):
|
def parse_args(args_cls):
|
||||||
cli_args = OmegaConf.from_cli()
|
cli_args = OmegaConf.from_cli()
|
||||||
file_cfg = OmegaConf.load(cli_args.config)
|
file_cfgs = []
|
||||||
# We remove 'config' attribute from config as the underlying DataClass does not have it
|
if "config" in cli_args:
|
||||||
del cli_args.config
|
file_cfg = OmegaConf.load(cli_args["config"])
|
||||||
|
del cli_args["config"]
|
||||||
|
file_cfgs.append(file_cfg)
|
||||||
|
|
||||||
|
if "configs" in cli_args:
|
||||||
|
for c in cli_args["configs"]:
|
||||||
|
extra_file_cfg = OmegaConf.load(c)
|
||||||
|
file_cfgs.append(extra_file_cfg)
|
||||||
|
del cli_args["configs"]
|
||||||
|
|
||||||
default_cfg = OmegaConf.create(args_cls().model_dump())
|
default_cfg = OmegaConf.create(args_cls().model_dump())
|
||||||
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
|
to_merge = [default_cfg]
|
||||||
|
to_merge.extend(file_cfgs)
|
||||||
|
to_merge.append(cli_args)
|
||||||
|
cfg = OmegaConf.merge(*to_merge)
|
||||||
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
||||||
pydantic_args = args_cls.model_validate(cfg)
|
pydantic_args = args_cls.model_validate(cfg)
|
||||||
return pydantic_args
|
return pydantic_args
|
||||||
|
|
|
@ -56,13 +56,11 @@ model:
|
||||||
recompute_attn: false
|
recompute_attn: false
|
||||||
custom_bwd: false
|
custom_bwd: false
|
||||||
layer_ckpt: "none"
|
layer_ckpt: "none"
|
||||||
patch_only_encoder: false
|
|
||||||
patch_only_decoder: false
|
|
||||||
use_local_encoder_transformer: true
|
use_local_encoder_transformer: true
|
||||||
init_use_gaussian: true
|
init_use_gaussian: true
|
||||||
init_use_depth: "current"
|
init_use_depth: "current"
|
||||||
attn_bias_type: "block_causal"
|
|
||||||
attn_impl: "xformers"
|
attn_impl: "xformers"
|
||||||
|
attn_bias_type: "block_causal"
|
||||||
alpha_depth: "disabled"
|
alpha_depth: "disabled"
|
||||||
max_length: 256
|
max_length: 256
|
||||||
local_attention_window_len: 512
|
local_attention_window_len: 512
|
||||||
|
|
10
bytelatent/print_config.py
Normal file
10
bytelatent/print_config.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
from bytelatent.args import TrainArgs, parse_args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
train_args = parse_args(TrainArgs)
|
||||||
|
print(train_args.model_dump_json(indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in a new issue