mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Make apex logs less noisy
Summary: Test Plan:
This commit is contained in:
parent
82ab5930ec
commit
2f247263b9
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Union
|
||||
|
@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
|
|||
)
|
||||
from xformers.ops import AttentionBias, fmha
|
||||
|
||||
from bytelatent import probe
|
||||
from bytelatent.tokenizers.constants import EOS_ID
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
RMSNorm = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("Apex not found. Using nn.RMSNorm")
|
||||
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||
RMSNorm = nn.RMSNorm
|
||||
|
||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||
|
|
|
@ -17,16 +17,15 @@ from bytelatent.base_transformer import (
|
|||
)
|
||||
from bytelatent.model.utils import create_causal_mask
|
||||
|
||||
logger = logging.getLogger()
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
RMSNorm = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("Apex not found. Using nn.RMSNorm")
|
||||
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||
RMSNorm = nn.RMSNorm
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.nn
|
||||
import torch.nn as nn
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import ConfigDict
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from xformers.ops import AttentionBias
|
||||
|
@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention
|
|||
from bytelatent.model.utils import create_causal_mask, downsample
|
||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||
|
||||
logger = logging.getLogger()
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
RMSNorm = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("Apex not found. Using nn.RMSNorm")
|
||||
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||
RMSNorm = nn.RMSNorm
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class LocalModelArgs(BaseTransformerArgs):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import (
|
|||
parallelize_module,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||
from xformers.ops import AttentionBias, fmha
|
||||
from xformers.ops import AttentionBias
|
||||
|
||||
from bytelatent.base_transformer import (
|
||||
BaseTransformer,
|
||||
|
@ -23,12 +23,14 @@ from bytelatent.base_transformer import (
|
|||
)
|
||||
from bytelatent.model.utils import create_causal_mask
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
try:
|
||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||
|
||||
RMSNorm = FusedRMSNorm
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
print("Apex not found. Using nn.RMSNorm")
|
||||
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||
RMSNorm = nn.RMSNorm
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue