blt/bytelatent/tokenizers/tiktoken_tokenizer.py
2024-12-12 15:32:30 -08:00

87 lines
3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
from copy import copy
from pathlib import Path
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
try:
import tiktoken
from tiktoken.load import load_tiktoken_bpe
has_tiktoken = True
except ImportError:
has_tiktoken = False
DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
DEFAULT_TIKTOKEN_SPECIAL_TOKENS = {
"<|begin_of_text|>": 0,
"<|end_of_text|>": 1,
"<|fim_prefix|>": 2,
"<|fim_middle|>": 3,
"<|fim_end_fill|>": 253,
"<|fim_pad|>": 254,
"<|fim_suffix|>": 255,
}
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
logger = logging.getLogger(__name__)
class TikTokenTokenizer(Tokenizer):
def __init__(self, model_path: str) -> None:
mergeable_ranks = load_tiktoken_bpe(model_path)
all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS)
missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values())
for id in missing_ids:
all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id
for name in all_special_tokens_with_ids:
all_special_tokens_with_ids[name] += len(mergeable_ranks)
self.tkt_model = tiktoken.core.Encoding(
name=Path(model_path).stem,
pat_str=DEFAULT_TIKTOKEN_PATTERN,
mergeable_ranks=mergeable_ranks,
special_tokens=all_special_tokens_with_ids,
)
self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>")
self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>")
self.n_words: int = self.tkt_model.n_vocab
logger.info(
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)
def encode(self, s: str, add_bos: bool, add_eos: bool):
assert isinstance(s, str)
subs = []
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
return (
[self.bos_id] * add_bos
+ sum(self.tkt_model.encode_ordinary_batch(subs), start=[])
+ [self.eos_id] * add_eos
)
def decode(self, tokens: list[int]):
return self.tkt_model.decode(tokens)
def get_token_offsets(
self, text: str, tokens: list[int] | None = None
) -> tuple[list[str], list[int]]:
if tokens is not None:
token_bytes = self.tkt_model.decode_tokens_bytes(tokens)
else:
token_bytes = self.tkt_model.decode_tokens_bytes(
self.tkt_model.encode(text, allowed_special="all")
)
text_len, offsets = 0, []
for token in token_bytes:
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])]
return substrs, offsets