blt/bytelatent/data/ngram_processor.py

147 lines
5.5 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pickle
from pathlib import Path
import numpy as np
from bytelatent import ByteLatentError
LOOKUP_OFFSET = 4
def apply_lookup_table_wrapper(ngram_to_idx: dict[tuple, int], lookup_offset=1):
"""
Wrapper function for applying the lookup table to each n-gram.
:param ngram: Array of numbers representing an n-gram.
:param lookup_table: Dictionary where keys are tuples (n-grams) and values are the desired outputs.
:param lookup_offset: Offset to add to the lookup result.
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
"""
def apply_lookup_table(ngram):
"""
Function to apply to each n-gram: converts it to a tuple and looks it up in a dictionary.
:param ngram: Array of numbers representing an n-gram.
:return: The value associated with the n-gram tuple in the dictionary, or None if not found.
"""
# Convert the n-gram to a tuple
ngram_tuple = tuple(ngram)
if ngram_tuple not in ngram_to_idx:
return 0
else:
return ngram_to_idx[ngram_tuple] + lookup_offset
return apply_lookup_table
def get_byte_ngrams_ids(
byte_array: np.ndarray, n: int, ngram_to_idx: dict[tuple, int], pad_value=0
):
"""
Generate n-grams from a 2D numpy array.
:param n: The length of each n-gram.
:param pad_value: The value used for padding of the byte values to maintain the same dimensions for the n-grams.
:return: A 2D numpy array where each element is the ID of an n-gram offset by LOOKUP_OFFSET.
"""
num_rows, num_cols = byte_array.shape
# Create an array to hold the padded version of the original array
padded_array = np.pad(
byte_array, ((0, 0), (n - 1, 0)), mode="constant", constant_values=pad_value
)
# Use stride tricks to avoid explicit looping
strided = np.lib.stride_tricks.as_strided
shape = (num_rows, num_cols, n)
strides = padded_array.strides[:2] + (padded_array.strides[1],)
ngrams = strided(padded_array, shape=shape, strides=strides)
ngram_ids = np.apply_along_axis(
apply_lookup_table_wrapper(ngram_to_idx, lookup_offset=LOOKUP_OFFSET), 2, ngrams
)
assert ngram_ids.shape == byte_array.shape
return ngram_ids
def reload_tables(
ngram_table_dir: str, ngram_to_size: dict[int, int], offset: int = LOOKUP_OFFSET
) -> tuple[dict[int, list], dict[tuple, int], dict[int, int]]:
"""
Reload lookup tables from a directory. Reload only the ngrams in the dictionary and per ngram,
only load up to the max specified size. Return the actual number of ngrams taken per ngram size.
"""
idx_to_ngram_tables = {}
ngram_to_idx_tables = {}
vocab_sizes = {}
for ngram, size in ngram_to_size.items():
with open(Path(ngram_table_dir) / f"ngram-{ngram}.pickle", "rb") as f:
# These are already sorted by count
# Value: tuple of: count, ngram, dataset
ngram_data: list[tuple[tuple, tuple[int, int, str]]] = pickle.load(f)[
"counts"
]
table = [ngram for ngram, _ in ngram_data][:size]
if len(table) != size:
raise ValueError(
f"Ngram table for {ngram}-gram is not large enough to get {size} ngrams, max size is {len(ngram_data)}"
)
ngram_to_idx = {ngram: idx for idx, ngram in enumerate(table)}
actual_size = len(table)
idx_to_ngram_tables[ngram] = table
ngram_to_idx_tables[ngram] = ngram_to_idx
vocab_sizes[ngram] = actual_size + offset
return ngram_to_idx_tables, ngram_to_idx_tables, vocab_sizes
def parse_ngram_to_size(ngram_to_size_str: str | None) -> dict[int, int]:
if ngram_to_size_str is None:
return None
ngram_to_size = {}
for entry in ngram_to_size_str.split(","):
ngram, size = entry.split(":")
ngram = int(ngram)
size = int(size)
ngram_to_size[ngram] = size
return ngram_to_size
class NgramProcessor:
def __init__(
self,
ngram_table_dir: str | None = None,
ngram_to_size: dict[int, int] | None = None,
):
if ngram_table_dir is None or ngram_to_size is None:
raise ByteLatentError(
"ngram_table_dir and ngram_to_size cannot be none if enable_byte_ngrams is True"
)
(
self.ngram_to_idx_tables,
self.idx_to_ngram_tables,
self.ngram_vocab_sizes,
) = reload_tables(ngram_table_dir, ngram_to_size)
# Lowest to highest ngram
self.ngram_sizes = sorted(list(self.ngram_to_idx_tables.keys()))
# Although the model might not use all the ngrams, we need the tokenizer
# to produce ngram_ids such that index zero is the 2-gram, later on in
# src.model.megabyte.Megabyte.forward
assert self.ngram_sizes[0] == 2
def encode_single_ngram_table(self, data: np.ndarray, n: int):
"""
Return the n-grams of the input data for a given n
numpy array with ids of shape data.shape
"""
return get_byte_ngrams_ids(data, n, self.ngram_to_idx_tables[n], pad_value=0)
def encode_token_ngrams(self, data: np.ndarray):
"""
Return the n-grams of the input data.
output shape: [ids with data.shape for n in self.ngram_sizes]
"""
return [self.encode_single_ngram_table(data, n) for n in self.ngram_sizes]