blt/bytelatent/preprocess/preprocess_entropies.py
Pedro Rodriguez c6ef4285e2 Several changes to enable entropy model training/eval
Summary:

- Make arrow iterator able to read from jsonl files, the entropies are omitted in this case
- Make the data/checkpoint code fsspec compatible
- Fix issues with all reduce with non-bf16 in dist_sum and norm computation.
- Minimal fixes to get eval to run, it is slow currently
- Add bpb numbers during training


Test Plan:
2025-02-04 18:03:19 +00:00

183 lines
6.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import time
import fsspec
import jsonlines
import numpy as np
import pyarrow as pa
import torch
import typer
from rich.progress import Progress, TextColumn
from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import calculate_entropies
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_key(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
return "sample_id"
elif "title" in doc:
return "title"
elif "qid" in doc:
return "qid"
elif "paper_id" in doc:
return "paper_id"
elif "path" in doc:
return "path"
elif "url" in doc:
return "url"
elif "id" in doc:
return "id"
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
return str(doc[get_id_key(doc)])
def get_text(doc: dict):
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(f"Could not find a text key from: {doc.keys()}")
return text
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
with fs.open(path) as f:
reader = jsonlines.Reader(f)
yield from reader
def main(
input_file: str,
output_file: str,
patching_device: str = "cuda",
log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
bpe_tokenizer_path: str = "public_data/tokenizer.model",
dry_run: bool = False,
s3_profile: str | None = None,
):
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir)
input_fs = get_fs(input_file, s3_profile=s3_profile)
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
if dry_run:
return
entropy_model = load_entropy_model(
entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,
)
print("Creating patcher")
patching_batch_size = 32
print("Creating tokenizer")
tokenizer_args = TokenizerArgs(
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
)
tokenizer = tokenizer_args.build()
step = 0
print("starting")
start_time = time.time()
patch_time = 0
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
text_field = pa.field("text", pa.string(), nullable=False)
schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000
output_fs = get_fs(output_file, s3_profile=s3_profile)
try:
with output_fs.open(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer:
id_buffer = []
entropies_buffer = []
text_buffer = []
with Progress(
*Progress.get_default_columns(),
TextColumn("Completed: {task.completed}"),
) as progress:
task = progress.add_task(
"[green]Calculating entropies...", total=None
)
for doc in input_doc_iterator:
sample_id = get_id_from_doc(doc)
text = get_text(doc)
tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time()
scores, _ = calculate_entropies(
tokens,
entropy_model,
patching_batch_size,
patching_device,
)
entropies_buffer.append(
np.array(scores.tolist(), dtype=np.float16)
)
id_buffer.append(sample_id)
text_buffer.append(text)
if len(entropies_buffer) == arrow_batch_size:
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
patch_time += time.time() - patch_start
step += 1
if step % log_step == 0:
print("Completed steps:", step)
progress.update(task, advance=1)
if len(entropies_buffer) > 0:
# Write last things
batch = pa.record_batch(
{
"entropies": entropies_buffer,
"sample_id": id_buffer,
"text": text_buffer,
},
schema,
)
writer.write(batch)
entropies_buffer = []
id_buffer = []
text_buffer = []
output_fs.touch(f"{output_file}.complete")
except:
if output_fs.exists(output_file):
output_fs.rm(output_file)
raise
elapsed = time.time() - start_time
print("steps", step)
print("done in:", elapsed)
if __name__ == "__main__":
typer.run(main)