Merge 2f247263b9 into sapling-pr-archive-EntilZha
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

This commit is contained in:
Pedro Rodriguez 2025-02-18 10:43:12 -08:00 committed by GitHub
commit 4b57d05c3b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 14 additions and 12 deletions

View file

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os import os
from enum import Enum from enum import Enum
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
) )
from xformers.ops import AttentionBias, fmha from xformers.ops import AttentionBias, fmha
from bytelatent import probe
from bytelatent.tokenizers.constants import EOS_ID from bytelatent.tokenizers.constants import EOS_ID
logger = logging.getLogger()
try: try:
from apex.normalization.fused_layer_norm import FusedRMSNorm from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm") logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm RMSNorm = nn.RMSNorm
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:

View file

@ -17,16 +17,15 @@ from bytelatent.base_transformer import (
) )
from bytelatent.model.utils import create_causal_mask from bytelatent.model.utils import create_causal_mask
logger = logging.getLogger()
try: try:
from apex.normalization.fused_layer_norm import FusedRMSNorm from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm") logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm RMSNorm = nn.RMSNorm
logger = logging.getLogger()
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
""" """

View file

@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
import torch.nn import torch.nn
import torch.nn as nn import torch.nn as nn
from pydantic import BaseModel, ConfigDict from pydantic import ConfigDict
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.attention.flex_attention import BlockMask from torch.nn.attention.flex_attention import BlockMask
from xformers.ops import AttentionBias from xformers.ops import AttentionBias
@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask, downsample from bytelatent.model.utils import create_causal_mask, downsample
from bytelatent.tokenizers.blt_tokenizer import BOE_ID from bytelatent.tokenizers.blt_tokenizer import BOE_ID
logger = logging.getLogger()
try: try:
from apex.normalization.fused_layer_norm import FusedRMSNorm from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm") logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm RMSNorm = nn.RMSNorm
logger = logging.getLogger()
class LocalModelArgs(BaseTransformerArgs): class LocalModelArgs(BaseTransformerArgs):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")

View file

@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import dataclass import logging
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import (
parallelize_module, parallelize_module,
) )
from torch.nn.attention.flex_attention import BlockMask, create_block_mask from torch.nn.attention.flex_attention import BlockMask, create_block_mask
from xformers.ops import AttentionBias, fmha from xformers.ops import AttentionBias
from bytelatent.base_transformer import ( from bytelatent.base_transformer import (
BaseTransformer, BaseTransformer,
@ -23,12 +23,14 @@ from bytelatent.base_transformer import (
) )
from bytelatent.model.utils import create_causal_mask from bytelatent.model.utils import create_causal_mask
logger = logging.getLogger()
try: try:
from apex.normalization.fused_layer_norm import FusedRMSNorm from apex.normalization.fused_layer_norm import FusedRMSNorm
RMSNorm = FusedRMSNorm RMSNorm = FusedRMSNorm
except (ImportError, ModuleNotFoundError): except (ImportError, ModuleNotFoundError):
print("Apex not found. Using nn.RMSNorm") logging.debug("Apex not found. Using nn.RMSNorm")
RMSNorm = nn.RMSNorm RMSNorm = nn.RMSNorm