diff --git a/bytelatent/args.py b/bytelatent/args.py
index 6927f1c..8ee5a67 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -7,7 +7,7 @@ import numpy as np
 import yaml
 from pydantic import BaseModel, ConfigDict
 
-from bytelatent.checkpoint import CheckpointArgs
+from bytelatent.checkpoint import CONSOLIDATE_FOLDER, CheckpointArgs
 from bytelatent.data.data_types import Batch
 from bytelatent.data.file_util import get_fs
 from bytelatent.data.iterators.abstract_iterator import StatefulIterator
@@ -270,8 +270,11 @@ class EvalArgs(BaseModel):
     model_config = ConfigDict(extra="forbid")
     dump_dir: str | None = None
     ckpt_dir: str | None = None
+    entropy_ckpt_dir: str | None = None
     metric_log_dir: str | None = None
 
+    prompts: list[str] | None = None
+
     run_ppl: bool = True
     run_tasks: bool = False
 
@@ -284,6 +287,8 @@ class EvalArgs(BaseModel):
 
     global_step: int | None = None  # for in-training evaluation
     s3_profile: str | None = None
+    consolidate_if_needed: bool = False
+    consolidate_folder: str = CONSOLIDATE_FOLDER
 
 
 class TrainArgs(BaseModel):
diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py
index 328b716..3e6fe12 100644
--- a/bytelatent/data/patcher.py
+++ b/bytelatent/data/patcher.py
@@ -1,5 +1,6 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 import math
+import os
 import time
 from collections import defaultdict
 from contextlib import nullcontext
@@ -476,7 +477,11 @@ class Patcher:
                 patcher_args.entropy_model_checkpoint_dir is not None
             ), "Cannot require realtime patching without an entropy model checkpoint"
             entropy_model = load_entropy_model(
-                patcher_args.entropy_model_checkpoint_dir
+                patcher_args.entropy_model_checkpoint_dir,
+                os.path.join(
+                    patcher_args.entropy_model_checkpoint_dir,
+                    "consolidated/consolidated.pth",
+                ),
             )
             entropy_model, _ = to_device(entropy_model, patcher_args.patching_device)
             self.entropy_model = entropy_model
diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py
index 284c717..7c99380 100644
--- a/bytelatent/distributed.py
+++ b/bytelatent/distributed.py
@@ -162,6 +162,12 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
     return tensor
 
 
+def dist_min(x: Union[int, float], mesh: DeviceMesh = None):
+    tensor = torch.tensor(x).cuda()
+    dist.all_reduce(tensor, op=ReduceOp.MIN, group=mesh.get_group() if mesh else None)
+    return tensor
+
+
 def dist_sum(
     x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
 ):
diff --git a/bytelatent/eval.py b/bytelatent/eval.py
index 61e4a2d..2401fc0 100644
--- a/bytelatent/eval.py
+++ b/bytelatent/eval.py
@@ -243,9 +243,20 @@ def launch_eval(eval_args: EvalArgs):
     ):
         consolidate_path = eval_args.ckpt_dir
     else:
-        consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
-        if not fs.exists(consolidate_path) and get_global_rank() == 0:
-            consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
+        if eval_args.consolidate_if_needed:
+            logger.info(
+                "Found a model checkpoint, but it has not been consolidated.... so consolidating the checkpoint"
+            )
+            consolidate_path = os.path.join(
+                eval_args.ckpt_dir, eval_args.consolidate_folder
+            )
+            if not fs.exists(consolidate_path) and get_global_rank() == 0:
+                consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
+            logger.info("Model consolidated to: %s", consolidate_path)
+        else:
+            raise ValueError(
+                "Did not find a consolidated checkpoint and consolidate_if_needed is False"
+            )
 
     fs.mkdirs(eval_args.dump_dir, exist_ok=True)
     with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
diff --git a/bytelatent/generate.py b/bytelatent/generate.py
index 9d44f30..c76360e 100644
--- a/bytelatent/generate.py
+++ b/bytelatent/generate.py
@@ -10,7 +10,7 @@ from torch.nn import functional as F
 from torch.nn.attention.flex_attention import create_block_mask
 from tqdm import tqdm
 
-from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
+from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs
 from bytelatent.base_transformer import (
     Attention,
     causal_mask,
@@ -18,8 +18,14 @@ from bytelatent.base_transformer import (
     lengths_to_local_ids,
     lengths_to_start_ids,
 )
-from bytelatent.checkpoint import CONSOLIDATE_NAME
+from bytelatent.checkpoint import (
+    CONSOLIDATE_FOLDER,
+    CONSOLIDATE_NAME,
+    consolidate_checkpoints,
+)
+from bytelatent.config_parser import parse_args_to_pydantic_model
 from bytelatent.data.file_util import get_fs
+from bytelatent.distributed import get_global_rank
 from bytelatent.model.blt import ByteLatentTransformer
 from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
 from bytelatent.transformer import LMTransformer
@@ -411,15 +417,25 @@ def load_consolidated_model_and_tokenizer(
 
 def main():
     # Load CLI arguments (overrides) and combine with a YAML config
-    cfg = OmegaConf.from_cli()
-    gen_cfg = dataclass_from_dict(
-        PackedCausalTransformerGeneratorArgs, cfg, strict=False
+    eval_args = parse_args_to_pydantic_model(EvalArgs)
+
+    fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
+    if (
+        fs.exists(eval_args.ckpt_dir)
+        and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
+        and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
+    ):
+        consolidate_path = eval_args.ckpt_dir
+    else:
+        consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
+        if not fs.exists(consolidate_path) and get_global_rank() == 0:
+            consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
+
+    model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
+        consolidate_path
     )
-    print(cfg)
 
-    model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
-
-    generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
+    generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
 
     # Allow multiple prompts
     prompts = []
diff --git a/bytelatent/generate_blt.py b/bytelatent/generate_blt.py
new file mode 100644
index 0000000..ace4ecd
--- /dev/null
+++ b/bytelatent/generate_blt.py
@@ -0,0 +1,209 @@
+import logging
+import os
+
+import torch
+
+from bytelatent.args import EvalArgs
+from bytelatent.config_parser import parse_args_to_pydantic_model
+from bytelatent.data.file_util import get_fs
+from bytelatent.data.patcher import Patcher
+from bytelatent.distributed import (
+    DistributedArgs,
+    dist_max,
+    dist_min,
+    dist_sum,
+    get_device_mesh,
+    setup_torch_distributed,
+)
+from bytelatent.generate import load_consolidated_model_and_tokenizer
+from bytelatent.model.blt import ByteLatentTransformer
+from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
+
+logger = logging.getLogger()
+
+
+def get_max_length(input_tokens: list[list[int]] | None) -> int:
+    # reduce max length prompt over all processes to have an equal number of call on each process with fsdp
+    if input_tokens is None:
+        max_length = 0
+    else:
+        max_length = max([len(t) for t in input_tokens])
+    if torch.distributed.is_initialized():
+        max_length = int(dist_max(max_length))
+    return max_length
+
+
+def get_min_length(input_tokens: list[list[int]] | None) -> int:
+    # reduce min length prompt over all processes to have an equal number of call on each process with fsdp
+    if input_tokens is None:
+        # TODO: Double check this change from int(1e9) is correct
+        min_length = 0
+    else:
+        min_length = min([len(t) for t in input_tokens])
+    if torch.distributed.is_initialized():
+        min_length = int(dist_min(min_length))
+    return min_length
+
+
+def get_generation_range(
+    prompt_tokens: list[list[int]] | None, max_gen_len: int
+) -> tuple[int, int]:
+    batch_min_prompt_length = get_min_length(prompt_tokens)
+    batch_max_prompt_length = get_max_length(prompt_tokens)
+    return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
+
+
+def sample_top_k(probs, k):
+    topk_value, _ = torch.topk(probs, k)  # batch_sz x topk
+    min_value_top_k = topk_value[:, [-1]]
+    probs[probs < min_value_top_k] = 0.0
+    probs.div_(probs.sum(dim=-1, keepdim=True))
+    next_token = torch.multinomial(probs, num_samples=1)
+    return next_token
+
+
+def sample_top_p(probs, p):
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort[mask] = 0.0
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = torch.multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token
+
+
+@torch.inference_mode()
+def generate_nocache(
+    prompts: list[str] | None,
+    *,
+    model: ByteLatentTransformer,
+    tokenizer: BltTokenizer,
+    patcher: Patcher,
+    max_prompt_len: int = 256,
+    max_gen_len: int = 256,
+    use_sampling: bool = False,
+    temp: float = 1.0,
+    top_k: int = 0,
+    top_p: float = 0.0,
+    remove_prompts: bool = True,
+) -> list[list[int]]:
+    assert (
+        patcher.realtime_patching
+    ), "generate_nocache requires patcher.realtime_patching=True"
+    model.eval()
+    if prompts is None:
+        prompt_tokens = None
+        n_truncated_prompts = 0
+        total_truncated_prompts = 0
+    else:
+        prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
+        n_truncated_prompts = sum([max_prompt_len < len(t) for t in prompt_tokens])
+        total_truncated_prompts = dist_sum(n_truncated_prompts)
+
+        # Truncation
+        prompt_tokens = [
+            t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
+            for t in prompt_tokens
+        ]
+
+    if total_truncated_prompts > 0:
+        logger.info(
+            f"There are {total_truncated_prompts} prompts that are truncated on the left, "
+            f"length greater than max_prompt_len = {max_prompt_len}, "
+            f"maximum prompt length = {get_max_length(prompt_tokens)} across all gpus."
+        )
+
+    if prompt_tokens is None:
+        prompt_tokens = [[tokenizer.bos_id] for _ in range(end_pos)]
+
+    start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
+    batch_size = len(prompt_tokens)
+    tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()
+
+    # Copy inputs to tensor for generated tokens
+    for i, row_tokens in enumerate(prompt_tokens):
+        tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
+    input_text_mask = tokens != tokenizer.pad_id
+
+    for i, curr_pos in enumerate(range(start_pos, end_pos)):
+        current_tokens = tokens[:, :curr_pos]
+        patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
+        logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
+
+        if use_sampling:
+            probs = torch.softmax(logits / temp, dim=-1)
+            if top_p > 0.0:
+                next_token = sample_top_p(probs, top_p)
+            elif top_k > 0:
+                next_token = sample_top_k(probs, top_k)
+            else:
+                next_token = torch.multinomial(probs, num_samples=1)
+        else:
+            next_token = torch.argmax(logits, dim=-1)
+
+        next_token = torch.where(
+            input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
+        )
+        tokens[:, curr_pos] = next_token
+
+    if remove_prompts:
+        generated_tokens = [
+            t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
+            for i, t in enumerate(tokens)
+        ]
+    else:
+        generated_tokens = [
+            t[: len(prompt_tokens[i]) + max_gen_len].tolist()
+            for i, t in enumerate(tokens)
+        ]
+    return generated_tokens
+
+
+def launch_generate(eval_args: EvalArgs):
+    assert eval_args.dump_dir is not None
+    assert eval_args.ckpt_dir is not None
+    distributed_args = DistributedArgs()
+    distributed_args.configure_world()
+    if not torch.distributed.is_initialized():
+        setup_torch_distributed(distributed_args)
+
+    world_mesh = get_device_mesh(distributed_args)
+    dp_mesh = world_mesh["dp_replicate"]
+    assert distributed_args.dp_shard == 1
+    world_size = dp_mesh.size()
+    world_rank = dp_mesh.get_local_rank()
+
+    fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
+    if (
+        fs.exists(eval_args.ckpt_dir)
+        and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
+        and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
+    ):
+        consolidate_path = eval_args.ckpt_dir
+    else:
+        raise ValueError("Did not find a consolidated checkpoint in the ckpt_dir")
+
+    model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
+        consolidate_path,
+    )
+    patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
+    patcher_args.realtime_patching = True
+    patcher_args.entropy_model_checkpoint_dir = eval_args.entropy_ckpt_dir
+    patcher = patcher_args.build()
+    outputs = generate_nocache(
+        eval_args.prompts, model=model, tokenizer=tokenizer, patcher=patcher
+    )
+    text_outputs = [tokenizer.decode(t) for t in outputs]
+    for p, t in zip(eval_args.prompts, text_outputs):
+        print(f'Prompt: "{p}" Completion: "{t}"')
+        print()
+
+
+def main():
+    eval_args = parse_args_to_pydantic_model(EvalArgs)
+    launch_generate(eval_args)
+
+
+if __name__ == "__main__":
+    main()