From d718cfa9a105b37cd8c20ed47081f5d7bc4f0f9d Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Mon, 13 Jan 2025 23:26:26 +0000 Subject: [PATCH] Update preprocess_entropies script to blt inference + add fsspec support Summary: Test Plan: --- bytelatent/data/patcher.py | 8 +- bytelatent/preprocess/preprocess_entropies.py | 105 ++++++++++++------ requirements.txt | 1 + 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index ede8b06..f8477a3 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -82,16 +82,16 @@ def calculate_entropies( if device is not None: split = split.to(device) assert torch.all(split >= 0) and torch.all(split < 260) - pred, _ = entropy_model(split) + pred = entropy_model(split) pred = pred.reshape(-1, pred.shape[-1])[ : split.numel() - pad_size, : ] # [batch_size * seq_len, vocab] pred_entropies = entropy(pred) entropies.append(pred_entropies) - entropies = torch.cat(entropies, dim=0) - entropies = entropies.reshape(tokens.shape) - return entropies + concat_entropies = torch.cat(entropies, dim=0) + concat_entropies = concat_entropies.reshape(tokens.shape) + return concat_entropies def patch_start_mask_from_entropy_with_monotonicity(entropies, t): diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 20d1e0c..1c19a5a 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -1,14 +1,59 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import time -from pathlib import Path +import fsspec +import jsonlines import numpy as np import pyarrow as pa import torch import typer from rich.progress import Progress, TextColumn -from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import calculate_entropies +from bytelatent.entropy_model import load_entropy_model +from bytelatent.tokenizers.build_tokenizer import TokenizerArgs + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + if "sample_id" in doc: + sample_id = doc["sample_id"] + elif "title" in doc: + sample_id = doc["title"] + elif "qid" in doc: + sample_id = doc["qid"] + elif "paper_id" in doc: + sample_id = doc["paper_id"] + elif "path" in doc: + sample_id = doc["path"] + elif "url" in doc: + sample_id = doc["url"] + elif "id" in doc: + sample_id = doc["id"] + else: + raise ValueError(f"Could not find a id key from: {doc.keys()}") + return str(sample_id) + + +def get_text(doc: dict): + if "text" in doc: + text = doc["text"] + elif "content" in doc: + text = doc["content"] + else: + raise ValueError(f"Could not find a text key from: {doc.keys()}") + return text + + +def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str): + with fs.open(path) as f: + reader = jsonlines.Reader(f) + yield from reader def main( @@ -16,39 +61,32 @@ def main( output_file: str, patching_device: str = "cuda", log_step: int = 10_000, - entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir", + entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint", + entropy_model_state_dict_path: str = "public_data/entropy_model.pth", + bpe_tokenizer_path: str = "public_data/tokenizer.model", dry_run: bool = False, + s3_profile: str | None = None, ): - # TODO: Modify this to work with the new code - raise NotImplementedError() - iterator = ArrowFileIterator( - file_path=input_file, - worker_id=0, - num_workers=1, - ) - tokenization_mode = "bytes" print(f"Preprocessing entropies, input: {input_file}, output: {output_file}") print("Loading entropy model", entropy_model_checkpoint_dir) + input_fs = get_fs(input_file, s3_profile=s3_profile) + input_doc_iterator = jsonl_file_iterator(input_fs, input_file) + if dry_run: return entropy_model = load_entropy_model( - entropy_model_checkpoint_dir, device=patching_device + entropy_model_checkpoint_dir, + entropy_model_state_dict_path, + device=patching_device, ) - entropy_model, _ = to_device(entropy_model, patching_device) + print("Creating patcher") patching_batch_size = 32 print("Creating tokenizer") - tokenizer = Tokenizer( - model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model", - tokenization_mode=tokenization_mode, - # BYTE_UNITS - vocab_size_unit_1=256, - bos=True, - eos=True, - bpe_delim=False, - # This isn't used, just stores a reference for other calls we don't use - patcher=None, + tokenizer_args = TokenizerArgs( + name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path} ) + tokenizer = tokenizer_args.build() step = 0 print("starting") start_time = time.time() @@ -59,8 +97,10 @@ def main( schema = pa.schema([sample_id_field, text_field, entropy_field]) arrow_batch_size = 1_000 + output_fs = get_fs(output_file, s3_profile=s3_profile) + try: - with pa.OSFile(output_file, "wb") as sink: + with output_fs.open(output_file, "wb") as sink: with pa.ipc.new_file(sink, schema) as writer: id_buffer = [] entropies_buffer = [] @@ -72,17 +112,9 @@ def main( task = progress.add_task( "[green]Calculating entropies...", total=None ) - for doc in iterator: + for doc in input_doc_iterator: sample_id = get_id_from_doc(doc) - - if "text" in doc: - text = doc["text"] - elif "content" in doc: - text = doc["content"] - else: - raise ValueError( - f"Could not find a text key from: {doc.keys()}" - ) + text = get_text(doc) tokens = torch.tensor(tokenizer.encode(text)) patch_start = time.time() scores = calculate_entropies( @@ -128,9 +160,10 @@ def main( entropies_buffer = [] id_buffer = [] text_buffer = [] - Path(f"{output_file}.complete").touch() + output_fs.touch(f"{output_file}.complete") except: - Path(output_file).unlink(missing_ok=True) + if output_fs.exists(output_file): + output_fs.rm(output_file) raise elapsed = time.time() - start_time print("steps", step) diff --git a/requirements.txt b/requirements.txt index 59192cd..8490556 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ submitit typer rich fsspec[full] +orjson