2024-12-12 23:32:30 +00:00
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
import time
|
|
|
|
|
2025-01-13 23:28:14 +00:00
|
|
|
import fsspec
|
|
|
|
import jsonlines
|
2024-12-12 23:32:30 +00:00
|
|
|
import numpy as np
|
|
|
|
import pyarrow as pa
|
|
|
|
import torch
|
|
|
|
import typer
|
|
|
|
from rich.progress import Progress, TextColumn
|
|
|
|
|
2025-01-13 23:28:14 +00:00
|
|
|
from bytelatent.data.file_util import get_fs
|
|
|
|
from bytelatent.data.patcher import calculate_entropies
|
|
|
|
from bytelatent.entropy_model import load_entropy_model
|
|
|
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
|
|
|
|
|
|
|
|
|
|
|
def get_id_from_doc(doc: dict) -> int:
|
|
|
|
"""
|
|
|
|
We need a reliable way to ensure that samples from jsonl
|
|
|
|
and arrow are the same, but there is no unique id field,
|
|
|
|
so derive the best possible
|
|
|
|
"""
|
|
|
|
if "sample_id" in doc:
|
|
|
|
sample_id = doc["sample_id"]
|
|
|
|
elif "title" in doc:
|
|
|
|
sample_id = doc["title"]
|
|
|
|
elif "qid" in doc:
|
|
|
|
sample_id = doc["qid"]
|
|
|
|
elif "paper_id" in doc:
|
|
|
|
sample_id = doc["paper_id"]
|
|
|
|
elif "path" in doc:
|
|
|
|
sample_id = doc["path"]
|
|
|
|
elif "url" in doc:
|
|
|
|
sample_id = doc["url"]
|
|
|
|
elif "id" in doc:
|
|
|
|
sample_id = doc["id"]
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
|
|
|
return str(sample_id)
|
|
|
|
|
|
|
|
|
|
|
|
def get_text(doc: dict):
|
|
|
|
if "text" in doc:
|
|
|
|
text = doc["text"]
|
|
|
|
elif "content" in doc:
|
|
|
|
text = doc["content"]
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Could not find a text key from: {doc.keys()}")
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
|
|
|
|
with fs.open(path) as f:
|
|
|
|
reader = jsonlines.Reader(f)
|
|
|
|
yield from reader
|
2024-12-12 23:32:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
def main(
|
|
|
|
input_file: str,
|
|
|
|
output_file: str,
|
|
|
|
patching_device: str = "cuda",
|
|
|
|
log_step: int = 10_000,
|
2025-01-13 23:28:14 +00:00
|
|
|
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
|
|
|
|
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
|
|
|
|
bpe_tokenizer_path: str = "public_data/tokenizer.model",
|
2024-12-12 23:32:30 +00:00
|
|
|
dry_run: bool = False,
|
2025-01-13 23:28:14 +00:00
|
|
|
s3_profile: str | None = None,
|
2024-12-12 23:32:30 +00:00
|
|
|
):
|
|
|
|
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
|
|
|
print("Loading entropy model", entropy_model_checkpoint_dir)
|
2025-01-13 23:28:14 +00:00
|
|
|
input_fs = get_fs(input_file, s3_profile=s3_profile)
|
|
|
|
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
|
|
|
|
|
2024-12-12 23:32:30 +00:00
|
|
|
if dry_run:
|
|
|
|
return
|
|
|
|
entropy_model = load_entropy_model(
|
2025-01-13 23:28:14 +00:00
|
|
|
entropy_model_checkpoint_dir,
|
|
|
|
entropy_model_state_dict_path,
|
|
|
|
device=patching_device,
|
2024-12-12 23:32:30 +00:00
|
|
|
)
|
2025-01-13 23:28:14 +00:00
|
|
|
|
2024-12-12 23:32:30 +00:00
|
|
|
print("Creating patcher")
|
|
|
|
patching_batch_size = 32
|
|
|
|
print("Creating tokenizer")
|
2025-01-13 23:28:14 +00:00
|
|
|
tokenizer_args = TokenizerArgs(
|
|
|
|
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
|
2024-12-12 23:32:30 +00:00
|
|
|
)
|
2025-01-13 23:28:14 +00:00
|
|
|
tokenizer = tokenizer_args.build()
|
2024-12-12 23:32:30 +00:00
|
|
|
step = 0
|
|
|
|
print("starting")
|
|
|
|
start_time = time.time()
|
|
|
|
patch_time = 0
|
|
|
|
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
|
|
|
|
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
|
|
|
|
text_field = pa.field("text", pa.string(), nullable=False)
|
|
|
|
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
|
|
|
arrow_batch_size = 1_000
|
|
|
|
|
2025-01-13 23:28:14 +00:00
|
|
|
output_fs = get_fs(output_file, s3_profile=s3_profile)
|
|
|
|
|
2024-12-12 23:32:30 +00:00
|
|
|
try:
|
2025-01-13 23:28:14 +00:00
|
|
|
with output_fs.open(output_file, "wb") as sink:
|
2024-12-12 23:32:30 +00:00
|
|
|
with pa.ipc.new_file(sink, schema) as writer:
|
|
|
|
id_buffer = []
|
|
|
|
entropies_buffer = []
|
|
|
|
text_buffer = []
|
|
|
|
with Progress(
|
|
|
|
*Progress.get_default_columns(),
|
|
|
|
TextColumn("Completed: {task.completed}"),
|
|
|
|
) as progress:
|
|
|
|
task = progress.add_task(
|
|
|
|
"[green]Calculating entropies...", total=None
|
|
|
|
)
|
2025-01-13 23:28:14 +00:00
|
|
|
for doc in input_doc_iterator:
|
2024-12-12 23:32:30 +00:00
|
|
|
sample_id = get_id_from_doc(doc)
|
2025-01-13 23:28:14 +00:00
|
|
|
text = get_text(doc)
|
2024-12-12 23:32:30 +00:00
|
|
|
tokens = torch.tensor(tokenizer.encode(text))
|
|
|
|
patch_start = time.time()
|
|
|
|
scores = calculate_entropies(
|
|
|
|
tokens,
|
|
|
|
entropy_model,
|
|
|
|
patching_batch_size,
|
|
|
|
patching_device,
|
|
|
|
)
|
|
|
|
entropies_buffer.append(
|
|
|
|
np.array(scores.tolist(), dtype=np.float16)
|
|
|
|
)
|
|
|
|
id_buffer.append(sample_id)
|
|
|
|
text_buffer.append(text)
|
|
|
|
if len(entropies_buffer) == arrow_batch_size:
|
|
|
|
batch = pa.record_batch(
|
|
|
|
{
|
|
|
|
"entropies": entropies_buffer,
|
|
|
|
"sample_id": id_buffer,
|
|
|
|
"text": text_buffer,
|
|
|
|
},
|
|
|
|
schema,
|
|
|
|
)
|
|
|
|
writer.write(batch)
|
|
|
|
entropies_buffer = []
|
|
|
|
id_buffer = []
|
|
|
|
text_buffer = []
|
|
|
|
patch_time += time.time() - patch_start
|
|
|
|
step += 1
|
|
|
|
if step % log_step == 0:
|
|
|
|
print("Completed steps:", step)
|
|
|
|
progress.update(task, advance=1)
|
|
|
|
if len(entropies_buffer) > 0:
|
|
|
|
# Write last things
|
|
|
|
batch = pa.record_batch(
|
|
|
|
{
|
|
|
|
"entropies": entropies_buffer,
|
|
|
|
"sample_id": id_buffer,
|
|
|
|
"text": text_buffer,
|
|
|
|
},
|
|
|
|
schema,
|
|
|
|
)
|
|
|
|
writer.write(batch)
|
|
|
|
entropies_buffer = []
|
|
|
|
id_buffer = []
|
|
|
|
text_buffer = []
|
2025-01-13 23:28:14 +00:00
|
|
|
output_fs.touch(f"{output_file}.complete")
|
2024-12-12 23:32:30 +00:00
|
|
|
except:
|
2025-01-13 23:28:14 +00:00
|
|
|
if output_fs.exists(output_file):
|
|
|
|
output_fs.rm(output_file)
|
2024-12-12 23:32:30 +00:00
|
|
|
raise
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
print("steps", step)
|
|
|
|
print("done in:", elapsed)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
typer.run(main)
|