mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-06 12:39:04 +00:00
consolidated model file
This commit is contained in:
parent
4ae7a62594
commit
4f86b6e7ab
2 changed files with 2829 additions and 0 deletions
2618
blt_one_file.py
Normal file
2618
blt_one_file.py
Normal file
File diff suppressed because it is too large
Load diff
211
demo_hf.py
Normal file
211
demo_hf.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
#demo_hf.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from blt_one_file import ByteLatentTransformer, ByteLatentTransformerArgs
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
#generatel_blt_consolidated.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from blt_one_file import Patcher
|
||||||
|
from bytelatent.distributed import (
|
||||||
|
dist_max,
|
||||||
|
dist_min,
|
||||||
|
)
|
||||||
|
from blt_one_file import ByteLatentTransformer
|
||||||
|
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
def get_generation_range(
|
||||||
|
prompt_tokens: list[list[int]] | None, max_gen_len: int
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
batch_min_prompt_length = min([len(t) for t in prompt_tokens])
|
||||||
|
batch_max_prompt_length = max([len(t) for t in prompt_tokens])
|
||||||
|
return batch_min_prompt_length, batch_max_prompt_length + max_gen_len
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_k(probs, k):
|
||||||
|
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
|
||||||
|
min_value_top_k = topk_value[:, [-1]]
|
||||||
|
probs[probs < min_value_top_k] = 0.0
|
||||||
|
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate_nocache(
|
||||||
|
prompts: list[str] | None,
|
||||||
|
*,
|
||||||
|
model: ByteLatentTransformer,
|
||||||
|
tokenizer: BltTokenizer,
|
||||||
|
patcher: Patcher,
|
||||||
|
max_prompt_len: int = 256,
|
||||||
|
max_gen_len: int = 256,
|
||||||
|
use_sampling: bool = False,
|
||||||
|
temp: float = 1.0,
|
||||||
|
top_k: int = 0,
|
||||||
|
top_p: float = 0.0,
|
||||||
|
remove_prompts: bool = True,
|
||||||
|
) -> list[list[int]]:
|
||||||
|
assert (
|
||||||
|
patcher.realtime_patching
|
||||||
|
), "generate_nocache requires patcher.realtime_patching=True"
|
||||||
|
model.eval()
|
||||||
|
prompt_tokens = [tokenizer.encode(t, add_eos=False) for t in prompts]
|
||||||
|
# Truncation
|
||||||
|
prompt_tokens = [
|
||||||
|
t if len(t) < max_prompt_len else t[len(t) - max_prompt_len :]
|
||||||
|
for t in prompt_tokens
|
||||||
|
]
|
||||||
|
start_pos, end_pos = get_generation_range(prompt_tokens, max_gen_len)
|
||||||
|
batch_size = len(prompt_tokens)
|
||||||
|
tokens = torch.full((batch_size, end_pos), tokenizer.pad_id).cuda().long()
|
||||||
|
|
||||||
|
# Copy inputs to tensor for generated tokens
|
||||||
|
for i, row_tokens in enumerate(prompt_tokens):
|
||||||
|
tokens[i, : len(row_tokens)] = torch.tensor(row_tokens).long()
|
||||||
|
input_text_mask = tokens != tokenizer.pad_id
|
||||||
|
|
||||||
|
for i, curr_pos in enumerate(range(start_pos, end_pos)):
|
||||||
|
current_tokens = tokens[:, :curr_pos]
|
||||||
|
patch_lengths, _ = patcher.patch(current_tokens, include_next_token=True)
|
||||||
|
logits = model(current_tokens, patch_lengths=patch_lengths)[:, -1]
|
||||||
|
|
||||||
|
if use_sampling:
|
||||||
|
probs = torch.softmax(logits / temp, dim=-1)
|
||||||
|
if top_p > 0.0:
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
elif top_k > 0:
|
||||||
|
next_token = sample_top_k(probs, top_k)
|
||||||
|
else:
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits, dim=-1)
|
||||||
|
|
||||||
|
next_token = torch.where(
|
||||||
|
input_text_mask[:, curr_pos], tokens[:, curr_pos], next_token
|
||||||
|
)
|
||||||
|
tokens[:, curr_pos] = next_token
|
||||||
|
|
||||||
|
if remove_prompts:
|
||||||
|
generated_tokens = [
|
||||||
|
t[len(prompt_tokens[i]) : len(prompt_tokens[i]) + max_gen_len].tolist()
|
||||||
|
for i, t in enumerate(tokens)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
generated_tokens = [
|
||||||
|
t[: len(prompt_tokens[i]) + max_gen_len].tolist()
|
||||||
|
for i, t in enumerate(tokens)
|
||||||
|
]
|
||||||
|
return generated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main(prompt: str = "my name is", model_name: str = "blt-1b"):
|
||||||
|
# distributed_args = DistributedArgs()
|
||||||
|
# distributed_args.configure_world()
|
||||||
|
# if not torch.distributed.is_initialized():
|
||||||
|
# setup_torch_distributed(distributed_args)
|
||||||
|
|
||||||
|
# Set device and ensure CUDA is available
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA is required but not available")
|
||||||
|
device = torch.device("cuda")
|
||||||
|
torch.cuda.empty_cache() # Clear any existing CUDA memory
|
||||||
|
|
||||||
|
assert model_name in ["blt-1b", "blt-7b"]
|
||||||
|
model_name = model_name.replace("-", "_")
|
||||||
|
|
||||||
|
#HF
|
||||||
|
blt_repo = "facebook/blt-1b"
|
||||||
|
|
||||||
|
# Get the model's default configuration and entropy model params
|
||||||
|
print("Loading model configuration...")
|
||||||
|
config_path = hf_hub_download(repo_id=blt_repo, filename="config.json")
|
||||||
|
entropy_params_path = hf_hub_download(repo_id=blt_repo, filename="entropy_model/params.json")
|
||||||
|
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
with open(entropy_params_path, 'r') as f:
|
||||||
|
entropy_params = json.load(f)
|
||||||
|
|
||||||
|
# Create model args from config
|
||||||
|
model_args = ByteLatentTransformerArgs(**config["args"])
|
||||||
|
|
||||||
|
# Update patch parameters from entropy model params
|
||||||
|
patcher_args = entropy_params["data"]["patcher_args"]
|
||||||
|
model_args.patch_in_forward = True
|
||||||
|
model_args.patch_size = patcher_args["patch_size"]
|
||||||
|
model_args.patching_mode = patcher_args["patching_mode"]
|
||||||
|
model_args.patching_threshold = patcher_args["threshold"]
|
||||||
|
model_args.patching_threshold_add = patcher_args["threshold_add"]
|
||||||
|
model_args.max_patch_length = patcher_args["max_patch_length"]
|
||||||
|
model_args.patching_batch_size = patcher_args["patching_batch_size"]
|
||||||
|
model_args.patching_device = patcher_args["patching_device"]
|
||||||
|
model_args.monotonicity = patcher_args["monotonicity"]
|
||||||
|
|
||||||
|
# Load the model with updated arguments
|
||||||
|
print("Loading model with updated arguments...")
|
||||||
|
model = ByteLatentTransformer.from_pretrained(blt_repo, args=model_args).to(device)
|
||||||
|
|
||||||
|
# Configure model's patcher
|
||||||
|
model.patcher.realtime_patching = True
|
||||||
|
model.patcher.entropy_model_checkpoint_dir = os.path.join(
|
||||||
|
"hf-weights", "entropy_model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tokenizer
|
||||||
|
tokenizer = BltTokenizer(
|
||||||
|
vocab_size_unit_1=model_args.vocab_size,
|
||||||
|
add_bos=True,
|
||||||
|
add_eos=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate text
|
||||||
|
print("Generating text...")
|
||||||
|
prompts = [prompt]
|
||||||
|
outputs = generate_nocache(
|
||||||
|
prompts,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
patcher=model.patcher, # Use the model's patcher
|
||||||
|
max_gen_len=100
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode and print results
|
||||||
|
text_outputs = [tokenizer.decode(t) for t in outputs]
|
||||||
|
for p, t in zip(prompts, text_outputs):
|
||||||
|
print(f'Prompt: "{p}"')
|
||||||
|
print(f'Completion: "{t}"')
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
|
|
Loading…
Add table
Reference in a new issue