Update preprocess_entropies script to blt inference + add fsspec support

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-13 23:26:26 +00:00
parent b0120da72f
commit d718cfa9a1
3 changed files with 74 additions and 40 deletions

View file

@ -82,16 +82,16 @@ def calculate_entropies(
if device is not None: if device is not None:
split = split.to(device) split = split.to(device)
assert torch.all(split >= 0) and torch.all(split < 260) assert torch.all(split >= 0) and torch.all(split < 260)
pred, _ = entropy_model(split) pred = entropy_model(split)
pred = pred.reshape(-1, pred.shape[-1])[ pred = pred.reshape(-1, pred.shape[-1])[
: split.numel() - pad_size, : : split.numel() - pad_size, :
] # [batch_size * seq_len, vocab] ] # [batch_size * seq_len, vocab]
pred_entropies = entropy(pred) pred_entropies = entropy(pred)
entropies.append(pred_entropies) entropies.append(pred_entropies)
entropies = torch.cat(entropies, dim=0) concat_entropies = torch.cat(entropies, dim=0)
entropies = entropies.reshape(tokens.shape) concat_entropies = concat_entropies.reshape(tokens.shape)
return entropies return concat_entropies
def patch_start_mask_from_entropy_with_monotonicity(entropies, t): def patch_start_mask_from_entropy_with_monotonicity(entropies, t):

View file

@ -1,14 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import time import time
from pathlib import Path
import fsspec
import jsonlines
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
import torch import torch
import typer import typer
from rich.progress import Progress, TextColumn from rich.progress import Progress, TextColumn
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import calculate_entropies
from bytelatent.entropy_model import load_entropy_model
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
def get_id_from_doc(doc: dict) -> int:
"""
We need a reliable way to ensure that samples from jsonl
and arrow are the same, but there is no unique id field,
so derive the best possible
"""
if "sample_id" in doc:
sample_id = doc["sample_id"]
elif "title" in doc:
sample_id = doc["title"]
elif "qid" in doc:
sample_id = doc["qid"]
elif "paper_id" in doc:
sample_id = doc["paper_id"]
elif "path" in doc:
sample_id = doc["path"]
elif "url" in doc:
sample_id = doc["url"]
elif "id" in doc:
sample_id = doc["id"]
else:
raise ValueError(f"Could not find a id key from: {doc.keys()}")
return str(sample_id)
def get_text(doc: dict):
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(f"Could not find a text key from: {doc.keys()}")
return text
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
with fs.open(path) as f:
reader = jsonlines.Reader(f)
yield from reader
def main( def main(
@ -16,39 +61,32 @@ def main(
output_file: str, output_file: str,
patching_device: str = "cuda", patching_device: str = "cuda",
log_step: int = 10_000, log_step: int = 10_000,
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
bpe_tokenizer_path: str = "public_data/tokenizer.model",
dry_run: bool = False, dry_run: bool = False,
s3_profile: str | None = None,
): ):
# TODO: Modify this to work with the new code
raise NotImplementedError()
iterator = ArrowFileIterator(
file_path=input_file,
worker_id=0,
num_workers=1,
)
tokenization_mode = "bytes"
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
print("Loading entropy model", entropy_model_checkpoint_dir) print("Loading entropy model", entropy_model_checkpoint_dir)
input_fs = get_fs(input_file, s3_profile=s3_profile)
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
if dry_run: if dry_run:
return return
entropy_model = load_entropy_model( entropy_model = load_entropy_model(
entropy_model_checkpoint_dir, device=patching_device entropy_model_checkpoint_dir,
entropy_model_state_dict_path,
device=patching_device,
) )
entropy_model, _ = to_device(entropy_model, patching_device)
print("Creating patcher") print("Creating patcher")
patching_batch_size = 32 patching_batch_size = 32
print("Creating tokenizer") print("Creating tokenizer")
tokenizer = Tokenizer( tokenizer_args = TokenizerArgs(
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
tokenization_mode=tokenization_mode,
# BYTE_UNITS
vocab_size_unit_1=256,
bos=True,
eos=True,
bpe_delim=False,
# This isn't used, just stores a reference for other calls we don't use
patcher=None,
) )
tokenizer = tokenizer_args.build()
step = 0 step = 0
print("starting") print("starting")
start_time = time.time() start_time = time.time()
@ -59,8 +97,10 @@ def main(
schema = pa.schema([sample_id_field, text_field, entropy_field]) schema = pa.schema([sample_id_field, text_field, entropy_field])
arrow_batch_size = 1_000 arrow_batch_size = 1_000
output_fs = get_fs(output_file, s3_profile=s3_profile)
try: try:
with pa.OSFile(output_file, "wb") as sink: with output_fs.open(output_file, "wb") as sink:
with pa.ipc.new_file(sink, schema) as writer: with pa.ipc.new_file(sink, schema) as writer:
id_buffer = [] id_buffer = []
entropies_buffer = [] entropies_buffer = []
@ -72,17 +112,9 @@ def main(
task = progress.add_task( task = progress.add_task(
"[green]Calculating entropies...", total=None "[green]Calculating entropies...", total=None
) )
for doc in iterator: for doc in input_doc_iterator:
sample_id = get_id_from_doc(doc) sample_id = get_id_from_doc(doc)
text = get_text(doc)
if "text" in doc:
text = doc["text"]
elif "content" in doc:
text = doc["content"]
else:
raise ValueError(
f"Could not find a text key from: {doc.keys()}"
)
tokens = torch.tensor(tokenizer.encode(text)) tokens = torch.tensor(tokenizer.encode(text))
patch_start = time.time() patch_start = time.time()
scores = calculate_entropies( scores = calculate_entropies(
@ -128,9 +160,10 @@ def main(
entropies_buffer = [] entropies_buffer = []
id_buffer = [] id_buffer = []
text_buffer = [] text_buffer = []
Path(f"{output_file}.complete").touch() output_fs.touch(f"{output_file}.complete")
except: except:
Path(output_file).unlink(missing_ok=True) if output_fs.exists(output_file):
output_fs.rm(output_file)
raise raise
elapsed = time.time() - start_time elapsed = time.time() - start_time
print("steps", step) print("steps", step)

View file

@ -21,3 +21,4 @@ submitit
typer typer
rich rich
fsspec[full] fsspec[full]
orjson