mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 13:32:14 +00:00
Merge bec0164820
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
ed6300375f
|
@ -63,11 +63,11 @@ Now launch a debug job to check if everything works. **The provided configuratio
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# stool stands for SLURM tool !
|
# stool stands for SLURM tool !
|
||||||
python -m bytelatent.stool script=bytelatent.train config=apps/bytelatent/configs/debug.yaml nodes=1 partition=<partition>
|
python -m bytelatent.stool script=bytelatent.train config=bytelatent/configs/debug.yaml nodes=1 partition=<partition>
|
||||||
# if you want to launch locally you can use torchrun
|
# if you want to launch locally you can use torchrun
|
||||||
torchrun --nproc-per-node 8 -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
|
torchrun --nproc-per-node 8 -m bytelatent.train config=bytelatent/configs/debug.yaml
|
||||||
# or you can also launch on 1 GPU
|
# or you can also launch on 1 GPU
|
||||||
python -m bytelatent.train config=apps/bytelatent/configs/debug.yaml
|
python -m bytelatent.train config=bytelatent/configs/debug.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
When using `stool`, if a job crashes, it can be relaunched using sbatch:
|
When using `stool`, if a job crashes, it can be relaunched using sbatch:
|
||||||
|
|
|
@ -17,6 +17,14 @@ from xformers.ops import AttentionBias, fmha
|
||||||
from bytelatent import probe
|
from bytelatent import probe
|
||||||
from bytelatent.tokenizers.constants import EOS_ID
|
from bytelatent.tokenizers.constants import EOS_ID
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
|
RMSNorm = FusedRMSNorm
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
|
||||||
flex_attention_comp = torch.compile(flex_attention)
|
flex_attention_comp = torch.compile(flex_attention)
|
||||||
else:
|
else:
|
||||||
|
@ -294,37 +302,6 @@ class RotaryEmbedding(torch.nn.Module):
|
||||||
return self.freqs_cis[0:seqlen]
|
return self.freqs_cis[0:seqlen]
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
"""
|
|
||||||
Initialize the RMSNorm normalization layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): The dimension of the input tensor.
|
|
||||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
eps (float): A small value added to the denominator for numerical stability.
|
|
||||||
weight (nn.Parameter): Learnable scaling parameter.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def _norm(self, x: torch.Tensor):
|
|
||||||
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
x = probe.log_stats(x, "resid")
|
|
||||||
output = self._norm(x.float())
|
|
||||||
return (output * self.weight.float()).type_as(x)
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
torch.nn.init.ones_(self.weight) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
def _reshape_for_attn_bias(
|
def _reshape_for_attn_bias(
|
||||||
attn_bias: AttentionBias | None,
|
attn_bias: AttentionBias | None,
|
||||||
*tensors: torch.Tensor,
|
*tensors: torch.Tensor,
|
||||||
|
|
|
@ -12,12 +12,19 @@ from xformers.ops import AttentionBias
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformer,
|
BaseTransformer,
|
||||||
BaseTransformerArgs,
|
BaseTransformerArgs,
|
||||||
RMSNorm,
|
|
||||||
flex_attention_comp,
|
flex_attention_comp,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from bytelatent.model.utils import create_causal_mask
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
|
RMSNorm = FusedRMSNorm
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,7 +51,7 @@ class CrossAttention(nn.Module):
|
||||||
self.n_kv_heads = n_kv_heads
|
self.n_kv_heads = n_kv_heads
|
||||||
self.heads_per_group = self.n_heads // self.n_kv_heads
|
self.heads_per_group = self.n_heads // self.n_kv_heads
|
||||||
|
|
||||||
self.cross_attn_norm_q = RMSNorm(dim, eps=norm_eps)
|
self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps)
|
||||||
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps)
|
||||||
|
|
||||||
self.wq = nn.Linear(
|
self.wq = nn.Linear(
|
||||||
|
|
|
@ -14,7 +14,6 @@ from xformers.ops import AttentionBias
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformerArgs,
|
BaseTransformerArgs,
|
||||||
InitStdFactor,
|
InitStdFactor,
|
||||||
RMSNorm,
|
|
||||||
RotaryEmbedding,
|
RotaryEmbedding,
|
||||||
TransformerBlock,
|
TransformerBlock,
|
||||||
)
|
)
|
||||||
|
@ -22,6 +21,14 @@ from bytelatent.model.latent_transformer import CrossAttention
|
||||||
from bytelatent.model.utils import create_causal_mask, downsample
|
from bytelatent.model.utils import create_causal_mask, downsample
|
||||||
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
from bytelatent.tokenizers.blt_tokenizer import BOE_ID
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
|
RMSNorm = FusedRMSNorm
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,11 +19,18 @@ from xformers.ops import AttentionBias, fmha
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformer,
|
BaseTransformer,
|
||||||
BaseTransformerArgs,
|
BaseTransformerArgs,
|
||||||
RMSNorm,
|
|
||||||
cross_entropy,
|
cross_entropy,
|
||||||
)
|
)
|
||||||
from bytelatent.model.utils import create_causal_mask
|
from bytelatent.model.utils import create_causal_mask
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import FusedRMSNorm
|
||||||
|
|
||||||
|
RMSNorm = FusedRMSNorm
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
print("Apex not found. Using nn.RMSNorm")
|
||||||
|
RMSNorm = nn.RMSNorm
|
||||||
|
|
||||||
|
|
||||||
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
||||||
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
||||||
|
|
Loading…
Reference in a new issue