# Copyright (c) Meta Platforms, Inc. and affiliates. import os from dataclasses import replace import numpy as np import pytest import torch from bytelatent.constants import BLT_DATA from bytelatent.data.data_types import Batch from bytelatent.data.ngram_processor import NgramProcessor from bytelatent.model.blt import ( ByteLatentTransformer, ByteLatentTransformerArgs, EmbeddingType, compute_hash_embeddings, create_global_transformer, create_local_decoder, create_local_encoder, cross_attn_mask, decoder_patch_ids_from_lengths, get_blt_input, init_embeddings, patch_ids_from_lengths, ) from bytelatent.model.latent_transformer import CrossAttention from bytelatent.model.utils import create_causal_mask from bytelatent.optim import OptimArgs, build_optimizer from bytelatent.tokenizers.constants import EOS_ID from bytelatent.train import compute_loss def batch_to_tensors_and_gpu(batch): x = torch.from_numpy(batch.x) y = torch.from_numpy(batch.y) mask = None if batch.mask is None else torch.from_numpy(batch.mask) patch_lengths = ( None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths) ) ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids) if torch.cuda.is_available(): x = x.cuda() y = y.cuda() if mask is not None: mask = mask.cuda() if patch_lengths is not None: patch_lengths = patch_lengths.cuda() if ngram_ids is not None: ngram_ids = ngram_ids.cuda() return x, y, mask, patch_lengths, ngram_ids def fake_batch(): batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) del batch_dict["x2"] del batch_dict["y2"] del batch_dict["src_names"] return Batch(**batch_dict) def create_args(cross_attention=False): transformer_args = ByteLatentTransformerArgs( # Base args provided n_heads=8, dim=512, vocab_size=260, # Additional args from command line dim_token=256, patch_size=6, tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, data_loader_patching=True, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, patching_threshold=3.1439168453216553, encoder_hash_byte_group_size=[4], encoder_hash_byte_group_vocab=50002, encoder_hash_byte_group_nb_functions=3, cross_attn_encoder=cross_attention, # True, cross_attn_decoder=cross_attention, # True, cross_attn_window_encoder=512, cross_attn_window_decoder=512, dim_local_encoder=256, dim_local_decoder=256, cross_attn_k=8, cross_attn_nheads=4, cross_attn_all_layers_decoder=True, cross_attn_all_layers_encoder=True, cross_attn_use_flex_attention=True, cross_attn_init_by_pooling=True, log_patch_lengths=True, non_linearity="swiglu", use_rope=True, recompute_fc1_out=False, recompute_fc3_out=False, recompute_attn=False, custom_bwd=False, layer_ckpt="none", use_local_encoder_transformer=True, init_use_gaussian=True, init_use_depth="current", attn_bias_type="block_causal", attn_impl="xformers", alpha_depth="disabled", max_length=256, local_attention_window_len=512, max_seqlen=12288, downsampling_by_pooling="max", eos_id=EOS_ID, ) return transformer_args @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") class TestByteLatentTransformer: def test_local_encoder(self): args = create_args() device = torch.device("cuda") local_encoder = create_local_encoder(args).to(device) batch = fake_batch() tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) local_encoder_tokens, _, _ = get_blt_input( tokens=tokens, enforce_patch_size_multiple=False, nb_boe=0, patch_size=local_encoder.patch_size, boe_id=local_encoder.boe_id, ) patch_ids = patch_ids_from_lengths( patch_lengths, local_encoder_tokens.shape[-1] ) encoder_hash_tok_embedding = init_embeddings( args, EmbeddingType.HASH_TOK, local_encoder_dim=local_encoder.dim, encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, ).to(device) local_encoder_embeds = compute_hash_embeddings( local_encoder_tokens=local_encoder_tokens, local_encoder=local_encoder, encoder_hash_tok_embedding=encoder_hash_tok_embedding, encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, ) reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt") reference_tokens = torch.load(reference_path).to(device) torch.testing.assert_close( local_encoder_tokens, reference_tokens, msg="Generated tokens don't match reference tokens", ) (h_encoder, h_cross), cache_encoder = local_encoder( tokens=local_encoder_tokens, embeds=local_encoder_embeds, patch_embeds=None, cross_mask=None, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, ) assert h_encoder is not None assert h_cross is None assert cache_encoder is None expected_shape = ( local_encoder_tokens.shape[0], local_encoder_tokens.shape[1], local_encoder.dim, ) assert h_encoder.shape == expected_shape def test_local_encoder_cross_attention(self): args = create_args(cross_attention=True) device = torch.device("cuda") local_encoder = create_local_encoder(args).to(device) batch = fake_batch() tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) local_encoder_tokens, _, _ = get_blt_input( tokens=tokens, enforce_patch_size_multiple=False, nb_boe=0, patch_size=local_encoder.patch_size, boe_id=local_encoder.boe_id, ) patch_ids = patch_ids_from_lengths( patch_lengths, local_encoder_tokens.shape[-1] ) encoder_hash_tok_embedding = init_embeddings( args, EmbeddingType.HASH_TOK, local_encoder_dim=local_encoder.dim, encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, ).to(device) cross_attn_mask_enc = cross_attn_mask( patch_ids, patch_lengths, local_encoder_tokens.shape[-1], patches_as_queries=True, cross_attn_k=args.cross_attn_k, window=args.cross_attn_window_encoder, block_mask=True, ) local_encoder_embeds = compute_hash_embeddings( local_encoder_tokens=local_encoder_tokens, local_encoder=local_encoder, encoder_hash_tok_embedding=encoder_hash_tok_embedding, encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, ) (h_encoder, h_cross), cache_encoder = local_encoder( tokens=local_encoder_tokens, embeds=local_encoder_embeds, patch_embeds=None, cross_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, ) assert h_encoder is not None assert h_cross is not None assert cache_encoder is None expected_shape = ( local_encoder_tokens.shape[0], local_encoder_tokens.shape[1], local_encoder.dim, ) assert h_encoder.shape == expected_shape assert h_cross.shape == (2, 2048, local_encoder.dim) def test_local_decoder_cross_attention(self): args = create_args(cross_attention=True) device = torch.device("cuda") local_decoder = create_local_decoder(args).to(device) test_files = { "dec_embeds": "dec_embeds.pt", "decoder_tokens": "local_decoder_tokens.pt", "patch_embeds": "decoder_patch_cross_embeds.pt", } batch = fake_batch() _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) tensors = { name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) for name, filename in test_files.items() } decoder_patch_ids = decoder_patch_ids_from_lengths( patch_lengths, 0, tensors["decoder_tokens"].shape[-1] ) cross_attn_mask_dec = cross_attn_mask( decoder_patch_ids, patch_lengths, tensors["decoder_tokens"].shape[-1], patches_as_queries=False, cross_attn_k=args.cross_attn_k, window=args.cross_attn_window_decoder, block_mask=True, ) output, _ = local_decoder( embeds=tensors["dec_embeds"], patch_embeds=tensors["patch_embeds"], tokens=tensors["decoder_tokens"], cross_mask=cross_attn_mask_dec, cache=None, ) assert output is not None assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size) def test_local_decoder(self): args = create_args() device = torch.device("cuda") local_decoder = create_local_decoder(args).to(device) test_files = { "dec_embeds": "dec_embeds.pt", "decoder_tokens": "local_decoder_tokens.pt", "patch_embeds": "decoder_patch_embeds.pt", } tensors = { name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) for name, filename in test_files.items() } output, cache_decoder = local_decoder( embeds=tensors["dec_embeds"], patch_embeds=tensors["patch_embeds"], tokens=tensors["decoder_tokens"], cross_mask=None, cache=None, ) assert output is not None expected_shape = ( tensors["decoder_tokens"].shape[0], tensors["decoder_tokens"].shape[1], args.vocab_size, ) assert output.shape == expected_shape assert cache_decoder is None def test_global_transformer(self): args = create_args() device = torch.device("cuda") global_transformer = create_global_transformer(args).to(device) test_files = { "global_embeds": "global_embeds.pt", "global_tokens": "global_tokens.pt", } tensors = { name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) for name, filename in test_files.items() } h, cache = global_transformer( embeds=tensors["global_embeds"], tokens=tensors["global_tokens"] ) h is not None assert h.shape == (2, 256, 512) assert cache is None def test_blt_transformer_init(self): args = create_args() model = ByteLatentTransformer(args) assert model is not None @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"]) def test_blt_transformer_forward(self, attn_impl): args = create_args() if attn_impl == "sdpa": os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" else: os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0" args = args.model_copy(update=dict(attn_impl=attn_impl)) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) output = model( tokens=x, patch_lengths=patch_lengths, ngram_ids=ngram_ids, ) assert output is not None expected_shape = ( x.shape[0], x.shape[1], args.vocab_size, ) assert output.shape == expected_shape def test_blt_transformer_cross_attn_forward(self): args = create_args(cross_attention=True) model = ByteLatentTransformer(args) model = model.cuda() batch = fake_batch() x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) output = model( tokens=x, patch_lengths=patch_lengths, ngram_ids=ngram_ids, ) assert output is not None expected_shape = ( x.shape[0], x.shape[1], args.vocab_size, ) assert output.shape == expected_shape def test_cross_attention_rand(self): x = torch.randn(2, 256, 512, device="cuda") kv = torch.randn(2, 256, 512, device="cuda") cross_attention = CrossAttention( dim=512, head_dim=64, n_heads=8, n_kv_heads=4, norm_eps=1e-6, ).to("cuda") mask = create_causal_mask( x.shape[1], "flex_attention", None, sliding_window=None ) output = cross_attention(x, kv, mask) assert output is not None assert output.shape == (2, 256, 512) def test_ngram_embeddings(self): ngram_to_size = { 2: 38396, 3: 50000, 4: 50000, 5: 50000, 6: 50000, 7: 50000, 8: 50000, } batch = fake_batch() ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size) ngram_ids = ngram_processor.encode_token_ngrams(batch.x) ngram_ids = np.stack(ngram_ids, axis=0) batch = replace(batch, ngram_ids=ngram_ids) args = create_args(cross_attention=True) args = args.model_copy( update=dict( encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000", encoder_enable_byte_ngrams=True, ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes, ) ) model = ByteLatentTransformer(args) model = model.cuda() x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) output = model( tokens=x, patch_lengths=patch_lengths, ngram_ids=ngram_ids, ) assert output is not None expected_shape = ( x.shape[0], x.shape[1], args.vocab_size, ) assert output.shape == expected_shape def test_loss_backward(self): args = create_args() args = args.model_copy(update=dict(attn_impl="xformers")) batch = fake_batch() model = ByteLatentTransformer(args) steps = 10 optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps) model = model.cuda() x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) initial_loss = None final_loss = None for step in range(steps): output = model( tokens=x, patch_lengths=patch_lengths, ngram_ids=ngram_ids, ) loss, _ = compute_loss(output, y, mask, 1.0) if step == 0: initial_loss = loss.item() if step == steps - 1: final_loss = loss.item() prev_loss = loss.item() loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() assert ( final_loss < initial_loss ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}"