mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 22:34:37 +00:00
Changes for training entropy model and correcting attention in local models (#25)
Summary: - Refactor local model configs to be separate and clearer - Add attention arguments and correct which attention is used in local models - Preparation for being able to have an entropy train script - Fix failing unit tests Test Plan:
This commit is contained in:
parent
caec8d2621
commit
6ffeb66b53
15 changed files with 349 additions and 138 deletions
|
@ -1,8 +1,13 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
from xformers.ops import fmha
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def patch_reduce(h, max_num_patches, reduction, patch_ids):
|
||||
"""
|
||||
|
@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx):
|
|||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
def create_causal_mask(seqlen, attn_impl, sliding_window):
|
||||
if sliding_window is not None and attn_impl == "xformers":
|
||||
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
||||
window_left=sliding_window - 1, window_right=0
|
||||
)
|
||||
elif attn_impl == "xformers":
|
||||
return fmha.attn_bias.LowerTriangularMask()
|
||||
def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
|
||||
"""
|
||||
0 0 0 1 0 0 0 1 0 0 0
|
||||
0 1 0 0 0 1 0 0 0 0 0
|
||||
-> 4 4 3 2 4 5
|
||||
"""
|
||||
mask = batch == eos_id
|
||||
mask[:, -1] = True # virtual eos at the end of each row
|
||||
|
||||
# 0 0 0 1 0 0 0 1 0 0 X
|
||||
# 0 1 0 0 0 1 0 0 0 0 X
|
||||
row, col = torch.where(mask)
|
||||
|
||||
# row = 0, 0, 0, 1, 1, 1
|
||||
# col = 3, 7, 10, 1, 5, 10
|
||||
seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
|
||||
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
|
||||
return [int(col[0].item() + 1)] + seqlens.tolist()
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
seqlen,
|
||||
attn_impl: str,
|
||||
attn_bias_type: str | None,
|
||||
*,
|
||||
eos_id: int | None = None,
|
||||
tokens: torch.Tensor | None = None,
|
||||
sliding_window: int | None = None,
|
||||
):
|
||||
if attn_impl == "xformers":
|
||||
if attn_bias_type is None:
|
||||
return fmha.attn_bias.LowerTriangularMask()
|
||||
elif attn_bias_type == "causal":
|
||||
assert sliding_window is None
|
||||
return fmha.attn_bias.LowerTriangularMask()
|
||||
elif attn_bias_type == "block_causal":
|
||||
assert sliding_window is None
|
||||
assert eos_id is not None
|
||||
assert tokens is not None
|
||||
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=tokens_to_seqlen(tokens, eos_id)
|
||||
)
|
||||
elif attn_bias_type == "local_block_causal":
|
||||
assert sliding_window is not None
|
||||
assert eos_id is not None
|
||||
assert tokens is not None
|
||||
return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=tokens_to_seqlen(tokens, eos_id)
|
||||
).make_local_attention(sliding_window)
|
||||
else:
|
||||
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
|
||||
window_left=sliding_window - 1, window_right=0
|
||||
)
|
||||
elif attn_impl == "sdpa":
|
||||
return "causal"
|
||||
BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
|
||||
|
||||
if attn_bias_type == "causal":
|
||||
return "causal"
|
||||
|
||||
if BLT_SUPPRESS_ATTN_ERROR == 1:
|
||||
logging.warning(
|
||||
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
|
||||
)
|
||||
return "causal"
|
||||
else:
|
||||
raise ValueError(
|
||||
"SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
|
||||
)
|
||||
elif attn_impl == "flex_attention":
|
||||
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
|
||||
elif attn_impl == "fmha":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue