# Copyright (c) Meta Platforms, Inc. and affiliates. import re from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.tokenizers.constants import ( BOE_ID, BOS_ID, BPE_ID, BYTE_UNITS, EOS_ID, OFFSET, PAD_ID, ) from bytelatent.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer def convert_to_bytes(s): # check if the output is a bytes like object of the format <0x00> if re.match(r"<0x[0-9a-fA-F]+>", s): return bytes.fromhex(s[3:-1]) else: return bytes(s, "utf-8", errors="ignore") def text2bytes_bpe_delims( text: str, *, bpe_tokenizer, bpe_id: int, offsetting_special_char: int, add_bos: bool, add_eos: bool, ): cur_bpe = bpe_tokenizer.encode(text, add_bos=add_bos, add_eos=add_eos) # merge the leading space tokens leading_space_tokens = [] other_bpe_tokens = [] leading = True for token in cur_bpe: bpe_str = bpe_tokenizer.sp_model.id_to_piece(token) if leading and all(c == "▁" for c in bpe_str): leading_space_tokens.append(bpe_str) else: leading = False other_bpe_tokens.append(bpe_str) cur_bpe_strs = ["".join(leading_space_tokens)] + other_bpe_tokens # Remove the '▁' characters bpe_strs = [] for i, bpe_str in enumerate(cur_bpe_strs): if ( len(bpe_strs) <= 1 and all([c == " " for s in bpe_strs for c in s]) and not all(c == "▁" for c in bpe_str) ): # Remove leading space for first non space token. bpe_str = bpe_str.replace("▁", "") elif i == 0 and all(c == "▁" for c in bpe_str): bpe_str = " " * (len(text) - len(text.lstrip(" "))) else: bpe_str = bpe_str.replace("▁", " ") if len(bpe_str) > 0: bpe_strs.append(bpe_str) ex_seq = [] # Convert bpe tokens to bytes for s in bpe_strs: byte_chunk = convert_to_bytes(s) proc_chunk = [int(unit) for unit in byte_chunk] ex_seq.extend([bpe_id - offsetting_special_char] + proc_chunk) return ex_seq class BltTokenizer(Tokenizer): def __init__( self, *, vocab_size_unit_1: int = BYTE_UNITS, bpe_delim: bool = False, bpe_tokenizer_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", add_bos: bool = True, add_eos: bool = True, ): self.add_bos = add_bos self.add_eos = add_eos self.vocab_size_unit_1 = vocab_size_unit_1 self.boe_id = BOE_ID self.bos_id = BOS_ID self.eos_id = EOS_ID self.pad_id = PAD_ID self.bpe_id = BPE_ID self.bpe_tokenizer_path = bpe_tokenizer_path if bpe_delim: self.bpe_tokenizer = SentencePieceTokenizer( model_path=self.bpe_tokenizer_path ) else: self.bpe_tokenizer = None self.bpe_delim = bpe_delim self.offsetting_special_char = OFFSET self.vocab_size_unit_1 = vocab_size_unit_1 self.n_words = vocab_size_unit_1 + self.offsetting_special_char def encode( self, text: str, add_bos: bool | None = None, add_eos: bool | None = None ): if add_bos is None: add_bos = self.add_bos if add_eos is None: add_eos = self.add_eos if self.bpe_delim: tokens = text2bytes_bpe_delims( text, bpe_tokenizer=self.bpe_tokenizer, bpe_id=self.bpe_id, offsetting_special_char=self.offsetting_special_char, add_bos=False, add_eos=False, ) else: tokens = bytes(text, encoding="utf-8", errors="ignore") # Offsetting tokens = [int(unit) + self.offsetting_special_char for unit in tokens] if add_bos: tokens.insert(0, self.bos_id) if add_eos: tokens.append(self.eos_id) return tokens def decode(self, tokens: list[int], cut_at_eos: bool = False): if cut_at_eos: for k, t in enumerate(tokens): if t == self.eos_id: tokens = tokens[: k + 1] break return bytes( [ tok - self.offsetting_special_char for tok in tokens if tok - self.offsetting_special_char >= 0 ] ).decode("utf-8", errors="ignore") def get_token_offsets(self, text: str, tokens: list[int] | None = None): # TODO: Figure out what this does raise NotImplementedError()