Changes for training entropy model and correcting attention in local models (#25)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-17 14:23:01 -08:00 committed by GitHub
parent caec8d2621
commit 6ffeb66b53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 349 additions and 138 deletions

View file

@ -4,7 +4,7 @@ from enum import Enum
from typing import Optional, Tuple, Union
import torch
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.nn import functional as F
from torch.nn.attention.flex_attention import (
@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
@ -30,13 +31,14 @@ class InitStdFactor(Enum):
class BaseTransformerArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
dim: int = 512
n_layers: int = 8
head_dim: Optional[int] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
head_dim: int | None = None
n_heads: int | None = None
n_kv_heads: int | None = None
ffn_dim_multiplier: Optional[float] = None
ffn_dim_multiplier: float | None = None
multiple_of: int = 256
@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
rope_theta: float = 10000.0
init_base_std: Optional[float] = None
init_base_std: float | None = None
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
max_seqlen: int = 1024
attn_impl: str | None = "sdpa"
attn_bias_type: str | None = None
# Special token config
eos_id: int | None = EOS_ID
def cross_entropy(pred, target, **kwargs):
return F.nll_loss(
@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
torch.nn.init.ones_(self.weight) # type: ignore
def _reshape_for_attn_bias(
attn_bias: AttentionBias | None,
*tensors: torch.Tensor,
) -> list[torch.Tensor]:
to_transform = list(tensors)
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
# could be `view` instead of reshape during training, but for inference
# have to reshape due to strides mismatch
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
return to_transform
class Attention(nn.Module):
def __init__(
self,
@ -371,9 +390,12 @@ class Attention(nn.Module):
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
elif attn_impl == "fmha":
elif attn_impl == "xformers":
assert mask is None or isinstance(mask, AttentionBias)
query_shape = xq.shape
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
output = output.view(query_shape)
# This uses B S H D instead of B H S D of pytorch
elif attn_impl == "sdpa":
@ -522,14 +544,16 @@ class TransformerBlock(nn.Module):
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
attn_impl: str = "sdpa",
) -> torch.Tensor:
h = x + self.attention(
attn_out = self.attention(
self.attention_norm(x),
freq_cis,
tok_idx=tok_idx,
mask=mask,
attn_impl=attn_impl,
)
out = h + self.feed_forward(self.ffn_norm(h))
h = x + attn_out
h_norm = self.ffn_norm(h)
out = h + self.feed_forward(h_norm)
return out
def init_weights(self, init_std=None, factor=1.0):
@ -545,6 +569,8 @@ class BaseTransformer(nn.Module):
super().__init__()
self.dim = args.dim
self.init_base_std = args.init_base_std
self.attn_impl = args.attn_impl
self.attn_bias_type = args.attn_bias_type
self.init_std_factor = InitStdFactor(args.init_std_factor)
self.max_seqlen = args.max_seqlen
self.rope_embeddings = RotaryEmbedding(
@ -552,6 +578,7 @@ class BaseTransformer(nn.Module):
head_dim=args.head_dim or args.dim // args.n_heads,
max_seqlen=args.max_seqlen,
)
self.eos_id = args.eos_id
self.layers = nn.ModuleList()
for _ in range(args.n_layers):