mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
Merge f058373889
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
d44902da97
|
@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel):
|
||||||
norm_eps: float = 1e-5
|
norm_eps: float = 1e-5
|
||||||
|
|
||||||
rope_theta: float = 10000.0
|
rope_theta: float = 10000.0
|
||||||
|
rope_use_fp32_in_outer_product: bool = False
|
||||||
|
|
||||||
init_base_std: float | None = None
|
init_base_std: float | None = None
|
||||||
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
||||||
|
@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
def precompute_freqs_cis(
|
||||||
|
dim: int,
|
||||||
|
end: int,
|
||||||
|
theta: float = 10000.0,
|
||||||
|
rope_use_fp32_in_outer_product: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||||
|
|
||||||
|
@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
||||||
"""
|
"""
|
||||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
t = torch.arange(end, device=freqs.device)
|
t = torch.arange(end, device=freqs.device)
|
||||||
|
if rope_use_fp32_in_outer_product:
|
||||||
|
t = t.to(torch.float32)
|
||||||
|
|
||||||
freqs = torch.outer(t, freqs).float()
|
freqs = torch.outer(t, freqs).float()
|
||||||
|
|
||||||
cos, sin = freqs.cos(), freqs.sin()
|
cos, sin = freqs.cos(), freqs.sin()
|
||||||
|
@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
RotaryEmbedding Module
|
RotaryEmbedding Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024):
|
def __init__(
|
||||||
|
self,
|
||||||
|
theta: float,
|
||||||
|
head_dim: int,
|
||||||
|
max_seqlen: int = 1024,
|
||||||
|
rope_use_fp32_in_outer_product: bool = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.theta = theta
|
self.theta = theta
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.max_seqlen = max_seqlen
|
self.max_seqlen = max_seqlen
|
||||||
|
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
|
||||||
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"freqs_cis",
|
"freqs_cis",
|
||||||
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta),
|
precompute_freqs_cis(
|
||||||
|
dim=head_dim,
|
||||||
|
end=max_seqlen,
|
||||||
|
theta=theta,
|
||||||
|
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
||||||
|
),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.freqs_cis[...] = precompute_freqs_cis(
|
self.freqs_cis[...] = precompute_freqs_cis(
|
||||||
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
|
dim=self.head_dim,
|
||||||
|
end=self.max_seqlen,
|
||||||
|
theta=self.theta,
|
||||||
|
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -577,6 +601,7 @@ class BaseTransformer(nn.Module):
|
||||||
theta=args.rope_theta,
|
theta=args.rope_theta,
|
||||||
head_dim=args.head_dim or args.dim // args.n_heads,
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||||
max_seqlen=args.max_seqlen,
|
max_seqlen=args.max_seqlen,
|
||||||
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
)
|
)
|
||||||
self.eos_id = args.eos_id
|
self.eos_id = args.eos_id
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import asdict
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
@ -79,8 +78,8 @@ class MetricLogger:
|
||||||
and get_is_master()
|
and get_is_master()
|
||||||
):
|
):
|
||||||
run = wandb.init(
|
run = wandb.init(
|
||||||
config=asdict(self.args),
|
config=self.args.model_dump(),
|
||||||
**asdict(self.args.logging.wandb),
|
**self.args.logging.wandb.model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def log(self, metrics: dict[str, Any]):
|
def log(self, metrics: dict[str, Any]):
|
||||||
|
|
|
@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
patch_in_forward: bool = False
|
patch_in_forward: bool = False
|
||||||
|
|
||||||
# Architecture and dimensions
|
# Architecture and dimensions
|
||||||
dim_token: int = 256
|
dim_token: int | None = None
|
||||||
dim_global: int = 512
|
dim_global: int = 512
|
||||||
dim_local_decoder: int = 512
|
dim_local_decoder: int = 512
|
||||||
dim_local_encoder: int = 512
|
dim_local_encoder: int = 512
|
||||||
|
@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
use_fsdp: bool = True
|
use_fsdp: bool = True
|
||||||
attn_to_keep: str = "all"
|
attn_to_keep: str = "all"
|
||||||
|
|
||||||
# RoPE parameters
|
|
||||||
rope_theta: float = 10000.0
|
|
||||||
rope_use_fp32_in_outer_product: bool = False
|
|
||||||
|
|
||||||
# Parameter mixing
|
# Parameter mixing
|
||||||
pm_size: int = 0
|
pm_size: int = 0
|
||||||
|
|
||||||
|
@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
|
||||||
sliding_window=args.local_attention_window_len,
|
sliding_window=args.local_attention_window_len,
|
||||||
use_rope=args.use_rope,
|
use_rope=args.use_rope,
|
||||||
rope_theta=args.rope_theta,
|
rope_theta=args.rope_theta,
|
||||||
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
init_base_std=args.init_base_std,
|
init_base_std=args.init_base_std,
|
||||||
init_std_factor=args.init_std_factor,
|
init_std_factor=args.init_std_factor,
|
||||||
n_kv_heads=args.n_kv_heads,
|
n_kv_heads=args.n_kv_heads,
|
||||||
|
@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
|
||||||
sliding_window=args.local_attention_window_len,
|
sliding_window=args.local_attention_window_len,
|
||||||
use_rope=args.use_rope,
|
use_rope=args.use_rope,
|
||||||
rope_theta=args.rope_theta,
|
rope_theta=args.rope_theta,
|
||||||
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
init_base_std=args.init_base_std,
|
init_base_std=args.init_base_std,
|
||||||
init_std_factor=args.init_std_factor,
|
init_std_factor=args.init_std_factor,
|
||||||
n_kv_heads=args.n_kv_heads,
|
n_kv_heads=args.n_kv_heads,
|
||||||
|
|
|
@ -86,6 +86,7 @@ class LocalModelBase(nn.Module):
|
||||||
theta=args.rope_theta,
|
theta=args.rope_theta,
|
||||||
head_dim=args.head_dim or args.dim // args.n_heads,
|
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||||
max_seqlen=args.max_seqlen,
|
max_seqlen=args.max_seqlen,
|
||||||
|
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
|
||||||
)
|
)
|
||||||
self.pos_embeddings = None
|
self.pos_embeddings = None
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,15 @@ import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from dataclasses import dataclass
|
from pydantic import BaseModel
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class StoolArgs(BaseModel):
|
||||||
class StoolArgs:
|
name: str = None
|
||||||
|
dump_dir: str = None
|
||||||
config: Any = None
|
config: Any = None
|
||||||
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
|
||||||
script: str = "apps.main.train" # The script to run.
|
script: str = "apps.main.train" # The script to run.
|
||||||
|
@ -64,7 +65,7 @@ source activate {conda_env_path}
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
export LAUNCH_WITH="SBATCH"
|
export LAUNCH_WITH="SBATCH"
|
||||||
export DUMP_DIR={dump_dir}
|
export DUMP_DIR={dump_dir}
|
||||||
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml
|
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,8 +151,8 @@ def validate_args(args) -> None:
|
||||||
def launch_job(args: StoolArgs):
|
def launch_job(args: StoolArgs):
|
||||||
# Set up args default and validate them depending on the cluster or partition requested
|
# Set up args default and validate them depending on the cluster or partition requested
|
||||||
validate_args(args)
|
validate_args(args)
|
||||||
dump_dir = args.config["dump_dir"]
|
job_name = args.name or args.config["name"]
|
||||||
job_name = args.config["name"]
|
dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
|
||||||
print("Creating directories...")
|
print("Creating directories...")
|
||||||
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
|
||||||
if args.override:
|
if args.override:
|
||||||
|
@ -230,8 +231,7 @@ if __name__ == "__main__":
|
||||||
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
Then you can pass model.dim=32 to change values in LMTransformerArgs
|
||||||
or just name=tictac for top level attributes.
|
or just name=tictac for top level attributes.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Update this to blt code")
|
|
||||||
args = OmegaConf.from_cli()
|
args = OmegaConf.from_cli()
|
||||||
args.config = OmegaConf.load(args.config)
|
args.config = OmegaConf.load(args.config)
|
||||||
args = dataclass_from_dict(StoolArgs, args)
|
args = StoolArgs.model_validate(args)
|
||||||
launch_job(args)
|
launch_job(args)
|
||||||
|
|
|
@ -337,6 +337,7 @@ def train(args: TrainArgs):
|
||||||
|
|
||||||
# log model size
|
# log model size
|
||||||
|
|
||||||
|
logger.info(model)
|
||||||
logger.info(f"Model size: {model_param_count:,} total parameters")
|
logger.info(f"Model size: {model_param_count:,} total parameters")
|
||||||
|
|
||||||
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|
gpu_memory_monitor = GPUMemoryMonitor("cuda")
|
||||||
|
|
Loading…
Reference in a new issue