Merge f058373889 into sapling-pr-archive-EntilZha

This commit is contained in:
Pedro Rodriguez 2025-02-06 09:37:27 -08:00 committed by GitHub
commit d44902da97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 44 additions and 20 deletions

View file

@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel):
norm_eps: float = 1e-5 norm_eps: float = 1e-5
rope_theta: float = 10000.0 rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
init_base_std: float | None = None init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED init_std_factor: InitStdFactor = InitStdFactor.DISABLED
@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor:
) )
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): def precompute_freqs_cis(
dim: int,
end: int,
theta: float = 10000.0,
rope_use_fp32_in_outer_product: bool = False,
):
""" """
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
""" """
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) t = torch.arange(end, device=freqs.device)
if rope_use_fp32_in_outer_product:
t = t.to(torch.float32)
freqs = torch.outer(t, freqs).float() freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin() cos, sin = freqs.cos(), freqs.sin()
@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module):
RotaryEmbedding Module RotaryEmbedding Module
""" """
def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024): def __init__(
self,
theta: float,
head_dim: int,
max_seqlen: int = 1024,
rope_use_fp32_in_outer_product: bool = False,
):
super().__init__() super().__init__()
self.theta = theta self.theta = theta
self.head_dim = head_dim self.head_dim = head_dim
self.max_seqlen = max_seqlen self.max_seqlen = max_seqlen
self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product
self.register_buffer( self.register_buffer(
"freqs_cis", "freqs_cis",
precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta), precompute_freqs_cis(
dim=head_dim,
end=max_seqlen,
theta=theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
),
persistent=False, persistent=False,
) )
def reset_parameters(self): def reset_parameters(self):
self.freqs_cis[...] = precompute_freqs_cis( self.freqs_cis[...] = precompute_freqs_cis(
dim=self.head_dim, end=self.max_seqlen, theta=self.theta dim=self.head_dim,
end=self.max_seqlen,
theta=self.theta,
rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product,
) )
def forward( def forward(
@ -577,6 +601,7 @@ class BaseTransformer(nn.Module):
theta=args.rope_theta, theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads, head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen, max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
) )
self.eos_id = args.eos_id self.eos_id = args.eos_id

View file

@ -4,7 +4,6 @@
import json import json
import logging import logging
from collections import namedtuple from collections import namedtuple
from dataclasses import asdict
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any, Union
@ -79,8 +78,8 @@ class MetricLogger:
and get_is_master() and get_is_master()
): ):
run = wandb.init( run = wandb.init(
config=asdict(self.args), config=self.args.model_dump(),
**asdict(self.args.logging.wandb), **self.args.logging.wandb.model_dump(),
) )
def log(self, metrics: dict[str, Any]): def log(self, metrics: dict[str, Any]):

View file

@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
patch_in_forward: bool = False patch_in_forward: bool = False
# Architecture and dimensions # Architecture and dimensions
dim_token: int = 256 dim_token: int | None = None
dim_global: int = 512 dim_global: int = 512
dim_local_decoder: int = 512 dim_local_decoder: int = 512
dim_local_encoder: int = 512 dim_local_encoder: int = 512
@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
use_fsdp: bool = True use_fsdp: bool = True
attn_to_keep: str = "all" attn_to_keep: str = "all"
# RoPE parameters
rope_theta: float = 10000.0
rope_use_fp32_in_outer_product: bool = False
# Parameter mixing # Parameter mixing
pm_size: int = 0 pm_size: int = 0
@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
sliding_window=args.local_attention_window_len, sliding_window=args.local_attention_window_len,
use_rope=args.use_rope, use_rope=args.use_rope,
rope_theta=args.rope_theta, rope_theta=args.rope_theta,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
init_base_std=args.init_base_std, init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor, init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads, n_kv_heads=args.n_kv_heads,
@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
sliding_window=args.local_attention_window_len, sliding_window=args.local_attention_window_len,
use_rope=args.use_rope, use_rope=args.use_rope,
rope_theta=args.rope_theta, rope_theta=args.rope_theta,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
init_base_std=args.init_base_std, init_base_std=args.init_base_std,
init_std_factor=args.init_std_factor, init_std_factor=args.init_std_factor,
n_kv_heads=args.n_kv_heads, n_kv_heads=args.n_kv_heads,

View file

@ -86,6 +86,7 @@ class LocalModelBase(nn.Module):
theta=args.rope_theta, theta=args.rope_theta,
head_dim=args.head_dim or args.dim // args.n_heads, head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen, max_seqlen=args.max_seqlen,
rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product,
) )
self.pos_embeddings = None self.pos_embeddings = None

View file

@ -4,14 +4,15 @@ import json
import os import os
import shutil import shutil
import subprocess import subprocess
from dataclasses import dataclass from pydantic import BaseModel
from typing import Any, Dict from typing import Any, Dict
from omegaconf import OmegaConf from omegaconf import OmegaConf
@dataclass class StoolArgs(BaseModel):
class StoolArgs: name: str = None
dump_dir: str = None
config: Any = None config: Any = None
launcher: str = "sbatch" # Can be sbatch or bash if already in salloc launcher: str = "sbatch" # Can be sbatch or bash if already in salloc
script: str = "apps.main.train" # The script to run. script: str = "apps.main.train" # The script to run.
@ -64,7 +65,7 @@ source activate {conda_env_path}
export OMP_NUM_THREADS=1 export OMP_NUM_THREADS=1
export LAUNCH_WITH="SBATCH" export LAUNCH_WITH="SBATCH"
export DUMP_DIR={dump_dir} export DUMP_DIR={dump_dir}
srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name}
""" """
@ -150,8 +151,8 @@ def validate_args(args) -> None:
def launch_job(args: StoolArgs): def launch_job(args: StoolArgs):
# Set up args default and validate them depending on the cluster or partition requested # Set up args default and validate them depending on the cluster or partition requested
validate_args(args) validate_args(args)
dump_dir = args.config["dump_dir"] job_name = args.name or args.config["name"]
job_name = args.config["name"] dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"]
print("Creating directories...") print("Creating directories...")
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override)
if args.override: if args.override:
@ -230,8 +231,7 @@ if __name__ == "__main__":
Then you can pass model.dim=32 to change values in LMTransformerArgs Then you can pass model.dim=32 to change values in LMTransformerArgs
or just name=tictac for top level attributes. or just name=tictac for top level attributes.
""" """
raise NotImplementedError("Update this to blt code")
args = OmegaConf.from_cli() args = OmegaConf.from_cli()
args.config = OmegaConf.load(args.config) args.config = OmegaConf.load(args.config)
args = dataclass_from_dict(StoolArgs, args) args = StoolArgs.model_validate(args)
launch_job(args) launch_job(args)

View file

@ -337,6 +337,7 @@ def train(args: TrainArgs):
# log model size # log model size
logger.info(model)
logger.info(f"Model size: {model_param_count:,} total parameters") logger.info(f"Model size: {model_param_count:,} total parameters")
gpu_memory_monitor = GPUMemoryMonitor("cuda") gpu_memory_monitor = GPUMemoryMonitor("cuda")