mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-19 08:57:45 +00:00
464 lines
17 KiB
Python
464 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from lingua.args import dataclass_from_dict
|
|
from lingua.tokenizers.abstract_tokenizer import Tokenizer
|
|
from lingua.tokenizers.build_tokenizer import build_tokenizer
|
|
from omegaconf import OmegaConf
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.nn.attention.flex_attention import create_block_mask
|
|
from tqdm import tqdm
|
|
|
|
from bytelatent.base_transformer import (
|
|
Attention,
|
|
causal_mask,
|
|
generate_doc_mask_mod,
|
|
lengths_to_local_ids,
|
|
lengths_to_start_ids,
|
|
)
|
|
from bytelatent.checkpoint import CONSOLIDATE_NAME
|
|
from bytelatent.transformer import LMTransformer, LMTransformerArgs
|
|
|
|
|
|
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
mask = probs_sum - probs_sort > p
|
|
probs_sort[mask] = 0.0
|
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
next_token = torch.gather(probs_idx, -1, next_token)
|
|
return next_token
|
|
|
|
|
|
def sample_top_k(probs, k):
|
|
topk_value, _ = torch.topk(probs, k) # batch_sz x topk
|
|
min_value_top_k = topk_value[:, [-1]]
|
|
probs[probs < min_value_top_k] = 0.0
|
|
probs.div_(probs.sum(dim=-1, keepdim=True))
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
return next_token
|
|
|
|
|
|
def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
|
|
shape = logits.shape
|
|
logits = logits.flatten(end_dim=-2)
|
|
if temperature > 0.0:
|
|
probs = torch.softmax(logits / temperature, dim=-1)
|
|
|
|
if top_p is not None:
|
|
next_token = sample_top_p(probs, top_p)
|
|
elif top_k is not None:
|
|
next_token = sample_top_k(probs, top_k)
|
|
else:
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
else:
|
|
next_token = torch.argmax(logits, dim=-1)
|
|
return next_token.view(shape[:-1])
|
|
|
|
|
|
def pack_prompts(prompts: List[int]):
|
|
res = []
|
|
lengths = []
|
|
for i, p in enumerate(prompts):
|
|
p = torch.tensor(p, dtype=torch.long)
|
|
l = p.size(0)
|
|
res.append(p)
|
|
lengths.append(l)
|
|
lengths = torch.tensor(lengths, dtype=torch.long)
|
|
res = torch.cat(res)
|
|
return res, lengths
|
|
|
|
|
|
def batch_prompts(prompts, max_elements, lengths=None):
|
|
batches = []
|
|
current_batch = []
|
|
current_count = 0
|
|
|
|
for i in range(len(prompts)):
|
|
prt = prompts[i]
|
|
prompt_size = len(prt) if lengths is None else lengths[i]
|
|
if current_count + prompt_size <= max_elements:
|
|
current_batch.append(prt)
|
|
current_count += prompt_size
|
|
else:
|
|
if current_batch: # Add the current batch to batches
|
|
batches.append(current_batch)
|
|
# Start a new batch with the current prompt
|
|
current_batch = [prt]
|
|
current_count = prompt_size
|
|
|
|
# Add the last batch if it contains any prompts
|
|
if current_batch:
|
|
batches.append(current_batch)
|
|
|
|
return batches
|
|
|
|
|
|
class KVCache(nn.Module):
|
|
def __init__(self, bsz, seqlen, n_heads, head_dim, dtype, device):
|
|
super().__init__()
|
|
shape = (bsz, seqlen, n_heads, head_dim)
|
|
self.register_buffer("k_cache", torch.zeros(shape, dtype=dtype, device=device))
|
|
self.register_buffer("v_cache", torch.zeros(shape, dtype=dtype, device=device))
|
|
self.offset = 0
|
|
|
|
def reset(self):
|
|
self.k_cache.zero_()
|
|
self.v_cache.zero_()
|
|
self.offset = 0
|
|
|
|
def update(self, k_val, v_val, tok_idx):
|
|
# input_pos: [B], k_val: [B, S, H, D]
|
|
self.k_cache.index_copy_(1, self.offset + tok_idx, k_val)
|
|
self.v_cache.index_copy_(1, self.offset + tok_idx, v_val)
|
|
return self.k_cache, self.v_cache
|
|
|
|
|
|
@dataclass
|
|
class PackedCausalTransformerGeneratorArgs:
|
|
temperature: float = 0.0
|
|
top_p: Optional[float] = None
|
|
top_k: Optional[float] = None
|
|
max_gen_len: int = 512 # Maximum number of tokens to generate
|
|
max_tokens: int = 1024 # Maximum number of tokens that can go through the model
|
|
max_prompt_len: Optional[int] = None
|
|
until: List[str] = field(default_factory=list)
|
|
compile_prefilling: bool = False
|
|
reduce_generation_overhead: bool = False
|
|
show_progress: bool = False
|
|
dtype: Optional[str] = "bf16"
|
|
device: Optional[str] = "cuda"
|
|
|
|
|
|
class PackedCausalTransformerGenerator:
|
|
def __init__(
|
|
self,
|
|
cfg: PackedCausalTransformerGeneratorArgs,
|
|
model: nn.Module,
|
|
tokenizer: Tokenizer,
|
|
):
|
|
"""
|
|
This class wraps a causal transformer model with its corresponding tokenizer
|
|
and provides an efficient way to pack prompts together and do generation on
|
|
the packed sequence.
|
|
|
|
For example, if we had the prompts "Hello, I am a " and "Initiating calibration "
|
|
Then this class will concatenate those sequence (pack them together)
|
|
"Hello, I am a Initiating calibration"
|
|
And make the necessary attention masks such that a sequence only attends to itself
|
|
during prefilling and generation.
|
|
|
|
This class creates a fixed size cache of size max_tokens or sum of prompt sizes
|
|
+ the max number of generated tokens per sequence.
|
|
"""
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.temperature = cfg.temperature
|
|
self.top_p = cfg.top_p
|
|
self.top_k = cfg.top_k
|
|
|
|
self.max_gen_len = cfg.max_gen_len
|
|
self.max_tokens = cfg.max_tokens
|
|
self.max_prompt_len = cfg.max_prompt_len
|
|
self.until = cfg.until
|
|
self.max_until_size = max([len(e) for e in self.until]) if self.until else 1
|
|
self.device = cfg.device
|
|
|
|
# Compile if necessary
|
|
self.prefill = torch.compile(self.prefill, disable=not cfg.compile_prefilling)
|
|
self.generate_next_token = torch.compile(
|
|
self.generate_next_token,
|
|
mode="reduce-overhead",
|
|
disable=not cfg.reduce_generation_overhead,
|
|
)
|
|
|
|
self.show_progress = cfg.show_progress
|
|
self.dtype = dict(fp32=torch.float32, bf16=torch.bfloat16)[cfg.dtype]
|
|
|
|
self.prefill_doc_id, self.prefill_tok_id = None, None
|
|
self.padded_doc_id, self.padded_tok_id = None, None
|
|
self.current_doc_id, self.current_tok_id = None, None
|
|
self.padded_doc_start = None
|
|
self.prefill_mask = None
|
|
|
|
def clear_cache(self, offset):
|
|
for module in self.model.modules():
|
|
if isinstance(module, Attention):
|
|
if not hasattr(module, "kv_cache"):
|
|
module.kv_cache = KVCache(
|
|
1,
|
|
self.max_tokens,
|
|
module.n_kv_heads,
|
|
module.head_dim,
|
|
self.dtype,
|
|
self.device,
|
|
)
|
|
module.kv_cache.offset = offset
|
|
|
|
@torch.compiler.disable
|
|
def setup_prefilling(self, lengths: torch.Tensor):
|
|
# The KV cache is a fixed size tensor of size max_tokens that we need
|
|
# to update in order to do correct autoregressive generation.
|
|
|
|
# Here we will generate token by token but on multiple sequences
|
|
# at once. To do so, we need to have an attention mask that makes
|
|
# each sequence independent.
|
|
|
|
# Each sequence will write to its allocated space in the KV Cache.
|
|
# We allocate len(seq) + max_gen_len to each sequence in the cache.
|
|
|
|
# We will generate max_gen_len for each document
|
|
padded_lengths = lengths + self.max_gen_len
|
|
max_tokens = self.max_tokens or padded_lengths.sum().item()
|
|
# The last document might have more padding to fill up to max_tokens
|
|
padded_lengths[-1] += max_tokens - padded_lengths.sum()
|
|
|
|
# This is the start index in the cache for each document
|
|
self.padded_doc_start = lengths_to_start_ids(padded_lengths)
|
|
# For example with ab--123--cdef--
|
|
# this would be 0, 4, 9 if max_gen_len is 2
|
|
|
|
# We repeat interleave to align with tokens for prefilling
|
|
# Ex: ab--123--cdef--
|
|
# 000044444999999
|
|
prefill_offset = torch.repeat_interleave(self.padded_doc_start, lengths)
|
|
# This offset will make sure the tokens are written to the
|
|
# correct positions in the cache during prefilling
|
|
|
|
# We either init the cache or clear it by resetting the offset to prefill_offset
|
|
self.clear_cache(prefill_offset)
|
|
|
|
# The prefilling mask looks like the following for
|
|
# the two packed sequences ab and 123 : ab123
|
|
# Where spaces are empty cache positions
|
|
# keys
|
|
# ab---123---
|
|
# queries a 10000000000
|
|
# b 11000000000
|
|
# 1 00000100000
|
|
# 2 00000110000
|
|
# 3 00000111000
|
|
# We make sure to skip the empty cache positions
|
|
# and only attend to positions within the same sequence
|
|
doc_mask_mod = generate_doc_mask_mod(causal_mask, lengths, padded_lengths)
|
|
self.prefill_mask = create_block_mask(
|
|
doc_mask_mod, 1, None, lengths.sum(), max_tokens
|
|
)
|
|
|
|
# This creates the prefilling token ids which look like
|
|
# the following for the packed sequence abcdefg1234
|
|
# abcdefg1234
|
|
# 01234560123
|
|
# The token id gives us the position within each sequence
|
|
# This is used to compute ROPE and to update the cache
|
|
# At each forward pass the current tokens are written to
|
|
# offset + tok_id
|
|
self.prefill_doc_id, self.prefill_tok_id = lengths_to_local_ids(lengths)
|
|
|
|
# This creates the padded token and document ids
|
|
# which look like the following for the packed sequence ab123
|
|
# ab---123--- ab---123---
|
|
# padded_doc_id 00000111111 padded_tok_id 01234012345
|
|
# This will later be useful for the attention mask at generation
|
|
self.padded_doc_id, self.padded_tok_id = lengths_to_local_ids(padded_lengths)
|
|
|
|
@torch.compiler.disable
|
|
def setup_generation(self, lengths):
|
|
# KV Cache offset is set to the start of the padded documents
|
|
for module in self.model.modules():
|
|
if isinstance(module, Attention):
|
|
module.kv_cache.offset = self.padded_doc_start
|
|
# The token ids during generations correspond to the lengths of each doc
|
|
# current_tok_id will be incremented during generation
|
|
self.current_tok_id = lengths.clone()
|
|
# Since we're generating one token per document
|
|
# the document id is just an arange
|
|
self.current_doc_id = torch.arange(lengths.size(0), device=lengths.device)
|
|
|
|
# From here on some methods for generation
|
|
def prefill(self, tokens: torch.Tensor, lengths: torch.Tensor):
|
|
# Prefilling is done by taking multiple packed sequences and
|
|
# doing block diagonal attention on them so they remain independent
|
|
self.setup_prefilling(lengths=lengths)
|
|
prefill_out = self.model.forward(
|
|
tokens,
|
|
tok_idx=self.prefill_tok_id,
|
|
mask=self.prefill_mask,
|
|
attn_impl="flex_attention",
|
|
)
|
|
self.setup_generation(lengths=lengths)
|
|
return prefill_out
|
|
|
|
def generate_next_token(self, current_token):
|
|
# Since we're doing generation with multiple sequences at once
|
|
# we need to ignore tokens and cache entries from other sequences
|
|
# or in the future.
|
|
# Example mask :
|
|
# keys
|
|
# abc--1234--
|
|
# queries c 11100000000
|
|
# 4 00000111100
|
|
|
|
# mask shape : (n_seqs, cache_size)
|
|
doc_mask = self.current_doc_id.unsqueeze(1) == self.padded_doc_id.unsqueeze(0)
|
|
caus_mask = self.current_tok_id.unsqueeze(1) >= self.padded_tok_id.unsqueeze(0)
|
|
mask = doc_mask & caus_mask
|
|
out = self.model.forward(
|
|
current_token,
|
|
tok_idx=self.current_tok_id, # n_seqs
|
|
mask=mask,
|
|
attn_impl="sdpa",
|
|
)
|
|
self.current_tok_id += 1
|
|
return out
|
|
|
|
@torch.inference_mode()
|
|
def generate(self, prompts):
|
|
# Tokenize
|
|
prompts = [
|
|
self.tokenizer.encode(p, add_bos=True, add_eos=False) for p in prompts
|
|
]
|
|
# Truncate
|
|
max_seqlen = (
|
|
self.max_tokens
|
|
if not hasattr(self.model, "max_seqlen")
|
|
else self.model.max_seqlen
|
|
)
|
|
max_prompt_len = self.max_prompt_len or min(
|
|
max_seqlen - self.max_gen_len, self.max_tokens - self.max_gen_len
|
|
)
|
|
prompts = [p[-max_prompt_len:] for p in prompts]
|
|
# Account for the generation in lengths
|
|
padded_lengths = [len(p) + self.max_gen_len for p in prompts]
|
|
generation = []
|
|
loglikelihood = []
|
|
greedy = []
|
|
it = batch_prompts(prompts, self.max_tokens, lengths=padded_lengths)
|
|
if self.show_progress:
|
|
it = tqdm(it)
|
|
for batch in it:
|
|
n_seqs = len(batch)
|
|
generated_tokens = [[] for _ in range(n_seqs)]
|
|
is_done = [False for _ in range(n_seqs)]
|
|
packed_batch, lengths = pack_prompts(batch)
|
|
packed_batch, lengths = packed_batch.cuda(), lengths.cuda()
|
|
n_seqs = lengths.size(0)
|
|
|
|
# Prefilling cache
|
|
prompt_logits = self.prefill(packed_batch.unsqueeze(0), lengths)
|
|
# Selecting last token in each prompt
|
|
all_tokens = sample_tokens(
|
|
prompt_logits, self.temperature, self.top_p, self.top_k
|
|
)
|
|
start_token = all_tokens[:, lengths.cumsum(0) - 1]
|
|
|
|
for seq_id, tok in enumerate(start_token.squeeze(0).tolist()):
|
|
generated_tokens[seq_id].append(tok)
|
|
|
|
current_token = start_token
|
|
for i in range(1, self.max_gen_len):
|
|
|
|
next_logits = self.generate_next_token(current_token)
|
|
next_token = sample_tokens(
|
|
next_logits.clone(), self.temperature, self.top_p, self.top_k
|
|
)
|
|
|
|
for seq_id, tok in enumerate(next_token.squeeze(0).tolist()):
|
|
if not is_done[seq_id]:
|
|
generated_tokens[seq_id].append(tok)
|
|
current_end_str = self.tokenizer.decode(
|
|
generated_tokens[seq_id][-self.max_until_size :]
|
|
)
|
|
contains_end_string = any(
|
|
[e in current_end_str for e in self.until]
|
|
)
|
|
is_done[seq_id] = (
|
|
contains_end_string or tok == self.tokenizer.eos_id
|
|
)
|
|
if all(is_done):
|
|
break
|
|
|
|
current_token = next_token
|
|
|
|
generation.extend([self.tokenizer.decode(g) for g in generated_tokens])
|
|
|
|
for p, logit in zip(
|
|
batch, prompt_logits.squeeze(0).split(lengths.tolist())
|
|
):
|
|
x = logit[:-1]
|
|
y = torch.tensor(p[1:], device=x.device)
|
|
loglikelihood.append(-F.cross_entropy(x, y, reduction="none").cpu())
|
|
greedy.append((x.argmax(dim=-1) == y).cpu())
|
|
|
|
return generation, loglikelihood, greedy
|
|
|
|
|
|
def load_consolidated_model_and_tokenizer(
|
|
consolidated_path,
|
|
model_cls=LMTransformer,
|
|
model_args_cls=LMTransformerArgs,
|
|
):
|
|
ckpt_path = Path(consolidated_path)
|
|
config = ckpt_path / "params.json"
|
|
config = OmegaConf.load(config)
|
|
|
|
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
|
|
config.distributed.model_dtype
|
|
]
|
|
model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
|
|
tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
|
|
model = model_cls(model_args)
|
|
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
|
|
model.load_state_dict(st_dict["model"])
|
|
model = model.cuda().eval()
|
|
for param in model.parameters():
|
|
param.data = param.data.to(dtype=param_dtype)
|
|
return model, tokenizer, config
|
|
|
|
|
|
def main():
|
|
# Load CLI arguments (overrides) and combine with a YAML config
|
|
cfg = OmegaConf.from_cli()
|
|
gen_cfg = dataclass_from_dict(
|
|
PackedCausalTransformerGeneratorArgs, cfg, strict=False
|
|
)
|
|
print(cfg)
|
|
|
|
model, tokenizer, _ = load_consolidated_model_and_tokenizer(cfg.ckpt)
|
|
|
|
generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)
|
|
|
|
# Allow multiple prompts
|
|
prompts = []
|
|
while True:
|
|
prompt = input("Enter a prompt (or press enter to finish): ")
|
|
if not prompt:
|
|
break
|
|
prompts.append(prompt)
|
|
|
|
# Start generation
|
|
start_time = time.time()
|
|
generation, loglikelihood, greedy = generator.generate(prompts)
|
|
end_time = time.time()
|
|
|
|
# Calculate tokens per second
|
|
total_tokens = sum(len(tokenizer.encode(gen, False, False)) for gen in generation)
|
|
tokens_per_second = total_tokens / (end_time - start_time)
|
|
|
|
# Display the results
|
|
for i, gen in enumerate(generation):
|
|
print(f"\nPrompt {i+1}: {prompts[i]}")
|
|
print(f"Generated Text: {gen}")
|
|
|
|
print(f"\nTokens per second: {tokens_per_second:.2f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|