Changes for training entropy model and correcting attention in local models (#25)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-17 14:23:01 -08:00 committed by GitHub
parent caec8d2621
commit 6ffeb66b53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 349 additions and 138 deletions

View file

@ -1,8 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
import torch
from torch.nn.attention.flex_attention import create_block_mask
from xformers.ops import fmha
logger = logging.getLogger()
def patch_reduce(h, max_num_patches, reduction, patch_ids):
"""
@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
"""
0 0 0 1 0 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 0
-> 4 4 3 2 4 5
"""
mask = batch == eos_id
mask[:, -1] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row, col = torch.where(mask)
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [int(col[0].item() + 1)] + seqlens.tolist()
def create_causal_mask(
seqlen,
attn_impl: str,
attn_bias_type: str | None,
*,
eos_id: int | None = None,
tokens: torch.Tensor | None = None,
sliding_window: int | None = None,
):
if attn_impl == "xformers":
if attn_bias_type is None:
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "causal":
assert sliding_window is None
return fmha.attn_bias.LowerTriangularMask()
elif attn_bias_type == "block_causal":
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
)
elif attn_bias_type == "local_block_causal":
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=tokens_to_seqlen(tokens, eos_id)
).make_local_attention(sliding_window)
else:
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "sdpa":
return "causal"
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
if attn_bias_type == "causal":
return "causal"
if BLT_SUPPRESS_ATTN_ERROR == 1:
logging.warning(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
)
return "causal"
else:
raise ValueError(
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
)
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
elif attn_impl == "fmha":