blt/bytelatent/preprocess/preprocess_entropies.py
Pedro Rodriguez 1da3dd9315
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run
Update preprocess_entropies script to blt inference + add fsspec support (#23)
Summary:

Test Plan:
2025-01-13 15:28:14 -08:00

175 lines
6.1 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
import fsspec
import jsonlines
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import calculate_entropies
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
elif "title" in doc:
sample_id = doc["title"]
elif "qid" in doc:
sample_id = doc["qid"]
elif "paper_id" in doc:
sample_id = doc["paper_id"]
elif "path" in doc:
sample_id = doc["path"]
elif "url" in doc:
sample_id = doc["url"]
elif "id" in doc:
sample_id = doc["id"]
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_text(doc: dict):
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(f"Could not find a text key from: {doc.keys()}")
return text
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
with fs.open(path) as f:
reader = jsonlines.Reader(f)
yield from reader
def main(
input_file: str,
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
bpe_tokenizer_path: str = "public_data/tokenizer.model",
dry_run: bool = False,
s3_profile: str | None = None,
):
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
input_fs = get_fs(input_file, s3_profile=s3_profile)
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,
)
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer_args = TokenizerArgs(
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
)
tokenizer = tokenizer_args.build()
step = 0
print("starting")
start_time = time.time()
patch_time = 0
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
text_field = pa.field("text", pa.string(), nullable=False)
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
output_fs = get_fs(output_file, s3_profile=s3_profile)
try:
with output_fs.open(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
text_buffer = []
with Progress(
*Progress.get_default_columns(),
TextColumn("Completed: {task.completed}"),
) as progress:
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in input_doc_iterator:
sample_id = get_id_from_doc(doc)
text = get_text(doc)
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores = calculate_entropies(
tokens,
entropy_model,
patching_batch_size,
patching_device,
)
entropies_buffer.append(
np.array(scores.tolist(), dtype=np.float16)
)
id_buffer.append(sample_id)
text_buffer.append(text)
if len(entropies_buffer) == arrow_batch_size:
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
patch_time += time.time() - patch_start
step += 1
if step % log_step == 0:
print("Completed steps:", step)
progress.update(task, advance=1)
if len(entropies_buffer) > 0:
# Write last things
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
output_fs.touch(f"{output_file}.complete")
except:
if output_fs.exists(output_file):
output_fs.rm(output_file)
raise
elapsed = time.time() - start_time
print("steps", step)
print("done in:", elapsed)
if __name__ == "__main__":
typer.run(main)