mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-15 00:29:43 +00:00
Changes for training entropy model and correcting attention in local models (#25)
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
This commit is contained in:
parent
caec8d2621
commit
6ffeb66b53
15 changed files with 349 additions and 138 deletions
|
@ -4,7 +4,7 @@ from enum import Enum
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import (
|
||||
|
@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
|
|||
from xformers.ops import AttentionBias, fmha
|
||||
|
||||
from bytelatent import probe
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
|
||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||
flex_attention_comp = torch.compile(flex_attention)
|
||||
|
@ -30,13 +31,14 @@ class InitStdFactor(Enum):
|
|||
|
||||
|
||||
class BaseTransformerArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
dim: int = 512
|
||||
n_layers: int = 8
|
||||
head_dim: Optional[int] = None
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
head_dim: int | None = None
|
||||
n_heads: int | None = None
|
||||
n_kv_heads: int | None = None
|
||||
|
||||
ffn_dim_multiplier: Optional[float] = None
|
||||
ffn_dim_multiplier: float | None = None
|
||||
|
||||
multiple_of: int = 256
|
||||
|
||||
|
@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
|
|||
|
||||
rope_theta: float = 10000.0
|
||||
|
||||
init_base_std: Optional[float] = None
|
||||
init_base_std: float | None = None
|
||||
init_std_factor: InitStdFactor = InitStdFactor.DISABLED
|
||||
|
||||
max_seqlen: int = 1024
|
||||
|
||||
attn_impl: str | None = "sdpa"
|
||||
attn_bias_type: str | None = None
|
||||
# Special token config
|
||||
eos_id: int | None = EOS_ID
|
||||
|
||||
|
||||
def cross_entropy(pred, target, **kwargs):
|
||||
return F.nll_loss(
|
||||
|
@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
|
|||
torch.nn.init.ones_(self.weight) # type: ignore
|
||||
|
||||
|
||||
def _reshape_for_attn_bias(
|
||||
attn_bias: AttentionBias | None,
|
||||
*tensors: torch.Tensor,
|
||||
) -> list[torch.Tensor]:
|
||||
to_transform = list(tensors)
|
||||
if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
|
||||
# could be `view` instead of reshape during training, but for inference
|
||||
# have to reshape due to strides mismatch
|
||||
to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
|
||||
return to_transform
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -371,9 +390,12 @@ class Attention(nn.Module):
|
|||
output = flex_attention_comp(xq, xk, xv, block_mask=mask)
|
||||
output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
|
||||
|
||||
elif attn_impl == "fmha":
|
||||
elif attn_impl == "xformers":
|
||||
assert mask is None or isinstance(mask, AttentionBias)
|
||||
query_shape = xq.shape
|
||||
xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
|
||||
output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
|
||||
output = output.view(query_shape)
|
||||
# This uses B S H D instead of B H S D of pytorch
|
||||
|
||||
elif attn_impl == "sdpa":
|
||||
|
@ -522,14 +544,16 @@ class TransformerBlock(nn.Module):
|
|||
mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
|
||||
attn_impl: str = "sdpa",
|
||||
) -> torch.Tensor:
|
||||
h = x + self.attention(
|
||||
attn_out = self.attention(
|
||||
self.attention_norm(x),
|
||||
freq_cis,
|
||||
tok_idx=tok_idx,
|
||||
mask=mask,
|
||||
attn_impl=attn_impl,
|
||||
)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
h = x + attn_out
|
||||
h_norm = self.ffn_norm(h)
|
||||
out = h + self.feed_forward(h_norm)
|
||||
return out
|
||||
|
||||
def init_weights(self, init_std=None, factor=1.0):
|
||||
|
@ -545,6 +569,8 @@ class BaseTransformer(nn.Module):
|
|||
super().__init__()
|
||||
self.dim = args.dim
|
||||
self.init_base_std = args.init_base_std
|
||||
self.attn_impl = args.attn_impl
|
||||
self.attn_bias_type = args.attn_bias_type
|
||||
self.init_std_factor = InitStdFactor(args.init_std_factor)
|
||||
self.max_seqlen = args.max_seqlen
|
||||
self.rope_embeddings = RotaryEmbedding(
|
||||
|
@ -552,6 +578,7 @@ class BaseTransformer(nn.Module):
|
|||
head_dim=args.head_dim or args.dim // args.n_heads,
|
||||
max_seqlen=args.max_seqlen,
|
||||
)
|
||||
self.eos_id = args.eos_id
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(args.n_layers):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue