blt/bytelatent/preprocess/preprocess_entropies.py
Pedro Rodriguez 9c3c997cae
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled
Allow ArrowIterator to read from json
Summary:

Currently, arrow iterator can only read arrow files. However, the pyarrow library can read
other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines,
so we can read from the original source data, and simply omit the entropy column when doing so

Test Plan:

Run train script until dataloader starts
2025-02-06 17:44:36 +00:00

183 lines
6.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
import fsspec
import jsonlines
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import calculate_entropies
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_key(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
return "sample_id"
elif "title" in doc:
return "title"
elif "qid" in doc:
return "qid"
elif "paper_id" in doc:
return "paper_id"
elif "path" in doc:
return "path"
elif "url" in doc:
return "url"
elif "id" in doc:
return "id"
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict):
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(f"Could not find a text key from: {doc.keys()}")
return text
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
with fs.open(path) as f:
reader = jsonlines.Reader(f)
yield from reader
def main(
input_file: str,
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
bpe_tokenizer_path: str = "public_data/tokenizer.model",
dry_run: bool = False,
s3_profile: str | None = None,
):
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
input_fs = get_fs(input_file, s3_profile=s3_profile)
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,
)
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer_args = TokenizerArgs(
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
)
tokenizer = tokenizer_args.build()
step = 0
print("starting")
start_time = time.time()
patch_time = 0
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
text_field = pa.field("text", pa.string(), nullable=False)
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
output_fs = get_fs(output_file, s3_profile=s3_profile)
try:
with output_fs.open(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
text_buffer = []
with Progress(
*Progress.get_default_columns(),
TextColumn("Completed: {task.completed}"),
) as progress:
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in input_doc_iterator:
sample_id = get_id_from_doc(doc)
text = get_text(doc)
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores, _ = calculate_entropies(
tokens,
entropy_model,
patching_batch_size,
patching_device,
)
entropies_buffer.append(
np.array(scores.tolist(), dtype=np.float16)
)
id_buffer.append(sample_id)
text_buffer.append(text)
if len(entropies_buffer) == arrow_batch_size:
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
patch_time += time.time() - patch_start
step += 1
if step % log_step == 0:
print("Completed steps:", step)
progress.update(task, advance=1)
if len(entropies_buffer) > 0:
# Write last things
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
output_fs.touch(f"{output_file}.complete")
except:
if output_fs.exists(output_file):
output_fs.rm(output_file)
raise
elapsed = time.time() - start_time
print("steps", step)
print("done in:", elapsed)
if __name__ == "__main__":
typer.run(main)