mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 14:27:49 +00:00
parent
82ab5930ec
commit
b0956bde99
4 changed files with 14 additions and 12 deletions
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Union
|
||||
|
@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
|
|||
)
|
||||
from xformers.ops import AttentionBias, fmha
|
||||
|
||||
from bytelatent import probe
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
RMSNorm = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("Apex not found. Using nn.RMSNorm")
|
||||
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||
RMSNorm = nn.RMSNorm
|
||||
|
||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue