blt/bytelatent/configs/debug.yaml
Pedro Rodriguez 6ffeb66b53
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run
Changes for training entropy model and correcting attention in local models (#25)
Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
2025-01-17 14:23:01 -08:00

110 lines
2.5 KiB
YAML

# Template config, need to change dump_dir, data.root_dir and tokenizer.path
# Evals can be activated by uncommenting its config
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
dump_dir: /tmp/
name: "debug"
steps: 100_000
probe_freq: null
seed: 777
optim:
lr: 4e-04
warmup: 500
lr_min_ratio: 0.1
clip: 10.0
distributed:
fsdp_type: full_shard
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
tp_size: 1
model:
n_heads: 8
dim: 512
vocab_size: 260
dim_token: 256
patch_size: 6
tokenization_mode: "bytes"
patching_mode: "space"
tie_local_encoder_decoder_logits: false
data_loader_patching: true
max_encoder_seq_length: 12288
pad_to_max_length: true
patching_threshold: 3.1439168453216553
encoder_hash_byte_group_size: [4]
encoder_hash_byte_group_vocab: 50002
encoder_hash_byte_group_nb_functions: 3
encoder_enable_byte_ngrams: false
cross_attn_encoder: true # assuming cross_attention is true
cross_attn_decoder: true # assuming cross_attention is true
cross_attn_window_encoder: 512
cross_attn_window_decoder: 512
dim_local_encoder: 256
dim_local_decoder: 256
cross_attn_k: 8
cross_attn_nheads: 4
cross_attn_all_layers_decoder: true
cross_attn_all_layers_encoder: true
cross_attn_use_flex_attention: true
cross_attn_init_by_pooling: true
log_patch_lengths: true
non_linearity: "swiglu"
use_rope: true
recompute_fc1_out: false
recompute_fc3_out: false
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512
max_seqlen: 12288
downsampling_by_pooling: "max"
data:
root_dir: ???
sources:
dclm_baseline_1.0: 1.0
batch_size: 2
prefetch_size: 64
seq_len: 4096
load_async: true
preprocess_dir: ???
tokenizer_args:
name: blt
init_kwargs:
bpe_tokenizer_path: ???
profiling:
run: false
checkpoint:
dump:
every: 500
keep: 3
eval:
every: 1000
keep: -1
logging:
freq: 10
eval_on_gpus: 8
eval:
dataset_dir: /checkpoint/amaia/codegen/datasets/eval
tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1