diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 1888658..eed4912 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -22,6 +22,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print('Apex not found. Using nn.RMSNorm') RMSNorm = nn.RMSNorm if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index b1f72a7..1e615b6 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -22,6 +22,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print('Apex not found. Using nn.RMSNorm') RMSNorm = nn.RMSNorm logger = logging.getLogger() diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 1d8db07..1aba0e8 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -26,6 +26,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print('Apex not found. Using nn.RMSNorm') RMSNorm = nn.RMSNorm logger = logging.getLogger()