mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
allow flex-attention to be disabled (#19)
* allow flex-attention to silently fail * allow flex-attn to be disabled via an env var
This commit is contained in:
parent
1da3dd9315
commit
caec8d2621
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
@ -16,7 +16,10 @@ from xformers.ops import AttentionBias, fmha
|
|||
|
||||
from bytelatent import probe
|
||||
|
||||
flex_attention_comp = torch.compile(flex_attention)
|
||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||
flex_attention_comp = torch.compile(flex_attention)
|
||||
else:
|
||||
flex_attention_comp = None
|
||||
|
||||
|
||||
class InitStdFactor(Enum):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
import logging
|
||||
|
@ -48,9 +47,13 @@ default_no_recompute_ops = {
|
|||
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
||||
torch.ops.c10d_functional.reduce_scatter_tensor.default,
|
||||
torch.ops.xformers_flash.flash_fwd.default,
|
||||
torch.ops.xformers.efficient_attention_forward_cutlass.default,
|
||||
}
|
||||
|
||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||
default_no_recompute_ops.add(
|
||||
torch.ops.xformers.efficient_attention_forward_cutlass.default
|
||||
)
|
||||
|
||||
|
||||
class DistributedArgs(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
Loading…
Reference in a new issue