diff --git a/bytelatent/args.py b/bytelatent/args.py index 6927f1c..8ee5a67 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -7,7 +7,7 @@ import numpy as np import yaml from pydantic import BaseModel, ConfigDict -from bytelatent.checkpoint import CheckpointArgs +from bytelatent.checkpoint import CONSOLIDATE_FOLDER, CheckpointArgs from bytelatent.data.data_types import Batch from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator @@ -270,8 +270,11 @@ class EvalArgs(BaseModel): model_config = ConfigDict(extra="forbid") dump_dir: str | None = None ckpt_dir: str | None = None + entropy_ckpt_dir: str | None = None metric_log_dir: str | None = None + prompts: list[str] | None = None + run_ppl: bool = True run_tasks: bool = False @@ -284,6 +287,8 @@ class EvalArgs(BaseModel): global_step: int | None = None # for in-training evaluation s3_profile: str | None = None + consolidate_if_needed: bool = False + consolidate_folder: str = CONSOLIDATE_FOLDER class TrainArgs(BaseModel): diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index 328b716..3e6fe12 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. import math +import os import time from collections import defaultdict from contextlib import nullcontext @@ -476,7 +477,11 @@ class Patcher: patcher_args.entropy_model_checkpoint_dir is not None ), "Cannot require realtime patching without an entropy model checkpoint" entropy_model = load_entropy_model( - patcher_args.entropy_model_checkpoint_dir + patcher_args.entropy_model_checkpoint_dir, + os.path.join( + patcher_args.entropy_model_checkpoint_dir, + "consolidated/consolidated.pth", + ), ) entropy_model, _ = to_device(entropy_model, patcher_args.patching_device) self.entropy_model = entropy_model diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 284c717..7c99380 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -162,6 +162,12 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None): return tensor +def dist_min(x: Union[int, float], mesh: DeviceMesh = None): + tensor = torch.tensor(x).cuda() + dist.all_reduce(tensor, op=ReduceOp.MIN, group=mesh.get_group() if mesh else None) + return tensor + + def dist_sum( x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None ): diff --git a/bytelatent/eval.py b/bytelatent/eval.py index 61e4a2d..2401fc0 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -243,9 +243,20 @@ def launch_eval(eval_args: EvalArgs): ): consolidate_path = eval_args.ckpt_dir else: - consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) - if not fs.exists(consolidate_path) and get_global_rank() == 0: - consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) + if eval_args.consolidate_if_needed: + logger.info( + "Found a model checkpoint, but it has not been consolidated.... so consolidating the checkpoint" + ) + consolidate_path = os.path.join( + eval_args.ckpt_dir, eval_args.consolidate_folder + ) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) + logger.info("Model consolidated to: %s", consolidate_path) + else: + raise ValueError( + "Did not find a consolidated checkpoint and consolidate_if_needed is False" + ) fs.mkdirs(eval_args.dump_dir, exist_ok=True) with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f: diff --git a/bytelatent/generate.py b/bytelatent/generate.py index 9d44f30..c76360e 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -10,7 +10,7 @@ from torch.nn import functional as F from torch.nn.attention.flex_attention import create_block_mask from tqdm import tqdm -from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs +from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( Attention, causal_mask, @@ -18,8 +18,14 @@ from bytelatent.base_transformer import ( lengths_to_local_ids, lengths_to_start_ids, ) -from bytelatent.checkpoint import CONSOLIDATE_NAME +from bytelatent.checkpoint import ( + CONSOLIDATE_FOLDER, + CONSOLIDATE_NAME, + consolidate_checkpoints, +) +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs +from bytelatent.distributed import get_global_rank from bytelatent.model.blt import ByteLatentTransformer from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.transformer import LMTransformer @@ -411,15 +417,25 @@ def load_consolidated_model_and_tokenizer( def main(): # Load CLI arguments (overrides) and combine with a YAML config - cfg = OmegaConf.from_cli() - gen_cfg = dataclass_from_dict( - PackedCausalTransformerGeneratorArgs, cfg, strict=False + eval_args = parse_args_to_pydantic_model(EvalArgs) + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) + if ( + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 + ): + consolidate_path = eval_args.ckpt_dir + else: + consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER) + if not fs.exists(consolidate_path) and get_global_rank() == 0: + consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir) + + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( + consolidate_path ) - print(cfg) - model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt) - - generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer) + generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer) # Allow multiple prompts prompts = [] diff --git a/bytelatent/generate_blt.py b/bytelatent/generate_blt.py new file mode 100644 index 0000000..ace4ecd --- /dev/null +++ b/bytelatent/generate_blt.py @@ -0,0 +1,209 @@ +import logging +import os + +import torch + +from bytelatent.args import EvalArgs +from bytelatent.config_parser import parse_args_to_pydantic_model +from bytelatent.data.file_util import get_fs +from bytelatent.data.patcher import Patcher +from bytelatent.distributed import ( + DistributedArgs, + dist_max, + dist_min, + dist_sum, + get_device_mesh, + setup_torch_distributed, +) +from bytelatent.generate import load_consolidated_model_and_tokenizer +from bytelatent.model.blt import ByteLatentTransformer +from bytelatent.tokenizers.blt_tokenizer import BltTokenizer + +logger = logging.getLogger() + + +def get_max_length(input_tokens: list[list[int]] | None) -> int: + # reduce max length prompt over all processes to have an equal number of call on each process with fsdp + if input_tokens is None: + max_length = 0 + else: + max_length = max([len(t) for t in input_tokens]) + if torch.distributed.is_initialized(): + max_length = int(dist_max(max_length)) + return max_length + + +def get_min_length(input_tokens: list[list[int]] | None) -> int: + # reduce min length prompt over all processes to have an equal number of call on each process with fsdp + if input_tokens is None: + # TODO: Double check this change from int(1e9) is correct + min_length = 0 + else: + min_length = min([len(t) for t in input_tokens]) + if torch.distributed.is_initialized(): + min_length = int(dist_min(min_length)) + return min_length + + +def get_generation_range( + prompt_tokens: list[list[int]] | None, max_gen_len: int +) -> tuple[int, int]: + batch_min_prompt_length = get_min_length(prompt_tokens) + batch_max_prompt_length = get_max_length(prompt_tokens) + return batch_min_prompt_length, batch_max_prompt_length + max_gen_len + + +def sample_top_k(probs, k): + topk_value, _ = torch.topk(probs, k) # batch_sz x topk + min_value_top_k = topk_value[:, [-1]] + probs[probs < min_value_top_k] = 0.0 + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + + +@torch.inference_mode() +def generate_nocache( + prompts: list[str] | None, + *, + model: ByteLatentTransformer, + tokenizer: BltTokenizer, + patcher: Patcher, + max_prompt_len: int = 256, + max_gen_len: int = 256, + use_sampling: bool = False, + temp: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + remove_prompts: bool = True, +) -> list[list[int]]: + assert ( + patcher.realtime_patching + ), "generate_nocache requires patcher.realtime_patching=True" + model.eval() + if prompts is None: + prompt_tokens = None + n_truncated_prompts = 0 + total_truncated_prompts = 0 + else: + prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] + n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens]) + total_truncated_prompts = dist_sum(n_truncated_prompts) + + # Truncation + prompt_tokens = [ + t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] + for t in prompt_tokens + ] + + if total_truncated_prompts > 0: + logger.info( + f"There are {total_truncated_prompts} prompts that are truncated on the left, " + f"length greater than max_prompt_len = {max_prompt_len}, " + f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus." + ) + + if prompt_tokens is None: + prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)] + + start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) + batch_size = len(prompt_tokens) + tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long() + + # Copy inputs to tensor for generated tokens + for i, row_tokens in enumerate(prompt_tokens): + tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long() + input_text_mask = tokens != tokenizer.pad_id + + for i, curr_pos in enumerate(range(start_pos, end_pos)): + current_tokens = tokens[:, :curr_pos] + patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True) + logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] + + if use_sampling: + probs = torch.softmax(logits / temp, dim=-1) + if top_p > 0.0: + next_token = sample_top_p(probs, top_p) + elif top_k > 0: + next_token = sample_top_k(probs, top_k) + else: + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1) + + next_token = torch.where( + input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token + ) + tokens[:, curr_pos] = next_token + + if remove_prompts: + generated_tokens = [ + t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() + for i, t in enumerate(tokens) + ] + else: + generated_tokens = [ + t[: len(prompt_tokens[i]) + max_gen_len].tolist() + for i, t in enumerate(tokens) + ] + return generated_tokens + + +def launch_generate(eval_args: EvalArgs): + assert eval_args.dump_dir is not None + assert eval_args.ckpt_dir is not None + distributed_args = DistributedArgs() + distributed_args.configure_world() + if not torch.distributed.is_initialized(): + setup_torch_distributed(distributed_args) + + world_mesh = get_device_mesh(distributed_args) + dp_mesh = world_mesh["dp_replicate"] + assert distributed_args.dp_shard == 1 + world_size = dp_mesh.size() + world_rank = dp_mesh.get_local_rank() + + fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile) + if ( + fs.exists(eval_args.ckpt_dir) + and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json")) + and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0 + ): + consolidate_path = eval_args.ckpt_dir + else: + raise ValueError("Did not find a consolidated checkpoint in the ckpt_dir") + + model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( + consolidate_path, + ) + patcher_args = train_cfg.data.patcher_args.model_copy(deep=True) + patcher_args.realtime_patching = True + patcher_args.entropy_model_checkpoint_dir = eval_args.entropy_ckpt_dir + patcher = patcher_args.build() + outputs = generate_nocache( + eval_args.prompts, model=model, tokenizer=tokenizer, patcher=patcher + ) + text_outputs = [tokenizer.decode(t) for t in outputs] + for p, t in zip(eval_args.prompts, text_outputs): + print(f'Prompt: "{p}" Completion: "{t}"') + print() + + +def main(): + eval_args = parse_args_to_pydantic_model(EvalArgs) + launch_generate(eval_args) + + +if __name__ == "__main__": + main()