blt/bytelatent/tokenizers/byte_tokenizer.py
2024-12-12 15:32:30 -08:00

36 lines
1.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
class ByteTokenizer(Tokenizer):
def __init__(self):
self.bos_id = 256
self.eos_id = 257
self.n_words = 258
def encode(self, s: str, add_bos: bool = False, add_eos: bool = False):
tokens = [self.bos_id] * add_bos + list(s.encode()) + [self.eos_id] * add_eos
return tokens
def decode(self, tokens: list[int]):
byte_tokens = bytes([t for t in tokens if t < 256])
return byte_tokens.decode("utf-8", errors="backslashreplace")
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
if tokens is None:
tokens = self.encode(text)
decoded_chars, offsets = [], []
byte_pos = 0
for token in tokens:
if token < 256:
char = bytes([token]).decode("utf-8", errors="ignore")
if char:
decoded_chars.append(char)
offsets.append(byte_pos)
byte_pos += len(char.encode("utf-8"))
return decoded_chars, offsets