diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 7b76b9e..d44676d 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import logging import os from enum import Enum from typing import Optional, Tuple, Union @@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import ( ) from xformers.ops import AttentionBias, fmha -from bytelatent import probe from bytelatent.tokenizers.constants import EOS_ID +logger = logging.getLogger() + try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index 95b6d8b..a6cabdc 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -17,16 +17,15 @@ from bytelatent.base_transformer import ( ) from bytelatent.model.utils import create_causal_mask +logger = logging.getLogger() try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm -logger = logging.getLogger() - class CrossAttention(nn.Module): """ diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 353c878..09a5a19 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union import torch import torch.nn import torch.nn as nn -from pydantic import BaseModel, ConfigDict +from pydantic import ConfigDict from torch.nn import functional as F from torch.nn.attention.flex_attention import BlockMask from xformers.ops import AttentionBias @@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.tokenizers.blt_tokenizer import BOE_ID +logger = logging.getLogger() try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm -logger = logging.getLogger() - class LocalModelArgs(BaseTransformerArgs): model_config = ConfigDict(extra="forbid") diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 2e45ea5..da03761 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from dataclasses import dataclass +import logging from typing import Optional, Tuple, Union import torch @@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import ( parallelize_module, ) from torch.nn.attention.flex_attention import BlockMask, create_block_mask -from xformers.ops import AttentionBias, fmha +from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, @@ -23,12 +23,14 @@ from bytelatent.base_transformer import ( ) from bytelatent.model.utils import create_causal_mask +logger = logging.getLogger() + try: from apex.normalization.fused_layer_norm import FusedRMSNorm RMSNorm = FusedRMSNorm except (ImportError, ModuleNotFoundError): - print("Apex not found. Using nn.RMSNorm") + logging.debug("Apex not found. Using nn.RMSNorm") RMSNorm = nn.RMSNorm