mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
added message for missing apex
This commit is contained in:
parent
d3bf3a1383
commit
da2cf02179
|
@ -22,6 +22,7 @@ try:
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print('Apex not found. Using nn.RMSNorm')
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||||
|
|
|
@ -22,6 +22,7 @@ try:
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print('Apex not found. Using nn.RMSNorm')
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
|
@ -26,6 +26,7 @@ try:
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print('Apex not found. Using nn.RMSNorm')
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
Loading…
Reference in a new issue