mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
147 lines
5.5 KiB
Python
147 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import pickle
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from bytelatent import ByteLatentError
|
|
|
|
LOOKUP_OFFSET = 4
|
|
|
|
|
|
def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
|
|
"""
|
|
Wrapper function for applying the lookup table to each n-gram.
|
|
|
|
:param ngram: Array of numbers representing an n-gram.
|
|
:param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
|
|
:param lookup_offset: Offset to add to the lookup result.
|
|
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
|
|
"""
|
|
|
|
def apply_lookup_table(ngram):
|
|
"""
|
|
Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
|
|
|
|
:param ngram: Array of numbers representing an n-gram.
|
|
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
|
|
"""
|
|
# Convert the n-gram to a tuple
|
|
ngram_tuple = tuple(ngram)
|
|
|
|
if ngram_tuple not in ngram_to_idx:
|
|
return 0
|
|
else:
|
|
return ngram_to_idx[ngram_tuple] + lookup_offset
|
|
|
|
return apply_lookup_table
|
|
|
|
|
|
def get_byte_ngrams_ids(
|
|
byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
|
|
):
|
|
"""
|
|
Generate n-grams from a 2D numpy array.
|
|
|
|
:param n: The length of each n-gram.
|
|
:param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
|
|
:return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
|
|
"""
|
|
num_rows, num_cols = byte_array.shape
|
|
|
|
# Create an array to hold the padded version of the original array
|
|
padded_array = np.pad(
|
|
byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
|
|
)
|
|
|
|
# Use stride tricks to avoid explicit looping
|
|
strided = np.lib.stride_tricks.as_strided
|
|
shape = (num_rows, num_cols, n)
|
|
strides = padded_array.strides[:2] + (padded_array.strides[1],)
|
|
ngrams = strided(padded_array, shape=shape, strides=strides)
|
|
|
|
ngram_ids = np.apply_along_axis(
|
|
apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
|
|
)
|
|
assert ngram_ids.shape == byte_array.shape
|
|
return ngram_ids
|
|
|
|
|
|
def reload_tables(
|
|
ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
|
|
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
|
|
"""
|
|
Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
|
|
only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
|
|
"""
|
|
idx_to_ngram_tables = {}
|
|
ngram_to_idx_tables = {}
|
|
vocab_sizes = {}
|
|
for ngram, size in ngram_to_size.items():
|
|
with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
|
|
# These are already sorted by count
|
|
# Value: tuple of: count, ngram, dataset
|
|
ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
|
|
"counts"
|
|
]
|
|
table = [ngram for ngram, _ in ngram_data][:size]
|
|
if len(table) != size:
|
|
raise ValueError(
|
|
f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
|
|
)
|
|
ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
|
|
actual_size = len(table)
|
|
idx_to_ngram_tables[ngram] = table
|
|
ngram_to_idx_tables[ngram] = ngram_to_idx
|
|
vocab_sizes[ngram] = actual_size + offset
|
|
return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
|
|
|
|
|
|
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
|
|
if ngram_to_size_str is None:
|
|
return None
|
|
ngram_to_size = {}
|
|
for entry in ngram_to_size_str.split(","):
|
|
ngram, size = entry.split(":")
|
|
ngram = int(ngram)
|
|
size = int(size)
|
|
ngram_to_size[ngram] = size
|
|
return ngram_to_size
|
|
|
|
|
|
class NgramProcessor:
|
|
def __init__(
|
|
self,
|
|
ngram_table_dir: str | None = None,
|
|
ngram_to_size: dict[int, int] | None = None,
|
|
):
|
|
if ngram_table_dir is None or ngram_to_size is None:
|
|
raise ByteLatentError(
|
|
"ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
|
|
)
|
|
(
|
|
self.ngram_to_idx_tables,
|
|
self.idx_to_ngram_tables,
|
|
self.ngram_vocab_sizes,
|
|
) = reload_tables(ngram_table_dir, ngram_to_size)
|
|
# Lowest to highest ngram
|
|
self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
|
|
# Although the model might not use all the ngrams, we need the tokenizer
|
|
# to produce ngram_ids such that index zero is the 2-gram, later on in
|
|
# src.model.megabyte.Megabyte.forward
|
|
assert self.ngram_sizes[0] == 2
|
|
|
|
def encode_single_ngram_table(self, data: np.ndarray, n: int):
|
|
"""
|
|
Return the n-grams of the input data for a given n
|
|
numpy array with ids of shape data.shape
|
|
"""
|
|
return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
|
|
|
|
def encode_token_ngrams(self, data: np.ndarray):
|
|
"""
|
|
Return the n-grams of the input data.
|
|
output shape: [ids with data.shape for n in self.ngram_sizes]
|
|
"""
|
|
return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]
|