diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index f494a15..45cb7c5 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - +import os from enum import Enum from typing import Optional, Tuple, Union @@ -16,7 +16,10 @@ from xformers.ops import AttentionBias, fmha from bytelatent import probe -flex_attention_comp = torch.compile(flex_attention) +if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: + flex_attention_comp = torch.compile(flex_attention) +else: + flex_attention_comp = None class InitStdFactor(Enum): diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index fadce45..b211858 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -1,5 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - import atexit import contextlib import logging @@ -48,9 +47,13 @@ default_no_recompute_ops = { torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.c10d_functional.reduce_scatter_tensor.default, torch.ops.xformers_flash.flash_fwd.default, - torch.ops.xformers.efficient_attention_forward_cutlass.default, } +if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: + default_no_recompute_ops.add( + torch.ops.xformers.efficient_attention_forward_cutlass.default + ) + class DistributedArgs(BaseModel): model_config = ConfigDict(extra="forbid")