2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
2025-01-17 22:23:01 +00:00
import logging
import os
2024-12-12 23:32:30 +00:00
import torch
from torch . nn . attention . flex_attention import create_block_mask
from xformers . ops import fmha
2025-01-17 22:23:01 +00:00
logger = logging . getLogger ( )
2024-12-12 23:32:30 +00:00
def patch_reduce ( h , max_num_patches , reduction , patch_ids ) :
"""
Reduce variable length patches to single embedding per patch
Note : this works with variable number of patches for different sequences in the batch
It handles variable length patches by assuming that patch_lengths will be 0 for any
extra patches on the * right * . Since there can be a variable number of patches
this function also return the number of patches for each sequence in the batch .
Any embeddings on the right that are not allocated to a patch
( i . e . if the sum ( patch_lengths [ i ] ) < seq_len for any i )
will be sent to a dummy patch , which is trimmed before returning .
"""
bs , seq_len , emb_dim = h . shape
patch_ids = patch_ids . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , h . shape [ - 1 ] )
reduced_embs = torch . zeros (
( bs , max_num_patches , emb_dim ) , dtype = h . dtype , device = h . device
)
reduced_embs = reduced_embs . scatter_reduce (
src = h ,
dim = 1 ,
index = patch_ids ,
reduce = reduction ,
include_self = False ,
)
reduced_embs = reduced_embs [ : , : max_num_patches , : ]
return reduced_embs
def concat_downsample ( h , patch_lengths , patch_size ) :
# The assumption in this function is that seq_len = patch_size * num_patches.
bs , seq_len , emb_dim = h . shape
patch_end_ids = torch . cumsum ( patch_lengths , dim = 1 )
patch_ids = patch_end_ids . unsqueeze ( - 1 ) - torch . arange ( patch_size , 0 , - 1 ) . to (
patch_end_ids . device
)
# Is clamp ok here?
patch_ids = patch_ids . clamp ( min = 0 ) . unsqueeze ( - 1 ) . expand ( - 1 , - 1 , - 1 , h . shape [ - 1 ] )
patch_ids = patch_ids . view ( bs , - 1 , emb_dim )
# after gather h.shape = [batch_size, seq_len, dim]
h = torch . gather ( h , 1 , patch_ids )
h = h . reshape ( bs , patch_lengths . shape [ 1 ] , patch_size * h . size ( - 1 ) )
return h
def pooling_downsample ( h , max_num_patches , pooling_mode , patch_ids ) :
cat = [ ]
if " avg " in pooling_mode or " mean " in pooling_mode :
cat . append ( patch_reduce ( h , max_num_patches , " mean " , patch_ids ) )
if " min " in pooling_mode :
cat . append ( patch_reduce ( h , max_num_patches , " amin " , patch_ids ) )
if " max " in pooling_mode :
cat . append ( patch_reduce ( h , max_num_patches , " amax " , patch_ids ) )
assert len ( cat ) > 0
h = torch . cat ( cat , dim = - 1 )
return h
def downsample (
h ,
num_patches ,
patch_lengths = None ,
patch_ids = None ,
downsampling_by_pooling = None ,
patch_size = 4 ,
) :
"""
Downsampling :
a . concatenating embeddings in the patch
Note : with dynamic patching , patch the last patch_size tokens .
b . pooling embeddings in the patch
"""
# input: h.shape = [batch_size, seq_len, dim]
# input: pool h.shape = [batch_size, seq_len / patch_size, dim]
# if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep
if downsampling_by_pooling is not None and len ( downsampling_by_pooling ) > 0 :
# By pooling
max_num_patches = num_patches
assert patch_ids is not None
h = pooling_downsample ( h , max_num_patches , downsampling_by_pooling , patch_ids )
else :
# TODO: remove this condition
# By concatenating (fixed lengths patching)
assert patch_lengths is not None
h = concat_downsample ( h , patch_lengths , patch_size )
return h
def causal_mask ( b , h , q_idx , kv_idx ) :
return q_idx > = kv_idx
2025-01-17 22:23:01 +00:00
def tokens_to_seqlen ( batch : torch . Tensor , eos_id : int ) :
"""
0 0 0 1 0 0 0 1 0 0 0
0 1 0 0 0 1 0 0 0 0 0
- > 4 4 3 2 4 5
"""
mask = batch == eos_id
mask [ : , - 1 ] = True # virtual eos at the end of each row
# 0 0 0 1 0 0 0 1 0 0 X
# 0 1 0 0 0 1 0 0 0 0 X
row , col = torch . where ( mask )
# row = 0, 0, 0, 1, 1, 1
# col = 3, 7, 10, 1, 5, 10
seqlens = ( col [ 1 : ] - col [ : - 1 ] ) + ( row [ 1 : ] - row [ : - 1 ] ) * mask . shape [ 1 ]
# seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
return [ int ( col [ 0 ] . item ( ) + 1 ) ] + seqlens . tolist ( )
def create_causal_mask (
seqlen ,
attn_impl : str ,
attn_bias_type : str | None ,
* ,
eos_id : int | None = None ,
tokens : torch . Tensor | None = None ,
sliding_window : int | None = None ,
) :
if attn_impl == " xformers " :
if attn_bias_type is None :
return fmha . attn_bias . LowerTriangularMask ( )
elif attn_bias_type == " causal " :
assert sliding_window is None
return fmha . attn_bias . LowerTriangularMask ( )
elif attn_bias_type == " block_causal " :
assert sliding_window is None
assert eos_id is not None
assert tokens is not None
return fmha . attn_bias . BlockDiagonalCausalMask . from_seqlens (
q_seqlen = tokens_to_seqlen ( tokens , eos_id )
)
elif attn_bias_type == " local_block_causal " :
assert sliding_window is not None
assert eos_id is not None
assert tokens is not None
return fmha . attn_bias . BlockDiagonalCausalMask . from_seqlens (
q_seqlen = tokens_to_seqlen ( tokens , eos_id )
) . make_local_attention ( sliding_window )
else :
return fmha . attn_bias . LocalAttentionFromBottomRightMask (
window_left = sliding_window - 1 , window_right = 0
)
2024-12-12 23:32:30 +00:00
elif attn_impl == " sdpa " :
2025-01-17 22:23:01 +00:00
BLT_SUPPRESS_ATTN_ERROR = int ( os . environ . get ( " BLT_SUPPRESS_ATTN_ERROR " , 0 ) )
if attn_bias_type == " causal " :
return " causal "
if BLT_SUPPRESS_ATTN_ERROR == 1 :
logging . warning (
" SDPA attention being used, which doesn ' t have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1 "
)
return " causal "
else :
raise ValueError (
" SDPA attention being used, which doesn ' t have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1 "
)
2024-12-12 23:32:30 +00:00
elif attn_impl == " flex_attention " :
return create_block_mask ( causal_mask , None , None , seqlen , seqlen )
elif attn_impl == " fmha " :
return None
else :
raise NotImplementedError (
f " Attention { attn_impl } with { sliding_window } sliding window not implemented "
)