mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Merge 45456fa6d8
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
06a17a0ddc
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -572,7 +573,13 @@ class TransformerBlock(nn.Module):
|
||||||
self.ffn_norm.reset_parameters()
|
self.ffn_norm.reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
class BaseTransformer(nn.Module):
|
class SequenceModelWithOutput(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_output_seq_len(self) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
||||||
def __init__(self, args: BaseTransformerArgs):
|
def __init__(self, args: BaseTransformerArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
|
@ -593,6 +600,9 @@ class BaseTransformer(nn.Module):
|
||||||
for _ in range(args.n_layers):
|
for _ in range(args.n_layers):
|
||||||
self.layers.append(TransformerBlock(args))
|
self.layers.append(TransformerBlock(args))
|
||||||
|
|
||||||
|
def get_output_seq_len(self):
|
||||||
|
return self.max_seqlen
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
h,
|
h,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from typing_extensions import Self
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
BaseTransformerArgs,
|
BaseTransformerArgs,
|
||||||
InitStdFactor,
|
InitStdFactor,
|
||||||
|
SequenceModelWithOutput,
|
||||||
TransformerBlock,
|
TransformerBlock,
|
||||||
)
|
)
|
||||||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||||
|
@ -766,7 +767,7 @@ def compute_hash_embeddings(
|
||||||
return local_encoder_embeds
|
return local_encoder_embeds
|
||||||
|
|
||||||
|
|
||||||
class ByteLatentTransformer(nn.Module):
|
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
|
||||||
"""
|
"""
|
||||||
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
||||||
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
||||||
|
@ -856,6 +857,9 @@ class ByteLatentTransformer(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_output_seq_len(self):
|
||||||
|
return self.max_seqlen
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
tokens: torch.Tensor,
|
tokens: torch.Tensor,
|
||||||
|
|
|
@ -17,3 +17,7 @@ class Tokenizer(abc.ABC):
|
||||||
) -> tuple[list[str], list[int]]:
|
) -> tuple[list[str], list[int]]:
|
||||||
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
|
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
pass
|
||||||
|
|
|
@ -101,6 +101,9 @@ class BltTokenizer(Tokenizer):
|
||||||
self.vocab_size_unit_1 = vocab_size_unit_1
|
self.vocab_size_unit_1 = vocab_size_unit_1
|
||||||
self.n_words = vocab_size_unit_1 + self.offsetting_special_char
|
self.n_words = vocab_size_unit_1 + self.offsetting_special_char
|
||||||
|
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
return self.n_words
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
|
self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
|
||||||
):
|
):
|
||||||
|
|
|
@ -8,6 +8,9 @@ class ByteTokenizer(Tokenizer):
|
||||||
self.eos_id = 257
|
self.eos_id = 257
|
||||||
self.n_words = 258
|
self.n_words = 258
|
||||||
|
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
return self.n_words
|
||||||
|
|
||||||
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
|
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
|
||||||
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
|
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
|
||||||
return tokens
|
return tokens
|
||||||
|
|
|
@ -35,6 +35,9 @@ class SentencePieceTokenizer(Tokenizer):
|
||||||
)
|
)
|
||||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||||
|
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
return self.n_words
|
||||||
|
|
||||||
def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
|
def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
|
||||||
if add_bos is None:
|
if add_bos is None:
|
||||||
add_bos = self.add_bos
|
add_bos = self.add_bos
|
||||||
|
|
|
@ -53,6 +53,9 @@ class TikTokenTokenizer(Tokenizer):
|
||||||
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_vocab_size(self) -> int:
|
||||||
|
return self.n_words
|
||||||
|
|
||||||
def encode(self, s: str, add_bos: bool, add_eos: bool):
|
def encode(self, s: str, add_bos: bool, add_eos: bool):
|
||||||
assert isinstance(s, str)
|
assert isinstance(s, str)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue