mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
87 lines
3 KiB
Python
87 lines
3 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
import logging
|
||
|
from copy import copy
|
||
|
from pathlib import Path
|
||
|
|
||
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||
|
|
||
|
try:
|
||
|
import tiktoken
|
||
|
from tiktoken.load import load_tiktoken_bpe
|
||
|
|
||
|
has_tiktoken = True
|
||
|
except ImportError:
|
||
|
has_tiktoken = False
|
||
|
DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
||
|
DEFAULT_TIKTOKEN_SPECIAL_TOKENS = {
|
||
|
"<|begin_of_text|>": 0,
|
||
|
"<|end_of_text|>": 1,
|
||
|
"<|fim_prefix|>": 2,
|
||
|
"<|fim_middle|>": 3,
|
||
|
"<|fim_end_fill|>": 253,
|
||
|
"<|fim_pad|>": 254,
|
||
|
"<|fim_suffix|>": 255,
|
||
|
}
|
||
|
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class TikTokenTokenizer(Tokenizer):
|
||
|
def __init__(self, model_path: str) -> None:
|
||
|
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||
|
all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS)
|
||
|
missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values())
|
||
|
for id in missing_ids:
|
||
|
all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id
|
||
|
for name in all_special_tokens_with_ids:
|
||
|
all_special_tokens_with_ids[name] += len(mergeable_ranks)
|
||
|
|
||
|
self.tkt_model = tiktoken.core.Encoding(
|
||
|
name=Path(model_path).stem,
|
||
|
pat_str=DEFAULT_TIKTOKEN_PATTERN,
|
||
|
mergeable_ranks=mergeable_ranks,
|
||
|
special_tokens=all_special_tokens_with_ids,
|
||
|
)
|
||
|
|
||
|
self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>")
|
||
|
self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>")
|
||
|
|
||
|
self.n_words: int = self.tkt_model.n_vocab
|
||
|
|
||
|
logger.info(
|
||
|
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||
|
)
|
||
|
|
||
|
def encode(self, s: str, add_bos: bool, add_eos: bool):
|
||
|
assert isinstance(s, str)
|
||
|
|
||
|
subs = []
|
||
|
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS):
|
||
|
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS])
|
||
|
return (
|
||
|
[self.bos_id] * add_bos
|
||
|
+ sum(self.tkt_model.encode_ordinary_batch(subs), start=[])
|
||
|
+ [self.eos_id] * add_eos
|
||
|
)
|
||
|
|
||
|
def decode(self, tokens: list[int]):
|
||
|
return self.tkt_model.decode(tokens)
|
||
|
|
||
|
def get_token_offsets(
|
||
|
self, text: str, tokens: list[int] | None = None
|
||
|
) -> tuple[list[str], list[int]]:
|
||
|
if tokens is not None:
|
||
|
token_bytes = self.tkt_model.decode_tokens_bytes(tokens)
|
||
|
else:
|
||
|
token_bytes = self.tkt_model.decode_tokens_bytes(
|
||
|
self.tkt_model.encode(text, allowed_special="all")
|
||
|
)
|
||
|
|
||
|
text_len, offsets = 0, []
|
||
|
for token in token_bytes:
|
||
|
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
|
||
|
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
|
||
|
substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])]
|
||
|
return substrs, offsets
|