allow flex-attention to be disabled (#19)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

* allow flex-attention to silently fail

* allow flex-attn to be disabled via an env var
This commit is contained in:
Ink 2025-01-14 11:32:07 -06:00 committed by GitHub
parent 1da3dd9315
commit caec8d2621
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View file

@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@ -16,7 +16,10 @@ from xformers.ops import AttentionBias, fmha
from bytelatent import probe from bytelatent import probe
flex_attention_comp = torch.compile(flex_attention) if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
flex_attention_comp = torch.compile(flex_attention)
else:
flex_attention_comp = None
class InitStdFactor(Enum): class InitStdFactor(Enum):

View file

@ -1,5 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import atexit import atexit
import contextlib import contextlib
import logging import logging
@ -48,9 +47,13 @@ default_no_recompute_ops = {
torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default, torch.ops.c10d_functional.reduce_scatter_tensor.default,
torch.ops.xformers_flash.flash_fwd.default, torch.ops.xformers_flash.flash_fwd.default,
torch.ops.xformers.efficient_attention_forward_cutlass.default,
} }
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
default_no_recompute_ops.add(
torch.ops.xformers.efficient_attention_forward_cutlass.default
)
class DistributedArgs(BaseModel): class DistributedArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")