# Copyright (c) Meta Platforms, Inc. and affiliates. import logging from copy import copy from pathlib import Path from bytelatent.tokenizers.abstract_tokenizer import Tokenizer try: import tiktoken from tiktoken.load import load_tiktoken_bpe has_tiktoken = True except ImportError: has_tiktoken = False DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" DEFAULT_TIKTOKEN_SPECIAL_TOKENS = { "<|begin_of_text|>": 0, "<|end_of_text|>": 1, "<|fim_prefix|>": 2, "<|fim_middle|>": 3, "<|fim_end_fill|>": 253, "<|fim_pad|>": 254, "<|fim_suffix|>": 255, } TIKTOKEN_MAX_ENCODE_CHARS = 400_000 logger = logging.getLogger(__name__) class TikTokenTokenizer(Tokenizer): def __init__(self, model_path: str) -> None: mergeable_ranks = load_tiktoken_bpe(model_path) all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS) missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values()) for id in missing_ids: all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id for name in all_special_tokens_with_ids: all_special_tokens_with_ids[name] += len(mergeable_ranks) self.tkt_model = tiktoken.core.Encoding( name=Path(model_path).stem, pat_str=DEFAULT_TIKTOKEN_PATTERN, mergeable_ranks=mergeable_ranks, special_tokens=all_special_tokens_with_ids, ) self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>") self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>") self.n_words: int = self.tkt_model.n_vocab logger.info( f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" ) def encode(self, s: str, add_bos: bool, add_eos: bool): assert isinstance(s, str) subs = [] for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) return ( [self.bos_id] * add_bos + sum(self.tkt_model.encode_ordinary_batch(subs), start=[]) + [self.eos_id] * add_eos ) def decode(self, tokens: list[int]): return self.tkt_model.decode(tokens) def get_token_offsets( self, text: str, tokens: list[int] | None = None ) -> tuple[list[str], list[int]]: if tokens is not None: token_bytes = self.tkt_model.decode_tokens_bytes(tokens) else: token_bytes = self.tkt_model.decode_tokens_bytes( self.tkt_model.encode(text, allowed_special="all") ) text_len, offsets = 0, [] for token in token_bytes: offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])] return substrs, offsets