blt/bytelatent/preprocess/preprocess_entropies.py

175 lines
6.1 KiB
Python
Raw Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
import fsspec
import jsonlines
2024-12-12 23:32:30 +00:00
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import calculate_entropies
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
elif "title" in doc:
sample_id = doc["title"]
elif "qid" in doc:
sample_id = doc["qid"]
elif "paper_id" in doc:
sample_id = doc["paper_id"]
elif "path" in doc:
sample_id = doc["path"]
elif "url" in doc:
sample_id = doc["url"]
elif "id" in doc:
sample_id = doc["id"]
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_text(doc: dict):
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(f"Could not find a text key from: {doc.keys()}")
return text
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
with fs.open(path) as f:
reader = jsonlines.Reader(f)
yield from reader
2024-12-12 23:32:30 +00:00
def main(
input_file: str,
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
bpe_tokenizer_path: str = "public_data/tokenizer.model",
2024-12-12 23:32:30 +00:00
dry_run: bool = False,
s3_profile: str | None = None,
2024-12-12 23:32:30 +00:00
):
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
input_fs = get_fs(input_file, s3_profile=s3_profile)
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
2024-12-12 23:32:30 +00:00
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,
2024-12-12 23:32:30 +00:00
)
2024-12-12 23:32:30 +00:00
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer_args = TokenizerArgs(
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
2024-12-12 23:32:30 +00:00
)
tokenizer = tokenizer_args.build()
2024-12-12 23:32:30 +00:00
step = 0
print("starting")
start_time = time.time()
patch_time = 0
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
text_field = pa.field("text", pa.string(), nullable=False)
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
output_fs = get_fs(output_file, s3_profile=s3_profile)
2024-12-12 23:32:30 +00:00
try:
with output_fs.open(output_file, "wb") as sink:
2024-12-12 23:32:30 +00:00
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
text_buffer = []
with Progress(
*Progress.get_default_columns(),
TextColumn("Completed: {task.completed}"),
) as progress:
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in input_doc_iterator:
2024-12-12 23:32:30 +00:00
sample_id = get_id_from_doc(doc)
text = get_text(doc)
2024-12-12 23:32:30 +00:00
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores = calculate_entropies(
tokens,
entropy_model,
patching_batch_size,
patching_device,
)
entropies_buffer.append(
np.array(scores.tolist(), dtype=np.float16)
)
id_buffer.append(sample_id)
text_buffer.append(text)
if len(entropies_buffer) == arrow_batch_size:
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
patch_time += time.time() - patch_start
step += 1
if step % log_step == 0:
print("Completed steps:", step)
progress.update(task, advance=1)
if len(entropies_buffer) > 0:
# Write last things
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
output_fs.touch(f"{output_file}.complete")
2024-12-12 23:32:30 +00:00
except:
if output_fs.exists(output_file):
output_fs.rm(output_file)
2024-12-12 23:32:30 +00:00
raise
elapsed = time.time() - start_time
print("steps", step)
print("done in:", elapsed)
if __name__ == "__main__":
typer.run(main)