mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
Summary: Currently, arrow iterator can only read arrow files. However, the pyarrow library can read other formats, including jsonlines. This allows the same ArrowIterator to read from jsonlines, so we can read from the original source data, and simply omit the entropy column when doing so Test Plan: Run train script until dataloader starts
183 lines
6.2 KiB
Python
183 lines
6.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import time
|
|
|
|
import fsspec
|
|
import jsonlines
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
import torch
|
|
import typer
|
|
from rich.progress import Progress, TextColumn
|
|
|
|
from bytelatent.data.file_util import get_fs
|
|
from bytelatent.data.patcher import calculate_entropies
|
|
from bytelatent.entropy_model import load_entropy_model
|
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
|
|
|
|
|
def get_id_key(doc: dict) -> int:
|
|
"""
|
|
We need a reliable way to ensure that samples from jsonl
|
|
and arrow are the same, but there is no unique id field,
|
|
so derive the best possible
|
|
"""
|
|
if "sample_id" in doc:
|
|
return "sample_id"
|
|
elif "title" in doc:
|
|
return "title"
|
|
elif "qid" in doc:
|
|
return "qid"
|
|
elif "paper_id" in doc:
|
|
return "paper_id"
|
|
elif "path" in doc:
|
|
return "path"
|
|
elif "url" in doc:
|
|
return "url"
|
|
elif "id" in doc:
|
|
return "id"
|
|
else:
|
|
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
|
|
|
|
|
def get_id_from_doc(doc: dict) -> int:
|
|
"""
|
|
We need a reliable way to ensure that samples from jsonl
|
|
and arrow are the same, but there is no unique id field,
|
|
so derive the best possible
|
|
"""
|
|
return str(doc[get_id_key(doc)])
|
|
|
|
|
|
def get_text(doc: dict):
|
|
if "text" in doc:
|
|
text = doc["text"]
|
|
elif "content" in doc:
|
|
text = doc["content"]
|
|
else:
|
|
raise ValueError(f"Could not find a text key from: {doc.keys()}")
|
|
return text
|
|
|
|
|
|
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
|
|
with fs.open(path) as f:
|
|
reader = jsonlines.Reader(f)
|
|
yield from reader
|
|
|
|
|
|
def main(
|
|
input_file: str,
|
|
output_file: str,
|
|
patching_device: str = "cuda",
|
|
log_step: int = 10_000,
|
|
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
|
|
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
|
|
bpe_tokenizer_path: str = "public_data/tokenizer.model",
|
|
dry_run: bool = False,
|
|
s3_profile: str | None = None,
|
|
):
|
|
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
|
print("Loading entropy model", entropy_model_checkpoint_dir)
|
|
input_fs = get_fs(input_file, s3_profile=s3_profile)
|
|
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
|
|
|
|
if dry_run:
|
|
return
|
|
entropy_model = load_entropy_model(
|
|
entropy_model_checkpoint_dir,
|
|
entropy_model_state_dict_path,
|
|
device=patching_device,
|
|
)
|
|
|
|
print("Creating patcher")
|
|
patching_batch_size = 32
|
|
print("Creating tokenizer")
|
|
tokenizer_args = TokenizerArgs(
|
|
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
|
|
)
|
|
tokenizer = tokenizer_args.build()
|
|
step = 0
|
|
print("starting")
|
|
start_time = time.time()
|
|
patch_time = 0
|
|
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
|
|
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
|
|
text_field = pa.field("text", pa.string(), nullable=False)
|
|
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
|
arrow_batch_size = 1_000
|
|
|
|
output_fs = get_fs(output_file, s3_profile=s3_profile)
|
|
|
|
try:
|
|
with output_fs.open(output_file, "wb") as sink:
|
|
with pa.ipc.new_file(sink, schema) as writer:
|
|
id_buffer = []
|
|
entropies_buffer = []
|
|
text_buffer = []
|
|
with Progress(
|
|
*Progress.get_default_columns(),
|
|
TextColumn("Completed: {task.completed}"),
|
|
) as progress:
|
|
task = progress.add_task(
|
|
"[green]Calculating entropies...", total=None
|
|
)
|
|
for doc in input_doc_iterator:
|
|
sample_id = get_id_from_doc(doc)
|
|
text = get_text(doc)
|
|
tokens = torch.tensor(tokenizer.encode(text))
|
|
patch_start = time.time()
|
|
scores, _ = calculate_entropies(
|
|
tokens,
|
|
entropy_model,
|
|
patching_batch_size,
|
|
patching_device,
|
|
)
|
|
entropies_buffer.append(
|
|
np.array(scores.tolist(), dtype=np.float16)
|
|
)
|
|
id_buffer.append(sample_id)
|
|
text_buffer.append(text)
|
|
if len(entropies_buffer) == arrow_batch_size:
|
|
batch = pa.record_batch(
|
|
{
|
|
"entropies": entropies_buffer,
|
|
"sample_id": id_buffer,
|
|
"text": text_buffer,
|
|
},
|
|
schema,
|
|
)
|
|
writer.write(batch)
|
|
entropies_buffer = []
|
|
id_buffer = []
|
|
text_buffer = []
|
|
patch_time += time.time() - patch_start
|
|
step += 1
|
|
if step % log_step == 0:
|
|
print("Completed steps:", step)
|
|
progress.update(task, advance=1)
|
|
if len(entropies_buffer) > 0:
|
|
# Write last things
|
|
batch = pa.record_batch(
|
|
{
|
|
"entropies": entropies_buffer,
|
|
"sample_id": id_buffer,
|
|
"text": text_buffer,
|
|
},
|
|
schema,
|
|
)
|
|
writer.write(batch)
|
|
entropies_buffer = []
|
|
id_buffer = []
|
|
text_buffer = []
|
|
output_fs.touch(f"{output_file}.complete")
|
|
except:
|
|
if output_fs.exists(output_file):
|
|
output_fs.rm(output_file)
|
|
raise
|
|
elapsed = time.time() - start_time
|
|
print("steps", step)
|
|
print("done in:", elapsed)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
typer.run(main)
|