mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 08:27:45 +00:00
Update preprocess_entropies script to blt inference + add fsspec support
Summary: Test Plan:
This commit is contained in:
parent
b0120da72f
commit
d718cfa9a1
|
@ -82,16 +82,16 @@ def calculate_entropies(
|
|||
if device is not None:
|
||||
split = split.to(device)
|
||||
assert torch.all(split >= 0) and torch.all(split < 260)
|
||||
pred, _ = entropy_model(split)
|
||||
pred = entropy_model(split)
|
||||
pred = pred.reshape(-1, pred.shape[-1])[
|
||||
: split.numel() - pad_size, :
|
||||
] # [batch_size * seq_len, vocab]
|
||||
pred_entropies = entropy(pred)
|
||||
entropies.append(pred_entropies)
|
||||
|
||||
entropies = torch.cat(entropies, dim=0)
|
||||
entropies = entropies.reshape(tokens.shape)
|
||||
return entropies
|
||||
concat_entropies = torch.cat(entropies, dim=0)
|
||||
concat_entropies = concat_entropies.reshape(tokens.shape)
|
||||
return concat_entropies
|
||||
|
||||
|
||||
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||
|
|
|
@ -1,14 +1,59 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import fsspec
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import typer
|
||||
from rich.progress import Progress, TextColumn
|
||||
|
||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.patcher import calculate_entropies
|
||||
from bytelatent.entropy_model import load_entropy_model
|
||||
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||
|
||||
|
||||
def get_id_from_doc(doc: dict) -> int:
|
||||
"""
|
||||
We need a reliable way to ensure that samples from jsonl
|
||||
and arrow are the same, but there is no unique id field,
|
||||
so derive the best possible
|
||||
"""
|
||||
if "sample_id" in doc:
|
||||
sample_id = doc["sample_id"]
|
||||
elif "title" in doc:
|
||||
sample_id = doc["title"]
|
||||
elif "qid" in doc:
|
||||
sample_id = doc["qid"]
|
||||
elif "paper_id" in doc:
|
||||
sample_id = doc["paper_id"]
|
||||
elif "path" in doc:
|
||||
sample_id = doc["path"]
|
||||
elif "url" in doc:
|
||||
sample_id = doc["url"]
|
||||
elif "id" in doc:
|
||||
sample_id = doc["id"]
|
||||
else:
|
||||
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||
return str(sample_id)
|
||||
|
||||
|
||||
def get_text(doc: dict):
|
||||
if "text" in doc:
|
||||
text = doc["text"]
|
||||
elif "content" in doc:
|
||||
text = doc["content"]
|
||||
else:
|
||||
raise ValueError(f"Could not find a text key from: {doc.keys()}")
|
||||
return text
|
||||
|
||||
|
||||
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
|
||||
with fs.open(path) as f:
|
||||
reader = jsonlines.Reader(f)
|
||||
yield from reader
|
||||
|
||||
|
||||
def main(
|
||||
|
@ -16,39 +61,32 @@ def main(
|
|||
output_file: str,
|
||||
patching_device: str = "cuda",
|
||||
log_step: int = 10_000,
|
||||
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
|
||||
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
|
||||
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
|
||||
bpe_tokenizer_path: str = "public_data/tokenizer.model",
|
||||
dry_run: bool = False,
|
||||
s3_profile: str | None = None,
|
||||
):
|
||||
# TODO: Modify this to work with the new code
|
||||
raise NotImplementedError()
|
||||
iterator = ArrowFileIterator(
|
||||
file_path=input_file,
|
||||
worker_id=0,
|
||||
num_workers=1,
|
||||
)
|
||||
tokenization_mode = "bytes"
|
||||
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
||||
print("Loading entropy model", entropy_model_checkpoint_dir)
|
||||
input_fs = get_fs(input_file, s3_profile=s3_profile)
|
||||
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
|
||||
|
||||
if dry_run:
|
||||
return
|
||||
entropy_model = load_entropy_model(
|
||||
entropy_model_checkpoint_dir, device=patching_device
|
||||
entropy_model_checkpoint_dir,
|
||||
entropy_model_state_dict_path,
|
||||
device=patching_device,
|
||||
)
|
||||
entropy_model, _ = to_device(entropy_model, patching_device)
|
||||
|
||||
print("Creating patcher")
|
||||
patching_batch_size = 32
|
||||
print("Creating tokenizer")
|
||||
tokenizer = Tokenizer(
|
||||
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
|
||||
tokenization_mode=tokenization_mode,
|
||||
# BYTE_UNITS
|
||||
vocab_size_unit_1=256,
|
||||
bos=True,
|
||||
eos=True,
|
||||
bpe_delim=False,
|
||||
# This isn't used, just stores a reference for other calls we don't use
|
||||
patcher=None,
|
||||
tokenizer_args = TokenizerArgs(
|
||||
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
|
||||
)
|
||||
tokenizer = tokenizer_args.build()
|
||||
step = 0
|
||||
print("starting")
|
||||
start_time = time.time()
|
||||
|
@ -59,8 +97,10 @@ def main(
|
|||
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
||||
arrow_batch_size = 1_000
|
||||
|
||||
output_fs = get_fs(output_file, s3_profile=s3_profile)
|
||||
|
||||
try:
|
||||
with pa.OSFile(output_file, "wb") as sink:
|
||||
with output_fs.open(output_file, "wb") as sink:
|
||||
with pa.ipc.new_file(sink, schema) as writer:
|
||||
id_buffer = []
|
||||
entropies_buffer = []
|
||||
|
@ -72,17 +112,9 @@ def main(
|
|||
task = progress.add_task(
|
||||
"[green]Calculating entropies...", total=None
|
||||
)
|
||||
for doc in iterator:
|
||||
for doc in input_doc_iterator:
|
||||
sample_id = get_id_from_doc(doc)
|
||||
|
||||
if "text" in doc:
|
||||
text = doc["text"]
|
||||
elif "content" in doc:
|
||||
text = doc["content"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not find a text key from: {doc.keys()}"
|
||||
)
|
||||
text = get_text(doc)
|
||||
tokens = torch.tensor(tokenizer.encode(text))
|
||||
patch_start = time.time()
|
||||
scores = calculate_entropies(
|
||||
|
@ -128,9 +160,10 @@ def main(
|
|||
entropies_buffer = []
|
||||
id_buffer = []
|
||||
text_buffer = []
|
||||
Path(f"{output_file}.complete").touch()
|
||||
output_fs.touch(f"{output_file}.complete")
|
||||
except:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
if output_fs.exists(output_file):
|
||||
output_fs.rm(output_file)
|
||||
raise
|
||||
elapsed = time.time() - start_time
|
||||
print("steps", step)
|
||||
|
|
|
@ -21,3 +21,4 @@ submitit
|
|||
typer
|
||||
rich
|
||||
fsspec[full]
|
||||
orjson
|
||||
|
|
Loading…
Reference in a new issue