Get generation working for BLT

Summary:

Create a script for simple generation from BLT

Test Plan:

```
python -m bytelatent.generate_blt config=../internal-blt/configs/eval_blt.yaml
```
This commit is contained in:
Pedro Rodriguez 2025-03-21 02:13:35 +00:00
parent 2dcf48bdd9
commit 0c09a840b5
6 changed files with 266 additions and 14 deletions

View file

@ -7,7 +7,7 @@ import numpy as np
import yaml import yaml
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, CheckpointArgs
from bytelatent.data.data_types import Batch from bytelatent.data.data_types import Batch
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.abstract_iterator import StatefulIterator
@ -270,8 +270,11 @@ class EvalArgs(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
dump_dir: str | None = None dump_dir: str | None = None
ckpt_dir: str | None = None ckpt_dir: str | None = None
entropy_ckpt_dir: str | None = None
metric_log_dir: str | None = None metric_log_dir: str | None = None
prompts: list[str] | None = None
run_ppl: bool = True run_ppl: bool = True
run_tasks: bool = False run_tasks: bool = False
@ -284,6 +287,8 @@ class EvalArgs(BaseModel):
global_step: int | None = None # for in-training evaluation global_step: int | None = None # for in-training evaluation
s3_profile: str | None = None s3_profile: str | None = None
consolidate_if_needed: bool = False
consolidate_folder: str = CONSOLIDATE_FOLDER
class TrainArgs(BaseModel): class TrainArgs(BaseModel):

View file

@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import math import math
import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
@ -476,7 +477,11 @@ class Patcher:
patcher_args.entropy_model_checkpoint_dir is not None patcher_args.entropy_model_checkpoint_dir is not None
), "Cannot require realtime patching without an entropy model checkpoint" ), "Cannot require realtime patching without an entropy model checkpoint"
entropy_model = load_entropy_model( entropy_model = load_entropy_model(
patcher_args.entropy_model_checkpoint_dir patcher_args.entropy_model_checkpoint_dir,
os.path.join(
patcher_args.entropy_model_checkpoint_dir,
"consolidated/consolidated.pth",
),
) )
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
self.entropy_model = entropy_model self.entropy_model = entropy_model

View file

@ -162,6 +162,12 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
return tensor return tensor
def dist_min(x: Union[int, float], mesh: DeviceMesh = None):
tensor = torch.tensor(x).cuda()
dist.all_reduce(tensor, op=ReduceOp.MIN, group=mesh.get_group() if mesh else None)
return tensor
def dist_sum( def dist_sum(
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
): ):

View file

@ -243,9 +243,20 @@ def launch_eval(eval_args: EvalArgs):
): ):
consolidate_path = eval_args.ckpt_dir consolidate_path = eval_args.ckpt_dir
else: else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) if eval_args.consolidate_if_needed:
if not fs.exists(consolidate_path) and get_global_rank() == 0: logger.info(
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) "Found a model checkpoint, but it has not been consolidated.... so consolidating the checkpoint"
)
consolidate_path = os.path.join(
eval_args.ckpt_dir, eval_args.consolidate_folder
)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
logger.info("Model consolidated to: %s", consolidate_path)
else:
raise ValueError(
"Did not find a consolidated checkpoint and consolidate_if_needed is False"
)
fs.mkdirs(eval_args.dump_dir, exist_ok=True) fs.mkdirs(eval_args.dump_dir, exist_ok=True)
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:

View file

@ -10,7 +10,7 @@ from torch.nn import functional as F
from torch.nn.attention.flex_attention import create_block_mask from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm from tqdm import tqdm
from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
from bytelatent.base_transformer import ( from bytelatent.base_transformer import (
Attention, Attention,
causal_mask, causal_mask,
@ -18,8 +18,14 @@ from bytelatent.base_transformer import (
lengths_to_local_ids, lengths_to_local_ids,
lengths_to_start_ids, lengths_to_start_ids,
) )
from bytelatent.checkpoint import CONSOLIDATE_NAME from bytelatent.checkpoint import (
CONSOLIDATE_FOLDER,
CONSOLIDATE_NAME,
consolidate_checkpoints,
)
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.distributed import get_global_rank
from bytelatent.model.blt import ByteLatentTransformer from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer from bytelatent.transformer import LMTransformer
@ -411,15 +417,25 @@ def load_consolidated_model_and_tokenizer(
def main(): def main():
# Load CLI arguments (overrides) and combine with a YAML config # Load CLI arguments (overrides) and combine with a YAML config
cfg = OmegaConf.from_cli() eval_args = parse_args_to_pydantic_model(EvalArgs)
gen_cfg = dataclass_from_dict(
PackedCausalTransformerGeneratorArgs, cfg, strict=False fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
if (
fs.exists(eval_args.ckpt_dir)
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
):
consolidate_path = eval_args.ckpt_dir
else:
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
if not fs.exists(consolidate_path) and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path
) )
print(cfg)
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt) generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
# Allow multiple prompts # Allow multiple prompts
prompts = [] prompts = []

209
bytelatent/generate_blt.py Normal file
View file

@ -0,0 +1,209 @@
import logging
import os
import torch
from bytelatent.args import EvalArgs
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.data.patcher import Patcher
from bytelatent.distributed import (
DistributedArgs,
dist_max,
dist_min,
dist_sum,
get_device_mesh,
setup_torch_distributed,
)
from bytelatent.generate import load_consolidated_model_and_tokenizer
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
logger = logging.getLogger()
def get_max_length(input_tokens: list[list[int]] | None) -> int:
# reduce max length prompt over all processes to have an equal number of call on each process with fsdp
if input_tokens is None:
max_length = 0
else:
max_length = max([len(t) for t in input_tokens])
if torch.distributed.is_initialized():
max_length = int(dist_max(max_length))
return max_length
def get_min_length(input_tokens: list[list[int]] | None) -> int:
# reduce min length prompt over all processes to have an equal number of call on each process with fsdp
if input_tokens is None:
# TODO: Double check this change from int(1e9) is correct
min_length = 0
else:
min_length = min([len(t) for t in input_tokens])
if torch.distributed.is_initialized():
min_length = int(dist_min(min_length))
return min_length
def get_generation_range(
prompt_tokens: list[list[int]] | None, max_gen_len: int
) -> tuple[int, int]:
batch_min_prompt_length = get_min_length(prompt_tokens)
batch_max_prompt_length = get_max_length(prompt_tokens)
return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
def sample_top_k(probs, k):
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
min_value_top_k = topk_value[:, [-1]]
probs[probs < min_value_top_k] = 0.0
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
@torch.inference_mode()
def generate_nocache(
prompts: list[str] | None,
*,
model: ByteLatentTransformer,
tokenizer: BltTokenizer,
patcher: Patcher,
max_prompt_len: int = 256,
max_gen_len: int = 256,
use_sampling: bool = False,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
remove_prompts: bool = True,
) -> list[list[int]]:
assert (
patcher.realtime_patching
), "generate_nocache requires patcher.realtime_patching=True"
model.eval()
if prompts is None:
prompt_tokens = None
n_truncated_prompts = 0
total_truncated_prompts = 0
else:
prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
total_truncated_prompts = dist_sum(n_truncated_prompts)
# Truncation
prompt_tokens = [
t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
for t in prompt_tokens
]
if total_truncated_prompts > 0:
logger.info(
f"There are {total_truncated_prompts} prompts that are truncated on the left, "
f"length greater than max_prompt_len = {max_prompt_len}, "
f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
)
if prompt_tokens is None:
prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)]
start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
batch_size = len(prompt_tokens)
tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()
# Copy inputs to tensor for generated tokens
for i, row_tokens in enumerate(prompt_tokens):
tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
input_text_mask = tokens != tokenizer.pad_id
for i, curr_pos in enumerate(range(start_pos, end_pos)):
current_tokens = tokens[:, :curr_pos]
patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
if use_sampling:
probs = torch.softmax(logits / temp, dim=-1)
if top_p > 0.0:
next_token = sample_top_p(probs, top_p)
elif top_k > 0:
next_token = sample_top_k(probs, top_k)
else:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = torch.where(
input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
)
tokens[:, curr_pos] = next_token
if remove_prompts:
generated_tokens = [
t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
for i, t in enumerate(tokens)
]
else:
generated_tokens = [
t[: len(prompt_tokens[i]) + max_gen_len].tolist()
for i, t in enumerate(tokens)
]
return generated_tokens
def launch_generate(eval_args: EvalArgs):
assert eval_args.dump_dir is not None
assert eval_args.ckpt_dir is not None
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized():
setup_torch_distributed(distributed_args)
world_mesh = get_device_mesh(distributed_args)
dp_mesh = world_mesh["dp_replicate"]
assert distributed_args.dp_shard == 1
world_size = dp_mesh.size()
world_rank = dp_mesh.get_local_rank()
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
if (
fs.exists(eval_args.ckpt_dir)
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
):
consolidate_path = eval_args.ckpt_dir
else:
raise ValueError("Did not find a consolidated checkpoint in the ckpt_dir")
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path,
)
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
patcher_args.entropy_model_checkpoint_dir = eval_args.entropy_ckpt_dir
patcher = patcher_args.build()
outputs = generate_nocache(
eval_args.prompts, model=model, tokenizer=tokenizer, patcher=patcher
)
text_outputs = [tokenizer.decode(t) for t in outputs]
for p, t in zip(eval_args.prompts, text_outputs):
print(f'Prompt: "{p}" Completion: "{t}"')
print()
def main():
eval_args = parse_args_to_pydantic_model(EvalArgs)
launch_generate(eval_args)
if __name__ == "__main__":
main()