mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Merge 45456fa6d8
into sapling-pr-archive-EntilZha
This commit is contained in:
commit
06a17a0ddc
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
|
@ -572,7 +573,13 @@ class TransformerBlock(nn.Module):
|
|||
self.ffn_norm.reset_parameters()
|
||||
|
||||
|
||||
class BaseTransformer(nn.Module):
|
||||
class SequenceModelWithOutput(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def get_output_seq_len(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class BaseTransformer(nn.Module, SequenceModelWithOutput):
|
||||
def __init__(self, args: BaseTransformerArgs):
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
|
@ -593,6 +600,9 @@ class BaseTransformer(nn.Module):
|
|||
for _ in range(args.n_layers):
|
||||
self.layers.append(TransformerBlock(args))
|
||||
|
||||
def get_output_seq_len(self):
|
||||
return self.max_seqlen
|
||||
|
||||
def forward(
|
||||
self,
|
||||
h,
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing_extensions import Self
|
|||
from bytelatent.base_transformer import (
|
||||
BaseTransformerArgs,
|
||||
InitStdFactor,
|
||||
SequenceModelWithOutput,
|
||||
TransformerBlock,
|
||||
)
|
||||
from bytelatent.data.patcher import Patcher, PatcherArgs
|
||||
|
@ -766,7 +767,7 @@ def compute_hash_embeddings(
|
|||
return local_encoder_embeds
|
||||
|
||||
|
||||
class ByteLatentTransformer(nn.Module):
|
||||
class ByteLatentTransformer(nn.Module, SequenceModelWithOutput):
|
||||
"""
|
||||
The ByteLatentTransformer (BLT) is a byte-level language model architecture that processes byte sequences
|
||||
by dynamically segmenting them into patches. It uses a combination of local encoders, global transformers,
|
||||
|
@ -856,6 +857,9 @@ class ByteLatentTransformer(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
def get_output_seq_len(self):
|
||||
return self.max_seqlen
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
|
|
|
@ -17,3 +17,7 @@ class Tokenizer(abc.ABC):
|
|||
) -> tuple[list[str], list[int]]:
|
||||
"""Return the offsets of the tokens in the original text. Only used for evaluation."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_vocab_size(self) -> int:
|
||||
pass
|
||||
|
|
|
@ -101,6 +101,9 @@ class BltTokenizer(Tokenizer):
|
|||
self.vocab_size_unit_1 = vocab_size_unit_1
|
||||
self.n_words = vocab_size_unit_1 + self.offsetting_special_char
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.n_words
|
||||
|
||||
def encode(
|
||||
self, text: str, add_bos: bool | None = None, add_eos: bool | None = None
|
||||
):
|
||||
|
|
|
@ -8,6 +8,9 @@ class ByteTokenizer(Tokenizer):
|
|||
self.eos_id = 257
|
||||
self.n_words = 258
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.n_words
|
||||
|
||||
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
|
||||
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
|
||||
return tokens
|
||||
|
|
|
@ -35,6 +35,9 @@ class SentencePieceTokenizer(Tokenizer):
|
|||
)
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.n_words
|
||||
|
||||
def encode(self, s: str, add_bos: bool | None = None, add_eos: bool | None = None):
|
||||
if add_bos is None:
|
||||
add_bos = self.add_bos
|
||||
|
|
|
@ -53,6 +53,9 @@ class TikTokenTokenizer(Tokenizer):
|
|||
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||
)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.n_words
|
||||
|
||||
def encode(self, s: str, add_bos: bool, add_eos: bool):
|
||||
assert isinstance(s, str)
|
||||
|
||||
|
|
Loading…
Reference in a new issue