mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-02 10:39:10 +00:00
parent
96d51b59d2
commit
1b67cbe022
1 changed files with 7 additions and 3 deletions
|
@ -4,7 +4,7 @@ from enum import Enum, auto
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import model_validator
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.attention.flex_attention import create_block_mask
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
@ -13,7 +13,6 @@ from bytelatent.base_transformer import (
|
||||||
BaseTransformerArgs,
|
BaseTransformerArgs,
|
||||||
InitStdFactor,
|
InitStdFactor,
|
||||||
SequenceModelWithOutput,
|
SequenceModelWithOutput,
|
||||||
TransformerBlock,
|
|
||||||
)
|
)
|
||||||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||||
from bytelatent.model.latent_transformer import GlobalTransformer
|
from bytelatent.model.latent_transformer import GlobalTransformer
|
||||||
|
@ -21,6 +20,8 @@ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModel
|
||||||
from bytelatent.model.utils import downsample
|
from bytelatent.model.utils import downsample
|
||||||
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
|
||||||
|
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
|
||||||
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
def attention_flops_per_token(n_layers, seq_len, dim, causal):
|
||||||
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
|
||||||
|
@ -767,7 +768,10 @@ def compute_hash_embeddings(
|
||||||
return local_encoder_embeds
|
return local_encoder_embeds
|
||||||
|
|
||||||
|
|
||||||
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
|
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput, PyTorchModelHubMixin,
|
||||||
|
repo_url="https://github.com/facebookresearch/blt",
|
||||||
|
pipeline_tag="text-generation",
|
||||||
|
license="other"):
|
||||||
"""
|
"""
|
||||||
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
||||||
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
||||||
|
|
Loading…
Add table
Reference in a new issue