#demo_hf.py import os import torch import typer from blt_one_file import ByteLatentTransformer, ByteLatentTransformerArgs from bytelatent.tokenizers.blt_tokenizer import BltTokenizer from huggingface_hub import hf_hub_download import json #generatel_blt_consolidated.py import logging import os import torch from blt_one_file import Patcher from bytelatent.distributed import ( dist_max, dist_min, ) from blt_one_file import ByteLatentTransformer from bytelatent.tokenizers.blt_tokenizer import BltTokenizer logger = logging.getLogger() def get_generation_range( prompt_tokens: list[list[int]] | None, max_gen_len: int ) -> tuple[int, int]: batch_min_prompt_length = min([len(t) for t in prompt_tokens]) batch_max_prompt_length = max([len(t) for t in prompt_tokens]) return batch_min_prompt_length, batch_max_prompt_length + max_gen_len def sample_top_k(probs, k): topk_value, _ = torch.topk(probs, k) # batch_sz x topk min_value_top_k = topk_value[:, [-1]] probs[probs < min_value_top_k] = 0.0 probs.div_(probs.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs, num_samples=1) return next_token def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token @torch.inference_mode() def generate( prompts: list[str] | None, *, model: ByteLatentTransformer, tokenizer: BltTokenizer, patcher: Patcher, max_prompt_len: int = 256, max_gen_len: int = 256, use_sampling: bool = False, temp: float = 1.0, top_k: int = 0, top_p: float = 0.0, remove_prompts: bool = True, ) -> list[list[int]]: assert ( patcher.realtime_patching ), "generate_nocache requires patcher.realtime_patching=True" model.eval() prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts] # Truncation prompt_tokens = [ t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :] for t in prompt_tokens ] start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len) batch_size = len(prompt_tokens) tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long() # Copy inputs to tensor for generated tokens for i, row_tokens in enumerate(prompt_tokens): tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long() input_text_mask = tokens != tokenizer.pad_id for i, curr_pos in enumerate(range(start_pos, end_pos)): current_tokens = tokens[:, :curr_pos] patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True) logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1] if use_sampling: probs = torch.softmax(logits / temp, dim=-1) if top_p > 0.0: next_token = sample_top_p(probs, top_p) elif top_k > 0: next_token = sample_top_k(probs, top_k) else: next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1) next_token = torch.where( input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token ) tokens[:, curr_pos] = next_token if remove_prompts: generated_tokens = [ t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens) ] else: generated_tokens = [ t[: len(prompt_tokens[i]) + max_gen_len].tolist() for i, t in enumerate(tokens) ] return generated_tokens def main(prompt: str = "my name is", model_name: str = "blt-1b"): # distributed_args = DistributedArgs() # distributed_args.configure_world() # if not torch.distributed.is_initialized(): # setup_torch_distributed(distributed_args) # Set device and ensure CUDA is available if not torch.cuda.is_available(): raise RuntimeError("CUDA is required but not available") device = torch.device("cuda") torch.cuda.empty_cache() # Clear any existing CUDA memory assert model_name in ["blt-1b", "blt-7b"] model_name = model_name.replace("-", "_") #HF blt_repo = "facebook/blt-1b" # Get the model's default configuration and entropy model params print("Loading model configuration...") config_path = hf_hub_download(repo_id=blt_repo, filename="config.json") entropy_params_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/params.json") with open(config_path, 'r') as f: config = json.load(f) with open(entropy_params_path, 'r') as f: entropy_params = json.load(f) # Create model args from config model_args = ByteLatentTransformerArgs(**config["args"]) # Update patch parameters from entropy model params patcher_args = entropy_params["data"]["patcher_args"] model_args.patch_in_forward = True model_args.patch_size = patcher_args["patch_size"] model_args.patching_mode = patcher_args["patching_mode"] model_args.patching_threshold = patcher_args["threshold"] model_args.patching_threshold_add = patcher_args["threshold_add"] model_args.max_patch_length = patcher_args["max_patch_length"] model_args.patching_batch_size = patcher_args["patching_batch_size"] model_args.patching_device = patcher_args["patching_device"] model_args.monotonicity = patcher_args["monotonicity"] # Load the model with updated arguments print("Loading model with updated arguments...") model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device) # Configure model's patcher model.patcher.realtime_patching = True model.patcher.entropy_model_checkpoint_dir = os.path.join( "hf-weights", "entropy_model" ) # Create tokenizer tokenizer = BltTokenizer( vocab_size_unit_1=model_args.vocab_size, add_bos=True, add_eos=True ) # Generate text print("Generating text...") outputs = generate( [prompt], model=model, tokenizer=tokenizer, patcher=model.patcher, # Use the model's patcher max_gen_len=100 ) # Decode and print results text_outputs = [tokenizer.decode(t) for t in outputs] for p, t in zip([prompt], text_outputs): print(f'Prompt: "{p}"') print(f'Completion: "{t}"') print() # Clean up torch.cuda.empty_cache() if __name__ == "__main__": typer.run(main)