mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
Update preprocess_entropies script to blt inference + add fsspec support
Summary: Test Plan:
This commit is contained in:
parent
b0120da72f
commit
d718cfa9a1
|
@ -82,16 +82,16 @@ def calculate_entropies(
|
||||||
if device is not None:
|
if device is not None:
|
||||||
split = split.to(device)
|
split = split.to(device)
|
||||||
assert torch.all(split >= 0) and torch.all(split < 260)
|
assert torch.all(split >= 0) and torch.all(split < 260)
|
||||||
pred, _ = entropy_model(split)
|
pred = entropy_model(split)
|
||||||
pred = pred.reshape(-1, pred.shape[-1])[
|
pred = pred.reshape(-1, pred.shape[-1])[
|
||||||
: split.numel() - pad_size, :
|
: split.numel() - pad_size, :
|
||||||
] # [batch_size * seq_len, vocab]
|
] # [batch_size * seq_len, vocab]
|
||||||
pred_entropies = entropy(pred)
|
pred_entropies = entropy(pred)
|
||||||
entropies.append(pred_entropies)
|
entropies.append(pred_entropies)
|
||||||
|
|
||||||
entropies = torch.cat(entropies, dim=0)
|
concat_entropies = torch.cat(entropies, dim=0)
|
||||||
entropies = entropies.reshape(tokens.shape)
|
concat_entropies = concat_entropies.reshape(tokens.shape)
|
||||||
return entropies
|
return concat_entropies
|
||||||
|
|
||||||
|
|
||||||
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
||||||
|
|
|
@ -1,14 +1,59 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
import fsspec
|
||||||
|
import jsonlines
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import typer
|
import typer
|
||||||
from rich.progress import Progress, TextColumn
|
from rich.progress import Progress, TextColumn
|
||||||
|
|
||||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.data.patcher import calculate_entropies
|
||||||
|
from bytelatent.entropy_model import load_entropy_model
|
||||||
|
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
||||||
|
|
||||||
|
|
||||||
|
def get_id_from_doc(doc: dict) -> int:
|
||||||
|
"""
|
||||||
|
We need a reliable way to ensure that samples from jsonl
|
||||||
|
and arrow are the same, but there is no unique id field,
|
||||||
|
so derive the best possible
|
||||||
|
"""
|
||||||
|
if "sample_id" in doc:
|
||||||
|
sample_id = doc["sample_id"]
|
||||||
|
elif "title" in doc:
|
||||||
|
sample_id = doc["title"]
|
||||||
|
elif "qid" in doc:
|
||||||
|
sample_id = doc["qid"]
|
||||||
|
elif "paper_id" in doc:
|
||||||
|
sample_id = doc["paper_id"]
|
||||||
|
elif "path" in doc:
|
||||||
|
sample_id = doc["path"]
|
||||||
|
elif "url" in doc:
|
||||||
|
sample_id = doc["url"]
|
||||||
|
elif "id" in doc:
|
||||||
|
sample_id = doc["id"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
||||||
|
return str(sample_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_text(doc: dict):
|
||||||
|
if "text" in doc:
|
||||||
|
text = doc["text"]
|
||||||
|
elif "content" in doc:
|
||||||
|
text = doc["content"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not find a text key from: {doc.keys()}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
|
||||||
|
with fs.open(path) as f:
|
||||||
|
reader = jsonlines.Reader(f)
|
||||||
|
yield from reader
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
|
@ -16,39 +61,32 @@ def main(
|
||||||
output_file: str,
|
output_file: str,
|
||||||
patching_device: str = "cuda",
|
patching_device: str = "cuda",
|
||||||
log_step: int = 10_000,
|
log_step: int = 10_000,
|
||||||
entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
|
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
|
||||||
|
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
|
||||||
|
bpe_tokenizer_path: str = "public_data/tokenizer.model",
|
||||||
dry_run: bool = False,
|
dry_run: bool = False,
|
||||||
|
s3_profile: str | None = None,
|
||||||
):
|
):
|
||||||
# TODO: Modify this to work with the new code
|
|
||||||
raise NotImplementedError()
|
|
||||||
iterator = ArrowFileIterator(
|
|
||||||
file_path=input_file,
|
|
||||||
worker_id=0,
|
|
||||||
num_workers=1,
|
|
||||||
)
|
|
||||||
tokenization_mode = "bytes"
|
|
||||||
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
||||||
print("Loading entropy model", entropy_model_checkpoint_dir)
|
print("Loading entropy model", entropy_model_checkpoint_dir)
|
||||||
|
input_fs = get_fs(input_file, s3_profile=s3_profile)
|
||||||
|
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
return
|
return
|
||||||
entropy_model = load_entropy_model(
|
entropy_model = load_entropy_model(
|
||||||
entropy_model_checkpoint_dir, device=patching_device
|
entropy_model_checkpoint_dir,
|
||||||
|
entropy_model_state_dict_path,
|
||||||
|
device=patching_device,
|
||||||
)
|
)
|
||||||
entropy_model, _ = to_device(entropy_model, patching_device)
|
|
||||||
print("Creating patcher")
|
print("Creating patcher")
|
||||||
patching_batch_size = 32
|
patching_batch_size = 32
|
||||||
print("Creating tokenizer")
|
print("Creating tokenizer")
|
||||||
tokenizer = Tokenizer(
|
tokenizer_args = TokenizerArgs(
|
||||||
model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
|
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
|
||||||
tokenization_mode=tokenization_mode,
|
|
||||||
# BYTE_UNITS
|
|
||||||
vocab_size_unit_1=256,
|
|
||||||
bos=True,
|
|
||||||
eos=True,
|
|
||||||
bpe_delim=False,
|
|
||||||
# This isn't used, just stores a reference for other calls we don't use
|
|
||||||
patcher=None,
|
|
||||||
)
|
)
|
||||||
|
tokenizer = tokenizer_args.build()
|
||||||
step = 0
|
step = 0
|
||||||
print("starting")
|
print("starting")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -59,8 +97,10 @@ def main(
|
||||||
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
||||||
arrow_batch_size = 1_000
|
arrow_batch_size = 1_000
|
||||||
|
|
||||||
|
output_fs = get_fs(output_file, s3_profile=s3_profile)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with pa.OSFile(output_file, "wb") as sink:
|
with output_fs.open(output_file, "wb") as sink:
|
||||||
with pa.ipc.new_file(sink, schema) as writer:
|
with pa.ipc.new_file(sink, schema) as writer:
|
||||||
id_buffer = []
|
id_buffer = []
|
||||||
entropies_buffer = []
|
entropies_buffer = []
|
||||||
|
@ -72,17 +112,9 @@ def main(
|
||||||
task = progress.add_task(
|
task = progress.add_task(
|
||||||
"[green]Calculating entropies...", total=None
|
"[green]Calculating entropies...", total=None
|
||||||
)
|
)
|
||||||
for doc in iterator:
|
for doc in input_doc_iterator:
|
||||||
sample_id = get_id_from_doc(doc)
|
sample_id = get_id_from_doc(doc)
|
||||||
|
text = get_text(doc)
|
||||||
if "text" in doc:
|
|
||||||
text = doc["text"]
|
|
||||||
elif "content" in doc:
|
|
||||||
text = doc["content"]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not find a text key from: {doc.keys()}"
|
|
||||||
)
|
|
||||||
tokens = torch.tensor(tokenizer.encode(text))
|
tokens = torch.tensor(tokenizer.encode(text))
|
||||||
patch_start = time.time()
|
patch_start = time.time()
|
||||||
scores = calculate_entropies(
|
scores = calculate_entropies(
|
||||||
|
@ -128,9 +160,10 @@ def main(
|
||||||
entropies_buffer = []
|
entropies_buffer = []
|
||||||
id_buffer = []
|
id_buffer = []
|
||||||
text_buffer = []
|
text_buffer = []
|
||||||
Path(f"{output_file}.complete").touch()
|
output_fs.touch(f"{output_file}.complete")
|
||||||
except:
|
except:
|
||||||
Path(output_file).unlink(missing_ok=True)
|
if output_fs.exists(output_file):
|
||||||
|
output_fs.rm(output_file)
|
||||||
raise
|
raise
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print("steps", step)
|
print("steps", step)
|
||||||
|
|
|
@ -21,3 +21,4 @@ submitit
|
||||||
typer
|
typer
|
||||||
rich
|
rich
|
||||||
fsspec[full]
|
fsspec[full]
|
||||||
|
orjson
|
||||||
|
|
Loading…
Reference in a new issue