diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index cc9e1f3..2e45ea5 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -28,6 +28,7 @@ try: RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): + print("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm