diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index dd0cce6..87d7334 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -45,6 +45,7 @@ class BaseTransformerArgs(BaseModel): norm_eps: float = 1e-5 rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False init_base_std: float | None = None init_std_factor: InitStdFactor = InitStdFactor.DISABLED @@ -78,7 +79,12 @@ def repeat_kv(x: torch.Tensor, n_rep: int, dim: int) -> torch.Tensor: ) -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + rope_use_fp32_in_outer_product: bool = False, +): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -96,6 +102,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) + if rope_use_fp32_in_outer_product: + t = t.to(torch.float32) + freqs = torch.outer(t, freqs).float() cos, sin = freqs.cos(), freqs.sin() @@ -232,22 +241,37 @@ class RotaryEmbedding(torch.nn.Module): RotaryEmbedding Module """ - def __init__(self, theta: float, head_dim: int, max_seqlen: int = 1024): + def __init__( + self, + theta: float, + head_dim: int, + max_seqlen: int = 1024, + rope_use_fp32_in_outer_product: bool = False, + ): super().__init__() self.theta = theta self.head_dim = head_dim self.max_seqlen = max_seqlen + self.rope_use_fp32_in_outer_product = rope_use_fp32_in_outer_product self.register_buffer( "freqs_cis", - precompute_freqs_cis(dim=head_dim, end=max_seqlen, theta=theta), + precompute_freqs_cis( + dim=head_dim, + end=max_seqlen, + theta=theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, + ), persistent=False, ) def reset_parameters(self): self.freqs_cis[...] = precompute_freqs_cis( - dim=self.head_dim, end=self.max_seqlen, theta=self.theta + dim=self.head_dim, + end=self.max_seqlen, + theta=self.theta, + rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, ) def forward( @@ -577,6 +601,7 @@ class BaseTransformer(nn.Module): theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.eos_id = args.eos_id diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index 2df5be7..fb443d7 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -4,7 +4,6 @@ import json import logging from collections import namedtuple -from dataclasses import asdict from datetime import datetime, timezone from pathlib import Path from typing import Any, Union @@ -79,8 +78,8 @@ class MetricLogger: and get_is_master() ): run = wandb.init( - config=asdict(self.args), - **asdict(self.args.logging.wandb), + config=self.args.model_dump(), + **self.args.logging.wandb.model_dump(), ) def log(self, metrics: dict[str, Any]): diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index a62be23..53a3be6 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -414,7 +414,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): patch_in_forward: bool = False # Architecture and dimensions - dim_token: int = 256 + dim_token: int | None = None dim_global: int = 512 dim_local_decoder: int = 512 dim_local_encoder: int = 512 @@ -523,10 +523,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): use_fsdp: bool = True attn_to_keep: str = "all" - # RoPE parameters - rope_theta: float = 10000.0 - rope_use_fp32_in_outer_product: bool = False - # Parameter mixing pm_size: int = 0 @@ -619,6 +615,7 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, @@ -661,6 +658,7 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: sliding_window=args.local_attention_window_len, use_rope=args.use_rope, rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, init_base_std=args.init_base_std, init_std_factor=args.init_std_factor, n_kv_heads=args.n_kv_heads, diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index c16f62e..d0e24c0 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -86,6 +86,7 @@ class LocalModelBase(nn.Module): theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, ) self.pos_embeddings = None diff --git a/bytelatent/stool.py b/bytelatent/stool.py index 965f4cb..b156177 100644 --- a/bytelatent/stool.py +++ b/bytelatent/stool.py @@ -4,14 +4,15 @@ import json import os import shutil import subprocess -from dataclasses import dataclass +from pydantic import BaseModel from typing import Any, Dict from omegaconf import OmegaConf -@dataclass -class StoolArgs: +class StoolArgs(BaseModel): + name: str = None + dump_dir: str = None config: Any = None launcher: str = "sbatch" # Can be sbatch or bash if already in salloc script: str = "apps.main.train" # The script to run. @@ -64,7 +65,7 @@ source activate {conda_env_path} export OMP_NUM_THREADS=1 export LAUNCH_WITH="SBATCH" export DUMP_DIR={dump_dir} -srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml +srun {log_output} -n {tasks} -N {nodes_per_run} python -u -m {script} config=$DUMP_DIR/base_config.yaml dump_dir=$DUMP_DIR name={name} """ @@ -150,8 +151,8 @@ def validate_args(args) -> None: def launch_job(args: StoolArgs): # Set up args default and validate them depending on the cluster or partition requested validate_args(args) - dump_dir = args.config["dump_dir"] - job_name = args.config["name"] + job_name = args.name or args.config["name"] + dump_dir = os.path.join(args.dump_dir, job_name) or args.config["dump_dir"] print("Creating directories...") os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override) if args.override: @@ -230,8 +231,7 @@ if __name__ == "__main__": Then you can pass model.dim=32 to change values in LMTransformerArgs or just name=tictac for top level attributes. """ - raise NotImplementedError("Update this to blt code") args = OmegaConf.from_cli() args.config = OmegaConf.load(args.config) - args = dataclass_from_dict(StoolArgs, args) + args = StoolArgs.model_validate(args) launch_job(args) diff --git a/bytelatent/train.py b/bytelatent/train.py index a775e46..d7543f1 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -337,6 +337,7 @@ def train(args: TrainArgs): # log model size + logger.info(model) logger.info(f"Model size: {model_param_count:,} total parameters") gpu_memory_monitor = GPUMemoryMonitor("cuda")