From da2cf0217979f42c1e949d632976a7a7d9f28e63 Mon Sep 17 00:00:00 2001 From: Srini Iyer Date: Fri, 14 Feb 2025 19:15:22 +0000 Subject: [PATCH] added message for missing apex --- bytelatent/base_transformer.py | 1 + bytelatent/model/latent_transformer.py | 1 + bytelatent/model/local_models.py | 1 + 3 files changed, 3 insertions(+) diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 1888658..eed4912 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -22,6 +22,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print('Apex not found. Using nn.RMSNorm') RMSNorm = nn.RMSNorm if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index b1f72a7..1e615b6 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -22,6 +22,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print('Apex not found. Using nn.RMSNorm') RMSNorm = nn.RMSNorm logger = logging.getLogger() diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 1d8db07..1aba0e8 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -26,6 +26,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print('Apex not found. Using nn.RMSNorm') RMSNorm = nn.RMSNorm logger = logging.getLogger()