mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Merge 2f247263b9
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
4b57d05c3b
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
@ -14,15 +15,16 @@ from torch.nn.attention.flex_attention import (
|
||||||
)
|
)
|
||||||
from xformers.ops import AttentionBias, fmha
|
from xformers.ops import AttentionBias, fmha
|
||||||
|
|
||||||
from bytelatent import probe
|
|
||||||
from bytelatent.tokenizers.constants import EOS_ID
|
from bytelatent.tokenizers.constants import EOS_ID
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print("Apex not found. Using nn.RMSNorm")
|
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||||
|
|
|
@ -17,16 +17,15 @@ from bytelatent.base_transformer import (
|
||||||
)
|
)
|
||||||
from bytelatent.model.utils import create_causal_mask
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print("Apex not found. Using nn.RMSNorm")
|
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import ConfigDict
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
from xformers.ops import AttentionBias
|
from xformers.ops import AttentionBias
|
||||||
|
@ -21,16 +21,15 @@ from bytelatent.model.latent_transformer import CrossAttention
|
||||||
from bytelatent.model.utils import create_causal_mask, downsample
|
from bytelatent.model.utils import create_causal_mask, downsample
|
||||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print("Apex not found. Using nn.RMSNorm")
|
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
|
||||||
|
|
||||||
|
|
||||||
class LocalModelArgs(BaseTransformerArgs):
|
class LocalModelArgs(BaseTransformerArgs):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
import logging
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,7 +14,7 @@ from torch.distributed.tensor.parallel import (
|
||||||
parallelize_module,
|
parallelize_module,
|
||||||
)
|
)
|
||||||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
|
||||||
from xformers.ops import AttentionBias, fmha
|
from xformers.ops import AttentionBias
|
||||||
|
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformer,
|
BaseTransformer,
|
||||||
|
@ -23,12 +23,14 @@ from bytelatent.base_transformer import (
|
||||||
)
|
)
|
||||||
from bytelatent.model.utils import create_causal_mask
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
RMSNorm = FusedRMSNorm
|
RMSNorm = FusedRMSNorm
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
print("Apex not found. Using nn.RMSNorm")
|
logging.debug("Apex not found. Using nn.RMSNorm")
|
||||||
RMSNorm = nn.RMSNorm
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue