Update preprocess_entropies script to blt inference + add fsspec support

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-13 23:26:26 +00:00
parent b0120da72f
commit d718cfa9a1
3 changed files with 74 additions and 40 deletions

View file

@ -82,16 +82,16 @@ def calculate_entropies(
if device is not None:
split = split.to(device)
assert torch.all(split >= 0) and torch.all(split < 260)
pred, _ = entropy_model(split)
pred = entropy_model(split)
pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, :
] # [batch_size * seq_len, vocab]
pred_entropies = entropy(pred)
entropies.append(pred_entropies)
entropies = torch.cat(entropies, dim=0)
entropies = entropies.reshape(tokens.shape)
return entropies
concat_entropies = torch.cat(entropies, dim=0)
concat_entropies = concat_entropies.reshape(tokens.shape)
return concat_entropies
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):

View file

@ -1,14 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
from pathlib import Path
import fsspec
import jsonlines
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import calculate_entropies
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
elif "title" in doc:
sample_id = doc["title"]
elif "qid" in doc:
sample_id = doc["qid"]
elif "paper_id" in doc:
sample_id = doc["paper_id"]
elif "path" in doc:
sample_id = doc["path"]
elif "url" in doc:
sample_id = doc["url"]
elif "id" in doc:
sample_id = doc["id"]
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_text(doc: dict):
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(f"Could not find a text key from: {doc.keys()}")
return text
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
with fs.open(path) as f:
reader = jsonlines.Reader(f)
yield from reader
def main(
@ -16,39 +61,32 @@ def main(
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
bpe_tokenizer_path: str = "public_data/tokenizer.model",
dry_run: bool = False,
s3_profile: str | None = None,
):
# TODO: Modify this to work with the new code
raise NotImplementedError()
iterator = ArrowFileIterator(
file_path=input_file,
worker_id=0,
num_workers=1,
)
tokenization_mode = "bytes"
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
input_fs = get_fs(input_file, s3_profile=s3_profile)
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir, device=patching_device
entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,
)
entropy_model, _ = to_device(entropy_model, patching_device)
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer = Tokenizer(
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
tokenization_mode=tokenization_mode,
# BYTE_UNITS
vocab_size_unit_1=256,
bos=True,
eos=True,
bpe_delim=False,
# This isn't used, just stores a reference for other calls we don't use
patcher=None,
tokenizer_args = TokenizerArgs(
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
)
tokenizer = tokenizer_args.build()
step = 0
print("starting")
start_time = time.time()
@ -59,8 +97,10 @@ def main(
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
output_fs = get_fs(output_file, s3_profile=s3_profile)
try:
with pa.OSFile(output_file, "wb") as sink:
with output_fs.open(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
@ -72,17 +112,9 @@ def main(
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in iterator:
for doc in input_doc_iterator:
sample_id = get_id_from_doc(doc)
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(
f"Could not find a text key from: {doc.keys()}"
)
text = get_text(doc)
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores = calculate_entropies(
@ -128,9 +160,10 @@ def main(
entropies_buffer = []
id_buffer = []
text_buffer = []
Path(f"{output_file}.complete").touch()
output_fs.touch(f"{output_file}.complete")
except:
Path(output_file).unlink(missing_ok=True)
if output_fs.exists(output_file):
output_fs.rm(output_file)
raise
elapsed = time.time() - start_time
print("steps", step)

View file

@ -21,3 +21,4 @@ submitit
typer
rich
fsspec[full]
orjson