Make apex logs less noisy (#60)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-18 10:45:56 -08:00 committed by GitHub
parent 82ab5930ec
commit b0956bde99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 14 additions and 12 deletions

View file

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from enum import Enum
from typing import Optional, Tuple, Union
@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
)
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
logger = logging.getLogger()
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: