From caec8d26210b23158169a28828cb763c9591659c Mon Sep 17 00:00:00 2001 From: Ink Date: Tue, 14 Jan 2025 11:32:07 -0600 Subject: [PATCH] allow flex-attention to be disabled (#19) * allow flex-attention to silently fail * allow flex-attn to be disabled via an env var --- bytelatent/base_transformer.py | 7 +++++-- bytelatent/distributed.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index f494a15..45cb7c5 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -1,5 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - +import os from enum import Enum from typing import Optional, Tuple, Union @@ -16,7 +16,10 @@ from xformers.ops import AttentionBias, fmha from bytelatent import probe -flex_attention_comp = torch.compile(flex_attention) +if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: + flex_attention_comp = torch.compile(flex_attention) +else: + flex_attention_comp = None class InitStdFactor(Enum): diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index fadce45..b211858 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -1,5 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. - import atexit import contextlib import logging @@ -48,9 +47,13 @@ default_no_recompute_ops = { torch.ops.aten._scaled_dot_product_flash_attention.default, torch.ops.c10d_functional.reduce_scatter_tensor.default, torch.ops.xformers_flash.flash_fwd.default, - torch.ops.xformers.efficient_attention_forward_cutlass.default, } +if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: + default_no_recompute_ops.add( + torch.ops.xformers.efficient_attention_forward_cutlass.default + ) + class DistributedArgs(BaseModel): model_config = ConfigDict(extra="forbid")