mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
black
This commit is contained in:
parent
da2cf02179
commit
8be8c85a63
|
@ -22,7 +22,7 @@ try:
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print('Apex not found. Using nn.RMSNorm')
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||||
|
|
|
@ -22,7 +22,7 @@ try:
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print('Apex not found. Using nn.RMSNorm')
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
|
@ -26,7 +26,7 @@ try:
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print('Apex not found. Using nn.RMSNorm')
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
Loading…
Reference in a new issue