mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 08:57:45 +00:00
142 lines
5.2 KiB
Python
142 lines
5.2 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
import time
|
||
|
from pathlib import Path
|
||
|
|
||
|
import numpy as np
|
||
|
import pyarrow as pa
|
||
|
import torch
|
||
|
import typer
|
||
|
from rich.progress import Progress, TextColumn
|
||
|
|
||
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||
|
|
||
|
|
||
|
def main(
|
||
|
input_file: str,
|
||
|
output_file: str,
|
||
|
patching_device: str = "cuda",
|
||
|
log_step: int = 10_000,
|
||
|
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
|
||
|
dry_run: bool = False,
|
||
|
):
|
||
|
# TODO: Modify this to work with the new code
|
||
|
raise NotImplementedError()
|
||
|
iterator = ArrowFileIterator(
|
||
|
file_path=input_file,
|
||
|
worker_id=0,
|
||
|
num_workers=1,
|
||
|
)
|
||
|
tokenization_mode = "bytes"
|
||
|
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
||
|
print("Loading entropy model", entropy_model_checkpoint_dir)
|
||
|
if dry_run:
|
||
|
return
|
||
|
entropy_model = load_entropy_model(
|
||
|
entropy_model_checkpoint_dir, device=patching_device
|
||
|
)
|
||
|
entropy_model, _ = to_device(entropy_model, patching_device)
|
||
|
print("Creating patcher")
|
||
|
patching_batch_size = 32
|
||
|
print("Creating tokenizer")
|
||
|
tokenizer = Tokenizer(
|
||
|
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
|
||
|
tokenization_mode=tokenization_mode,
|
||
|
# BYTE_UNITS
|
||
|
vocab_size_unit_1=256,
|
||
|
bos=True,
|
||
|
eos=True,
|
||
|
bpe_delim=False,
|
||
|
# This isn't used, just stores a reference for other calls we don't use
|
||
|
patcher=None,
|
||
|
)
|
||
|
step = 0
|
||
|
print("starting")
|
||
|
start_time = time.time()
|
||
|
patch_time = 0
|
||
|
entropy_field = pa.field("entropies", pa.list_(pa.float16()), nullable=False)
|
||
|
sample_id_field = pa.field("sample_id", pa.string(), nullable=False)
|
||
|
text_field = pa.field("text", pa.string(), nullable=False)
|
||
|
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
||
|
arrow_batch_size = 1_000
|
||
|
|
||
|
try:
|
||
|
with pa.OSFile(output_file, "wb") as sink:
|
||
|
with pa.ipc.new_file(sink, schema) as writer:
|
||
|
id_buffer = []
|
||
|
entropies_buffer = []
|
||
|
text_buffer = []
|
||
|
with Progress(
|
||
|
*Progress.get_default_columns(),
|
||
|
TextColumn("Completed: {task.completed}"),
|
||
|
) as progress:
|
||
|
task = progress.add_task(
|
||
|
"[green]Calculating entropies...", total=None
|
||
|
)
|
||
|
for doc in iterator:
|
||
|
sample_id = get_id_from_doc(doc)
|
||
|
|
||
|
if "text" in doc:
|
||
|
text = doc["text"]
|
||
|
elif "content" in doc:
|
||
|
text = doc["content"]
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Could not find a text key from: {doc.keys()}"
|
||
|
)
|
||
|
tokens = torch.tensor(tokenizer.encode(text))
|
||
|
patch_start = time.time()
|
||
|
scores = calculate_entropies(
|
||
|
tokens,
|
||
|
entropy_model,
|
||
|
patching_batch_size,
|
||
|
patching_device,
|
||
|
)
|
||
|
entropies_buffer.append(
|
||
|
np.array(scores.tolist(), dtype=np.float16)
|
||
|
)
|
||
|
id_buffer.append(sample_id)
|
||
|
text_buffer.append(text)
|
||
|
if len(entropies_buffer) == arrow_batch_size:
|
||
|
batch = pa.record_batch(
|
||
|
{
|
||
|
"entropies": entropies_buffer,
|
||
|
"sample_id": id_buffer,
|
||
|
"text": text_buffer,
|
||
|
},
|
||
|
schema,
|
||
|
)
|
||
|
writer.write(batch)
|
||
|
entropies_buffer = []
|
||
|
id_buffer = []
|
||
|
text_buffer = []
|
||
|
patch_time += time.time() - patch_start
|
||
|
step += 1
|
||
|
if step % log_step == 0:
|
||
|
print("Completed steps:", step)
|
||
|
progress.update(task, advance=1)
|
||
|
if len(entropies_buffer) > 0:
|
||
|
# Write last things
|
||
|
batch = pa.record_batch(
|
||
|
{
|
||
|
"entropies": entropies_buffer,
|
||
|
"sample_id": id_buffer,
|
||
|
"text": text_buffer,
|
||
|
},
|
||
|
schema,
|
||
|
)
|
||
|
writer.write(batch)
|
||
|
entropies_buffer = []
|
||
|
id_buffer = []
|
||
|
text_buffer = []
|
||
|
Path(f"{output_file}.complete").touch()
|
||
|
except:
|
||
|
Path(output_file).unlink(missing_ok=True)
|
||
|
raise
|
||
|
elapsed = time.time() - start_time
|
||
|
print("steps", step)
|
||
|
print("done in:", elapsed)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
typer.run(main)
|