blt/bytelatent/preprocess/preprocess_entropies.py

142 lines
5.2 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
from pathlib import Path
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
def main(
input_file: str,
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
dry_run: bool = False,
):
# TODO: Modify this to work with the new code
raise NotImplementedError()
iterator = ArrowFileIterator(
file_path=input_file,
worker_id=0,
num_workers=1,
)
tokenization_mode = "bytes"
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir, device=patching_device
)
entropy_model, _ = to_device(entropy_model, patching_device)
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer = Tokenizer(
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
tokenization_mode=tokenization_mode,
# BYTE_UNITS
vocab_size_unit_1=256,
bos=True,
eos=True,
bpe_delim=False,
# This isn't used, just stores a reference for other calls we don't use
patcher=None,
)
step = 0
print("starting")
start_time = time.time()
patch_time = 0
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
text_field = pa.field("text", pa.string(), nullable=False)
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
try:
with pa.OSFile(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
text_buffer = []
with Progress(
*Progress.get_default_columns(),
TextColumn("Completed: {task.completed}"),
) as progress:
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in iterator:
sample_id = get_id_from_doc(doc)
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(
f"Could not find a text key from: {doc.keys()}"
)
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores = calculate_entropies(
tokens,
entropy_model,
patching_batch_size,
patching_device,
)
entropies_buffer.append(
np.array(scores.tolist(), dtype=np.float16)
)
id_buffer.append(sample_id)
text_buffer.append(text)
if len(entropies_buffer) == arrow_batch_size:
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
patch_time += time.time() - patch_start
step += 1
if step % log_step == 0:
print("Completed steps:", step)
progress.update(task, advance=1)
if len(entropies_buffer) > 0:
# Write last things
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
Path(f"{output_file}.complete").touch()
except:
Path(output_file).unlink(missing_ok=True)
raise
elapsed = time.time() - start_time
print("steps", step)
print("done in:", elapsed)
if __name__ == "__main__":
typer.run(main)