Make apex logs less noisy

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-14 23:45:11 +00:00
parent f94babc94e
commit a3e0647d03
4 changed files with 14 additions and 12 deletions

View file

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from enum import Enum
from typing import Optional, Tuple, Union
@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
)
from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID
logger = logging.getLogger()
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:

View file

@ -17,16 +17,15 @@ from bytelatent.base_transformer import (
)
from bytelatent.model.utils import create_causal_mask
logger = logging.getLogger()
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
logger = logging.getLogger()
class CrossAttention(nn.Module):
"""

View file

@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn
import torch.nn as nn
from pydantic import BaseModel, ConfigDict
from pydantic import ConfigDict
from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias
@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger()
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm
logger = logging.getLogger()
class LocalModelArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid")

View file

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import dataclass
import logging
from typing import Optional, Tuple, Union
import torch
@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
)
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from xformers.ops import AttentionBias, fmha
from xformers.ops import AttentionBias
from bytelatent.base_transformer import (
BaseTransformer,
@ -23,12 +23,14 @@ from bytelatent.base_transformer import (
)
from bytelatent.model.utils import create_causal_mask
logger = logging.getLogger()
try:
from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm")
logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm