mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-10 19:59:09 +00:00
Get generation working for BLT
Summary: Create a script for simple generation from BLT Test Plan: ``` python -m bytelatent.generate_blt config=../internal-blt/configs/eval_blt.yaml ```
This commit is contained in:
parent
2dcf48bdd9
commit
0c09a840b5
6 changed files with 266 additions and 14 deletions
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from bytelatent.checkpoint import CheckpointArgs
|
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, CheckpointArgs
|
||||||
from bytelatent.data.data_types import Batch
|
from bytelatent.data.data_types import Batch
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
|
@ -270,8 +270,11 @@ class EvalArgs(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
dump_dir: str | None = None
|
dump_dir: str | None = None
|
||||||
ckpt_dir: str | None = None
|
ckpt_dir: str | None = None
|
||||||
|
entropy_ckpt_dir: str | None = None
|
||||||
metric_log_dir: str | None = None
|
metric_log_dir: str | None = None
|
||||||
|
|
||||||
|
prompts: list[str] | None = None
|
||||||
|
|
||||||
run_ppl: bool = True
|
run_ppl: bool = True
|
||||||
run_tasks: bool = False
|
run_tasks: bool = False
|
||||||
|
|
||||||
|
@ -284,6 +287,8 @@ class EvalArgs(BaseModel):
|
||||||
|
|
||||||
global_step: int | None = None # for in-training evaluation
|
global_step: int | None = None # for in-training evaluation
|
||||||
s3_profile: str | None = None
|
s3_profile: str | None = None
|
||||||
|
consolidate_if_needed: bool = False
|
||||||
|
consolidate_folder: str = CONSOLIDATE_FOLDER
|
||||||
|
|
||||||
|
|
||||||
class TrainArgs(BaseModel):
|
class TrainArgs(BaseModel):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
@ -476,7 +477,11 @@ class Patcher:
|
||||||
patcher_args.entropy_model_checkpoint_dir is not None
|
patcher_args.entropy_model_checkpoint_dir is not None
|
||||||
), "Cannot require realtime patching without an entropy model checkpoint"
|
), "Cannot require realtime patching without an entropy model checkpoint"
|
||||||
entropy_model = load_entropy_model(
|
entropy_model = load_entropy_model(
|
||||||
patcher_args.entropy_model_checkpoint_dir
|
patcher_args.entropy_model_checkpoint_dir,
|
||||||
|
os.path.join(
|
||||||
|
patcher_args.entropy_model_checkpoint_dir,
|
||||||
|
"consolidated/consolidated.pth",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
|
||||||
self.entropy_model = entropy_model
|
self.entropy_model = entropy_model
|
||||||
|
|
|
@ -162,6 +162,12 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def dist_min(x: Union[int, float], mesh: DeviceMesh = None):
|
||||||
|
tensor = torch.tensor(x).cuda()
|
||||||
|
dist.all_reduce(tensor, op=ReduceOp.MIN, group=mesh.get_group() if mesh else None)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def dist_sum(
|
def dist_sum(
|
||||||
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
|
x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
|
||||||
):
|
):
|
||||||
|
|
|
@ -243,9 +243,20 @@ def launch_eval(eval_args: EvalArgs):
|
||||||
):
|
):
|
||||||
consolidate_path = eval_args.ckpt_dir
|
consolidate_path = eval_args.ckpt_dir
|
||||||
else:
|
else:
|
||||||
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
if eval_args.consolidate_if_needed:
|
||||||
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
logger.info(
|
||||||
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
"Found a model checkpoint, but it has not been consolidated.... so consolidating the checkpoint"
|
||||||
|
)
|
||||||
|
consolidate_path = os.path.join(
|
||||||
|
eval_args.ckpt_dir, eval_args.consolidate_folder
|
||||||
|
)
|
||||||
|
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||||
|
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||||
|
logger.info("Model consolidated to: %s", consolidate_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Did not find a consolidated checkpoint and consolidate_if_needed is False"
|
||||||
|
)
|
||||||
|
|
||||||
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
fs.mkdirs(eval_args.dump_dir, exist_ok=True)
|
||||||
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from torch.nn import functional as F
|
||||||
from torch.nn.attention.flex_attention import create_block_mask
|
from torch.nn.attention.flex_attention import create_block_mask
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
|
from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
|
||||||
from bytelatent.base_transformer import (
|
from bytelatent.base_transformer import (
|
||||||
Attention,
|
Attention,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
@ -18,8 +18,14 @@ from bytelatent.base_transformer import (
|
||||||
lengths_to_local_ids,
|
lengths_to_local_ids,
|
||||||
lengths_to_start_ids,
|
lengths_to_start_ids,
|
||||||
)
|
)
|
||||||
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
from bytelatent.checkpoint import (
|
||||||
|
CONSOLIDATE_FOLDER,
|
||||||
|
CONSOLIDATE_NAME,
|
||||||
|
consolidate_checkpoints,
|
||||||
|
)
|
||||||
|
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.distributed import get_global_rank
|
||||||
from bytelatent.model.blt import ByteLatentTransformer
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
from bytelatent.transformer import LMTransformer
|
from bytelatent.transformer import LMTransformer
|
||||||
|
@ -411,15 +417,25 @@ def load_consolidated_model_and_tokenizer(
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Load CLI arguments (overrides) and combine with a YAML config
|
# Load CLI arguments (overrides) and combine with a YAML config
|
||||||
cfg = OmegaConf.from_cli()
|
eval_args = parse_args_to_pydantic_model(EvalArgs)
|
||||||
gen_cfg = dataclass_from_dict(
|
|
||||||
PackedCausalTransformerGeneratorArgs, cfg, strict=False
|
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
||||||
|
if (
|
||||||
|
fs.exists(eval_args.ckpt_dir)
|
||||||
|
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
|
||||||
|
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
|
||||||
|
):
|
||||||
|
consolidate_path = eval_args.ckpt_dir
|
||||||
|
else:
|
||||||
|
consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
|
||||||
|
if not fs.exists(consolidate_path) and get_global_rank() == 0:
|
||||||
|
consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
|
||||||
|
|
||||||
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||||
|
consolidate_path
|
||||||
)
|
)
|
||||||
print(cfg)
|
|
||||||
|
|
||||||
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
|
generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
|
||||||
|
|
||||||
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
|
|
||||||
|
|
||||||
# Allow multiple prompts
|
# Allow multiple prompts
|
||||||
prompts = []
|
prompts = []
|
||||||
|
|
209
bytelatent/generate_blt.py
Normal file
209
bytelatent/generate_blt.py
Normal file
|
@ -0,0 +1,209 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from bytelatent.args import EvalArgs
|
||||||
|
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||||
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.data.patcher import Patcher
|
||||||
|
from bytelatent.distributed import (
|
||||||
|
DistributedArgs,
|
||||||
|
dist_max,
|
||||||
|
dist_min,
|
||||||
|
dist_sum,
|
||||||
|
get_device_mesh,
|
||||||
|
setup_torch_distributed,
|
||||||
|
)
|
||||||
|
from bytelatent.generate import load_consolidated_model_and_tokenizer
|
||||||
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_length(input_tokens: list[list[int]] | None) -> int:
|
||||||
|
# reduce max length prompt over all processes to have an equal number of call on each process with fsdp
|
||||||
|
if input_tokens is None:
|
||||||
|
max_length = 0
|
||||||
|
else:
|
||||||
|
max_length = max([len(t) for t in input_tokens])
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
max_length = int(dist_max(max_length))
|
||||||
|
return max_length
|
||||||
|
|
||||||
|
|
||||||
|
def get_min_length(input_tokens: list[list[int]] | None) -> int:
|
||||||
|
# reduce min length prompt over all processes to have an equal number of call on each process with fsdp
|
||||||
|
if input_tokens is None:
|
||||||
|
# TODO: Double check this change from int(1e9) is correct
|
||||||
|
min_length = 0
|
||||||
|
else:
|
||||||
|
min_length = min([len(t) for t in input_tokens])
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
min_length = int(dist_min(min_length))
|
||||||
|
return min_length
|
||||||
|
|
||||||
|
|
||||||
|
def get_generation_range(
|
||||||
|
prompt_tokens: list[list[int]] | None, max_gen_len: int
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
batch_min_prompt_length = get_min_length(prompt_tokens)
|
||||||
|
batch_max_prompt_length = get_max_length(prompt_tokens)
|
||||||
|
return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_k(probs, k):
|
||||||
|
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
|
||||||
|
min_value_top_k = topk_value[:, [-1]]
|
||||||
|
probs[probs < min_value_top_k] = 0.0
|
||||||
|
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_nocache(
|
||||||
|
prompts: list[str] | None,
|
||||||
|
*,
|
||||||
|
model: ByteLatentTransformer,
|
||||||
|
tokenizer: BltTokenizer,
|
||||||
|
patcher: Patcher,
|
||||||
|
max_prompt_len: int = 256,
|
||||||
|
max_gen_len: int = 256,
|
||||||
|
use_sampling: bool = False,
|
||||||
|
temp: float = 1.0,
|
||||||
|
top_k: int = 0,
|
||||||
|
top_p: float = 0.0,
|
||||||
|
remove_prompts: bool = True,
|
||||||
|
) -> list[list[int]]:
|
||||||
|
assert (
|
||||||
|
patcher.realtime_patching
|
||||||
|
), "generate_nocache requires patcher.realtime_patching=True"
|
||||||
|
model.eval()
|
||||||
|
if prompts is None:
|
||||||
|
prompt_tokens = None
|
||||||
|
n_truncated_prompts = 0
|
||||||
|
total_truncated_prompts = 0
|
||||||
|
else:
|
||||||
|
prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
|
||||||
|
n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
|
||||||
|
total_truncated_prompts = dist_sum(n_truncated_prompts)
|
||||||
|
|
||||||
|
# Truncation
|
||||||
|
prompt_tokens = [
|
||||||
|
t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
|
||||||
|
for t in prompt_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
if total_truncated_prompts > 0:
|
||||||
|
logger.info(
|
||||||
|
f"There are {total_truncated_prompts} prompts that are truncated on the left, "
|
||||||
|
f"length greater than max_prompt_len = {max_prompt_len}, "
|
||||||
|
f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_tokens is None:
|
||||||
|
prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)]
|
||||||
|
|
||||||
|
start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
|
||||||
|
batch_size = len(prompt_tokens)
|
||||||
|
tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()
|
||||||
|
|
||||||
|
# Copy inputs to tensor for generated tokens
|
||||||
|
for i, row_tokens in enumerate(prompt_tokens):
|
||||||
|
tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
|
||||||
|
input_text_mask = tokens != tokenizer.pad_id
|
||||||
|
|
||||||
|
for i, curr_pos in enumerate(range(start_pos, end_pos)):
|
||||||
|
current_tokens = tokens[:, :curr_pos]
|
||||||
|
patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
|
||||||
|
logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
|
||||||
|
|
||||||
|
if use_sampling:
|
||||||
|
probs = torch.softmax(logits / temp, dim=-1)
|
||||||
|
if top_p > 0.0:
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
elif top_k > 0:
|
||||||
|
next_token = sample_top_k(probs, top_k)
|
||||||
|
else:
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
|
next_token = torch.where(
|
||||||
|
input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
|
||||||
|
)
|
||||||
|
tokens[:, curr_pos] = next_token
|
||||||
|
|
||||||
|
if remove_prompts:
|
||||||
|
generated_tokens = [
|
||||||
|
t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
|
||||||
|
for i, t in enumerate(tokens)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
generated_tokens = [
|
||||||
|
t[: len(prompt_tokens[i]) + max_gen_len].tolist()
|
||||||
|
for i, t in enumerate(tokens)
|
||||||
|
]
|
||||||
|
return generated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def launch_generate(eval_args: EvalArgs):
|
||||||
|
assert eval_args.dump_dir is not None
|
||||||
|
assert eval_args.ckpt_dir is not None
|
||||||
|
distributed_args = DistributedArgs()
|
||||||
|
distributed_args.configure_world()
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
setup_torch_distributed(distributed_args)
|
||||||
|
|
||||||
|
world_mesh = get_device_mesh(distributed_args)
|
||||||
|
dp_mesh = world_mesh["dp_replicate"]
|
||||||
|
assert distributed_args.dp_shard == 1
|
||||||
|
world_size = dp_mesh.size()
|
||||||
|
world_rank = dp_mesh.get_local_rank()
|
||||||
|
|
||||||
|
fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
|
||||||
|
if (
|
||||||
|
fs.exists(eval_args.ckpt_dir)
|
||||||
|
and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
|
||||||
|
and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
|
||||||
|
):
|
||||||
|
consolidate_path = eval_args.ckpt_dir
|
||||||
|
else:
|
||||||
|
raise ValueError("Did not find a consolidated checkpoint in the ckpt_dir")
|
||||||
|
|
||||||
|
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
|
||||||
|
consolidate_path,
|
||||||
|
)
|
||||||
|
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
|
||||||
|
patcher_args.realtime_patching = True
|
||||||
|
patcher_args.entropy_model_checkpoint_dir = eval_args.entropy_ckpt_dir
|
||||||
|
patcher = patcher_args.build()
|
||||||
|
outputs = generate_nocache(
|
||||||
|
eval_args.prompts, model=model, tokenizer=tokenizer, patcher=patcher
|
||||||
|
)
|
||||||
|
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||||
|
for p, t in zip(eval_args.prompts, text_outputs):
|
||||||
|
print(f'Prompt: "{p}" Completion: "{t}"')
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
eval_args = parse_args_to_pydantic_model(EvalArgs)
|
||||||
|
launch_generate(eval_args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Reference in a new issue