# Copyright (c) Meta Platforms, Inc. and affiliates.
import os
from dataclasses import replace

import numpy as np
import pytest
import torch

from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import Batch
from bytelatent.data.ngram_processor import NgramProcessor
from bytelatent.model.blt import (
    ByteLatentTransformer,
    ByteLatentTransformerArgs,
    EmbeddingType,
    compute_hash_embeddings,
    create_global_transformer,
    create_local_decoder,
    create_local_encoder,
    cross_attn_mask,
    decoder_patch_ids_from_lengths,
    get_blt_input,
    init_embeddings,
    patch_ids_from_lengths,
)
from bytelatent.model.latent_transformer import CrossAttention
from bytelatent.model.utils import create_causal_mask
from bytelatent.optim import OptimArgs, build_optimizer
from bytelatent.tokenizers.constants import EOS_ID
from bytelatent.train import compute_loss


def batch_to_tensors_and_gpu(batch):
    x = torch.from_numpy(batch.x)
    y = torch.from_numpy(batch.y)
    mask = None if batch.mask is None else torch.from_numpy(batch.mask)
    patch_lengths = (
        None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths)
    )
    ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids)

    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
        if mask is not None:
            mask = mask.cuda()
        if patch_lengths is not None:
            patch_lengths = patch_lengths.cuda()
        if ngram_ids is not None:
            ngram_ids = ngram_ids.cuda()
    return x, y, mask, patch_lengths, ngram_ids


def fake_batch():
    batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
    del batch_dict["x2"]
    del batch_dict["y2"]
    del batch_dict["src_names"]
    return Batch(**batch_dict)


def create_args(cross_attention=False):
    transformer_args = ByteLatentTransformerArgs(
        # Base args provided
        n_heads=8,
        dim=512,
        vocab_size=260,
        # Additional args from command line
        dim_token=256,
        patch_size=6,
        tokenization_mode="bytes",
        patching_mode="space",
        tie_local_encoder_decoder_logits=False,
        data_loader_patching=True,
        max_encoder_seq_length=12288,
        pad_to_max_length=True,
        encoder_lm_loss=False,
        patching_threshold=3.1439168453216553,
        encoder_hash_byte_group_size=[4],
        encoder_hash_byte_group_vocab=50002,
        encoder_hash_byte_group_nb_functions=3,
        cross_attn_encoder=cross_attention,  # True,
        cross_attn_decoder=cross_attention,  # True,
        cross_attn_window_encoder=512,
        cross_attn_window_decoder=512,
        dim_local_encoder=256,
        dim_local_decoder=256,
        cross_attn_k=8,
        cross_attn_nheads=4,
        cross_attn_all_layers_decoder=True,
        cross_attn_all_layers_encoder=True,
        cross_attn_use_flex_attention=True,
        cross_attn_init_by_pooling=True,
        log_patch_lengths=True,
        non_linearity="swiglu",
        use_rope=True,
        recompute_fc1_out=False,
        recompute_fc3_out=False,
        recompute_attn=False,
        custom_bwd=False,
        layer_ckpt="none",
        use_local_encoder_transformer=True,
        init_use_gaussian=True,
        init_use_depth="current",
        attn_bias_type="block_causal",
        attn_impl="xformers",
        alpha_depth="disabled",
        max_length=256,
        local_attention_window_len=512,
        max_seqlen=12288,
        downsampling_by_pooling="max",
        eos_id=EOS_ID,
    )
    return transformer_args


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
class TestByteLatentTransformer:
    def test_local_encoder(self):
        args = create_args()
        device = torch.device("cuda")
        local_encoder = create_local_encoder(args).to(device)

        batch = fake_batch()
        tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)

        local_encoder_tokens, _, _ = get_blt_input(
            tokens=tokens,
            enforce_patch_size_multiple=False,
            nb_boe=0,
            patch_size=local_encoder.patch_size,
            boe_id=local_encoder.boe_id,
        )

        patch_ids = patch_ids_from_lengths(
            patch_lengths, local_encoder_tokens.shape[-1]
        )

        encoder_hash_tok_embedding = init_embeddings(
            args,
            EmbeddingType.HASH_TOK,
            local_encoder_dim=local_encoder.dim,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
        ).to(device)

        local_encoder_embeds = compute_hash_embeddings(
            local_encoder_tokens=local_encoder_tokens,
            local_encoder=local_encoder,
            encoder_hash_tok_embedding=encoder_hash_tok_embedding,
            encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
            encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
        )

        reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt")
        reference_tokens = torch.load(reference_path).to(device)
        torch.testing.assert_close(
            local_encoder_tokens,
            reference_tokens,
            msg="Generated tokens don't match reference tokens",
        )

        (h_encoder, h_cross), cache_encoder = local_encoder(
            tokens=local_encoder_tokens,
            embeds=local_encoder_embeds,
            patch_embeds=None,
            cross_mask=None,
            num_patches=patch_lengths.shape[1],
            patch_ids=patch_ids,
        )

        assert h_encoder is not None
        assert h_cross is None
        assert cache_encoder is None

        expected_shape = (
            local_encoder_tokens.shape[0],
            local_encoder_tokens.shape[1],
            local_encoder.dim,
        )
        assert h_encoder.shape == expected_shape

    def test_local_encoder_cross_attention(self):
        args = create_args(cross_attention=True)
        device = torch.device("cuda")
        local_encoder = create_local_encoder(args).to(device)

        batch = fake_batch()
        tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)

        local_encoder_tokens, _, _ = get_blt_input(
            tokens=tokens,
            enforce_patch_size_multiple=False,
            nb_boe=0,
            patch_size=local_encoder.patch_size,
            boe_id=local_encoder.boe_id,
        )

        patch_ids = patch_ids_from_lengths(
            patch_lengths, local_encoder_tokens.shape[-1]
        )

        encoder_hash_tok_embedding = init_embeddings(
            args,
            EmbeddingType.HASH_TOK,
            local_encoder_dim=local_encoder.dim,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
        ).to(device)

        cross_attn_mask_enc = cross_attn_mask(
            patch_ids,
            patch_lengths,
            local_encoder_tokens.shape[-1],
            patches_as_queries=True,
            cross_attn_k=args.cross_attn_k,
            window=args.cross_attn_window_encoder,
            block_mask=True,
        )

        local_encoder_embeds = compute_hash_embeddings(
            local_encoder_tokens=local_encoder_tokens,
            local_encoder=local_encoder,
            encoder_hash_tok_embedding=encoder_hash_tok_embedding,
            encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions,
            encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
            encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab,
        )
        (h_encoder, h_cross), cache_encoder = local_encoder(
            tokens=local_encoder_tokens,
            embeds=local_encoder_embeds,
            patch_embeds=None,
            cross_mask=cross_attn_mask_enc,
            num_patches=patch_lengths.shape[1],
            patch_ids=patch_ids,
        )
        assert h_encoder is not None
        assert h_cross is not None
        assert cache_encoder is None
        expected_shape = (
            local_encoder_tokens.shape[0],
            local_encoder_tokens.shape[1],
            local_encoder.dim,
        )
        assert h_encoder.shape == expected_shape
        assert h_cross.shape == (2, 2048, local_encoder.dim)

    def test_local_decoder_cross_attention(self):
        args = create_args(cross_attention=True)
        device = torch.device("cuda")
        local_decoder = create_local_decoder(args).to(device)

        test_files = {
            "dec_embeds": "dec_embeds.pt",
            "decoder_tokens": "local_decoder_tokens.pt",
            "patch_embeds": "decoder_patch_cross_embeds.pt",
        }
        batch = fake_batch()
        _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch)

        tensors = {
            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
            for name, filename in test_files.items()
        }
        decoder_patch_ids = decoder_patch_ids_from_lengths(
            patch_lengths, 0, tensors["decoder_tokens"].shape[-1]
        )
        cross_attn_mask_dec = cross_attn_mask(
            decoder_patch_ids,
            patch_lengths,
            tensors["decoder_tokens"].shape[-1],
            patches_as_queries=False,
            cross_attn_k=args.cross_attn_k,
            window=args.cross_attn_window_decoder,
            block_mask=True,
        )
        output, _ = local_decoder(
            embeds=tensors["dec_embeds"],
            patch_embeds=tensors["patch_embeds"],
            tokens=tensors["decoder_tokens"],
            cross_mask=cross_attn_mask_dec,
            cache=None,
        )
        assert output is not None
        assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size)

    def test_local_decoder(self):
        args = create_args()
        device = torch.device("cuda")

        local_decoder = create_local_decoder(args).to(device)

        test_files = {
            "dec_embeds": "dec_embeds.pt",
            "decoder_tokens": "local_decoder_tokens.pt",
            "patch_embeds": "decoder_patch_embeds.pt",
        }

        tensors = {
            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
            for name, filename in test_files.items()
        }

        output, cache_decoder = local_decoder(
            embeds=tensors["dec_embeds"],
            patch_embeds=tensors["patch_embeds"],
            tokens=tensors["decoder_tokens"],
            cross_mask=None,
            cache=None,
        )
        assert output is not None
        expected_shape = (
            tensors["decoder_tokens"].shape[0],
            tensors["decoder_tokens"].shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape
        assert cache_decoder is None

    def test_global_transformer(self):
        args = create_args()
        device = torch.device("cuda")
        global_transformer = create_global_transformer(args).to(device)

        test_files = {
            "global_embeds": "global_embeds.pt",
            "global_tokens": "global_tokens.pt",
        }
        tensors = {
            name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device)
            for name, filename in test_files.items()
        }
        h, cache = global_transformer(
            embeds=tensors["global_embeds"], tokens=tensors["global_tokens"]
        )
        h is not None
        assert h.shape == (2, 256, 512)
        assert cache is None

    def test_blt_transformer_init(self):
        args = create_args()
        model = ByteLatentTransformer(args)
        assert model is not None

    @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"])
    def test_blt_transformer_forward(self, attn_impl):
        args = create_args()
        if attn_impl == "sdpa":
            os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1"
        else:
            os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0"

        args = args.model_copy(update=dict(attn_impl=attn_impl))
        model = ByteLatentTransformer(args)
        model = model.cuda()
        batch = fake_batch()
        x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        output = model(
            tokens=x,
            patch_lengths=patch_lengths,
            ngram_ids=ngram_ids,
        )
        assert output is not None
        expected_shape = (
            x.shape[0],
            x.shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape

    def test_blt_transformer_cross_attn_forward(self):
        args = create_args(cross_attention=True)
        model = ByteLatentTransformer(args)
        model = model.cuda()
        batch = fake_batch()
        x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        output = model(
            tokens=x,
            patch_lengths=patch_lengths,
            ngram_ids=ngram_ids,
        )
        assert output is not None
        expected_shape = (
            x.shape[0],
            x.shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape

    def test_cross_attention_rand(self):
        x = torch.randn(2, 256, 512, device="cuda")
        kv = torch.randn(2, 256, 512, device="cuda")
        cross_attention = CrossAttention(
            dim=512,
            head_dim=64,
            n_heads=8,
            n_kv_heads=4,
            norm_eps=1e-6,
        ).to("cuda")
        mask = create_causal_mask(
            x.shape[1], "flex_attention", None, sliding_window=None
        )
        output = cross_attention(x, kv, mask)
        assert output is not None
        assert output.shape == (2, 256, 512)

    def test_ngram_embeddings(self):
        ngram_to_size = {
            2: 38396,
            3: 50000,
            4: 50000,
            5: 50000,
            6: 50000,
            7: 50000,
            8: 50000,
        }
        batch = fake_batch()
        ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size)
        ngram_ids = ngram_processor.encode_token_ngrams(batch.x)
        ngram_ids = np.stack(ngram_ids, axis=0)
        batch = replace(batch, ngram_ids=ngram_ids)
        args = create_args(cross_attention=True)
        args = args.model_copy(
            update=dict(
                encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000",
                encoder_enable_byte_ngrams=True,
                ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes,
            )
        )
        model = ByteLatentTransformer(args)
        model = model.cuda()
        x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        output = model(
            tokens=x,
            patch_lengths=patch_lengths,
            ngram_ids=ngram_ids,
        )
        assert output is not None
        expected_shape = (
            x.shape[0],
            x.shape[1],
            args.vocab_size,
        )
        assert output.shape == expected_shape

    def test_loss_backward(self):
        args = create_args()
        args = args.model_copy(update=dict(attn_impl="xformers"))
        batch = fake_batch()
        model = ByteLatentTransformer(args)
        steps = 10
        optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps)
        model = model.cuda()
        x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch)

        initial_loss = None
        final_loss = None
        for step in range(steps):
            output = model(
                tokens=x,
                patch_lengths=patch_lengths,
                ngram_ids=ngram_ids,
            )
            loss, _ = compute_loss(output, y, mask, 1.0)
            if step == 0:
                initial_loss = loss.item()
            if step == steps - 1:
                final_loss = loss.item()
            prev_loss = loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        assert (
            final_loss < initial_loss
        ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}"