From c2108e7256a30482b2e3aedbf8c4d647f709b56b Mon Sep 17 00:00:00 2001 From: "ita.zaporozhets@huggingface.co" Date: Tue, 3 Jun 2025 15:29:01 +0000 Subject: [PATCH] move args out --- blt_args.py | 248 +++++++++++++++++++++++ blt_one_file.py | 508 ++++++++++-------------------------------------- demo_hf.py | 9 +- 3 files changed, 351 insertions(+), 414 deletions(-) create mode 100644 blt_args.py diff --git a/blt_args.py b/blt_args.py new file mode 100644 index 0000000..94527b5 --- /dev/null +++ b/blt_args.py @@ -0,0 +1,248 @@ +from enum import Enum, auto +from typing import Any, List, Optional, Tuple, Union +from pydantic import BaseModel, ConfigDict, model_validator +from typing_extensions import Self +from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID + +class InitStdFactor(str, Enum): + DISABLED = "disabled" # Init std is divided by 1.0 + GLOBAL_DEPTH = "global_depth" # Init std is divided by sqrt(2*n_layers) + CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) + DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 + + +class BaseTransformerArgs(BaseModel): + model_config = ConfigDict(extra="forbid") + dim: int = 512 + n_layers: int = 8 + head_dim: int | None = None + n_heads: int | None = None + n_kv_heads: int | None = None + + ffn_dim_multiplier: float | None = None + + multiple_of: int = 256 + + norm_eps: float = 1e-5 + + rope_theta: float = 10000.0 + rope_use_fp32_in_outer_product: bool = False + + init_base_std: float | None = None + init_std_factor: InitStdFactor = InitStdFactor.DISABLED + + max_seqlen: int = 1024 + + attn_impl: str | None = "sdpa" + attn_bias_type: str | None = None + # Special token config + eos_id: int | None = EOS_ID + +class ByteLatentTransformerArgs(BaseTransformerArgs): + # Basic model configuration + seed: int = 42 + vocab_size: int = -1 + dim: int = 512 + n_layers: int = 8 + n_heads: int = 8 + # TODO: What is the purpose of this parameter? + weight_tying: bool = False + patch_in_forward: bool = False + + # Architecture and dimensions + dim_token: int | None = None + dim_global: int = 512 + dim_local_decoder: int = 512 + dim_local_encoder: int = 512 + n_layers_global: int = 8 + n_layers_local_decoder: int = 8 + n_layers_local_encoder: int = 8 + + # Tokenization and patching + patch_size: float | None = None + patching_mode: str | None = None + patching_threshold: float | None = None + patching_threshold_add: float | None = None + monotonicity: bool = False + patching_batch_size: int = 1 + patching_device: str = "cuda" + max_patch_length: int | None = None + + # Encoder/Decoder configuration + tie_local_encoder_decoder_logits: bool = False + use_local_encoder_transformer: bool = False + encoder_lm_loss: bool = False + max_encoder_seq_length: int | None = None + pad_to_max_length: bool = False + encoder_enable_byte_ngrams: bool = False + encoder_enable_byte_group_hash: bool = False + ngram_vocab_sizes: int | None = None + + # Cross attention configurations + cross_attn_encoder: bool = False + cross_attn_decoder: bool = False + cross_attn_window_encoder: int | None = None + cross_attn_window_decoder: int | None = None + cross_attn_k: int | None = None + cross_attn_nheads: int | None = None + cross_attn_all_layers_decoder: bool = False + cross_attn_all_layers_encoder: bool = False + cross_attn_use_flex_attention: bool = True + cross_attn_init_by_pooling: bool = False + + # Encoder hash configurations + encoder_hash_byte_group_size: Any | None = None + encoder_hash_byte_group_vocab: int = 30000 + encoder_hash_byte_group_nb_functions: int = 3 + + # Model behavior and optimization + log_patch_lengths: bool = False + non_linearity: str = "swiglu" + use_rope: bool = True + recompute_fc1_out: bool = False + recompute_fc3_out: bool = False + recompute_attn: bool = True + custom_bwd: bool = False + layer_ckpt: str = "all" + + # Initialization and attention + init_use_gaussian: bool = True + init_use_depth: str = "current" + attn_bias_type: str = "causal" + alpha_depth: str = "disabled" + max_length: int = 2048 + + # Norm configuration + norm_eps: float = 1e-5 + norm_affine: bool = True + pre_norm: bool = True + norm_type: str = "rmsnorm" + + # Additional configurations + multiple_of: int = 256 + ffn_dim_multiplier: float = 1.0 + dropout: float = 0 + output_size: int = -1 + + # Additional parameters from ModelArgs + architecture: str = "vanilla" + share_encoder_decoder_emb: bool = True + global_local_decoder_residual_layer: str | None = None + + tokenize_with_bpe_delimiter: bool = False + patching_thresholds_str: str | None = None + tie_local_encoder_decoder: bool = False + encoder_preds_low_entropy_toks: float | None = None + encoder_preds_random_toks: float | None = None + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + encoder_ngram_table_dir: str | None = None + encoder_ngram_to_size_str: str | None = None + + # Model architecture params + entropy_model_checkpoint_dir: str | None = None + entropy_model_is_ngram_model: bool = False + downsampling_by_pooling: str | None = None + n_heads_global: int = 8 + n_heads_local_decoder: int = 8 + n_heads_local_encoder: int = 8 + n_kv_heads: int | None = None + n_kv_heads_global: int | None = None + conv_kernel_size: int | None = None + local_attention_window_len: int | None = None + + # Performance optimization + sequence_parallel: bool = False + loss_parallel: bool = False + fuse_sequence_parallel: bool = False + use_fsdp: bool = True + attn_to_keep: str = "all" + + # Parameter mixing + pm_size: int = 0 + + # Logging + full_logging_n_layers: int = 4 + + @model_validator(mode="after") + def check_hash_byte_sizes(self) -> Self: + if ( + self.encoder_hash_byte_group_size is not None + and type(self.encoder_hash_byte_group_size) == str + ): + self.encoder_hash_byte_group_size = [ + int(x) + for x in self.encoder_hash_byte_group_size.split(",") + if len(x) > 0 + ] + return self + + +class GlobalTransformerArgs(ByteLatentTransformerArgs): + # Global encoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with global encoder specific values + self.dim = self.dim_global + self.n_layers = self.n_layers_global + self.n_heads = self.n_heads_global + self.n_kv_heads = self.n_kv_heads_global + self.local_attention_window_len = None + self.cross_attn_encoder = False + self.cross_attn_decoder = False + + +class LocalDecoderArgs(ByteLatentTransformerArgs): + # Local decoder specific dimensions + dim_token_emb: int | None = None + dim_patch_emb: int | None = None + + def __post_init__(self): + # Override base args with local decoder specific values + self.dim = self.dim_local_decoder + self.n_layers = self.n_layers_local_decoder + self.n_heads = self.n_heads_local_decoder + self.cross_attn_encoder = False + self.cross_attn_init_by_pooling = False + self.attn_bias_type = "local_block_causal" + + +class LocalModelArgs(BaseTransformerArgs): + model_config = ConfigDict(extra="forbid") + # Override defaults + attn_impl: str | None = "xformers" + attn_bias_type: str | None = "local_block_causal" + + # Local encoder specific dimensions + dropout: float + vocab_size: int + patch_size: float + sliding_window: int | None + use_rope: bool + cross_attn_encoder: bool | None + cross_attn_decoder: bool | None + cross_attn_k: int | None + cross_attn_init_by_pooling: bool + patching_mode: str + use_local_encoder_transformer: bool + downsampling_by_pooling: str | None + encoder_hash_byte_group_size: Any | None = None + cross_attn_all_layers_encoder: bool = False + cross_attn_all_layers_decoder: bool = False + cross_attn_nheads: int | None + + dim_token_emb: int + dim_patch_emb: int | None + + +class LMTransformerArgs(BaseTransformerArgs): + seed: int = 42 + + vocab_size: int = -1 + weight_tying: bool = False + + sliding_window: int | None = None + diff --git a/blt_one_file.py b/blt_one_file.py index 38219c7..e7a82a5 100644 --- a/blt_one_file.py +++ b/blt_one_file.py @@ -34,6 +34,17 @@ from bytelatent.distributed import get_local_rank logger = logging.getLogger() +from blt_args import ( + BaseTransformerArgs, + ByteLatentTransformerArgs, + GlobalTransformerArgs, + LocalDecoderArgs, + LocalModelArgs, + LMTransformerArgs, + +) + + if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0: flex_attention_comp = torch.compile(flex_attention) @@ -71,65 +82,6 @@ def patch_reduce(h, max_num_patches, reduction, patch_ids): return reduced_embs -def concat_downsample(h, patch_lengths, patch_size): - # The assumption in this function is that seq_len = patch_size * num_patches. - bs, seq_len, emb_dim = h.shape - patch_end_ids = torch.cumsum(patch_lengths, dim=1) - patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( - patch_end_ids.device - ) - # Is clamp ok here? - patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) - patch_ids = patch_ids.view(bs, -1, emb_dim) - # after gather h.shape = [batch_size, seq_len, dim] - h = torch.gather(h, 1, patch_ids) - h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) - return h - - -def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): - cat = [] - if "avg" in pooling_mode or "mean" in pooling_mode: - cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) - if "min" in pooling_mode: - cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) - if "max" in pooling_mode: - cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) - assert len(cat) > 0 - h = torch.cat(cat, dim=-1) - return h - - -def downsample( - h, - num_patches, - patch_lengths=None, - patch_ids=None, - downsampling_by_pooling=None, - patch_size=4, -): - """ - Downsampling: - a. concatenating embeddings in the patch - Note: with dynamic patching, patch the last patch_size tokens. - b. pooling embeddings in the patch - """ - # input: h.shape = [batch_size, seq_len, dim] - # input: pool h.shape = [batch_size, seq_len / patch_size, dim] - # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep - if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: - # By pooling - max_num_patches = num_patches - assert patch_ids is not None - h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) - else: - # TODO: remove this condition - # By concatenating (fixed lengths patching) - assert patch_lengths is not None - h = concat_downsample(h, patch_lengths, patch_size) - return h - - def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx @@ -215,35 +167,6 @@ class InitStdFactor(str, Enum): CURRENT_DEPTH = "current_depth" # Init std is divided by sqrt(2*depth) DIM_RATIO = "dim_ratio" # Init std is divided by model_dim/4096 - -class BaseTransformerArgs(BaseModel): - model_config = ConfigDict(extra="forbid") - dim: int = 512 - n_layers: int = 8 - head_dim: int | None = None - n_heads: int | None = None - n_kv_heads: int | None = None - - ffn_dim_multiplier: float | None = None - - multiple_of: int = 256 - - norm_eps: float = 1e-5 - - rope_theta: float = 10000.0 - rope_use_fp32_in_outer_product: bool = False - - init_base_std: float | None = None - init_std_factor: InitStdFactor = InitStdFactor.DISABLED - - max_seqlen: int = 1024 - - attn_impl: str | None = "sdpa" - attn_bias_type: str | None = None - # Special token config - eos_id: int | None = EOS_ID - - def cross_entropy(pred, target, **kwargs): return F.nll_loss( F.log_softmax(pred.flatten(end_dim=-2).float(), -1), @@ -730,15 +653,6 @@ class BaseTransformer(nn.Module, SequenceModelWithOutput): layer.init_weights(self.init_base_std, factor) -class LMTransformerArgs(BaseTransformerArgs): - seed: int = 42 - - vocab_size: int = -1 - weight_tying: bool = False - - sliding_window: int | None = None - - class LMTransformer( BaseTransformer, PyTorchModelHubMixin, @@ -1401,177 +1315,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len): return patch_ids -class ByteLatentTransformerArgs(BaseTransformerArgs): - # Basic model configuration - seed: int = 42 - vocab_size: int = -1 - dim: int = 512 - n_layers: int = 8 - n_heads: int = 8 - # TODO: What is the purpose of this parameter? - weight_tying: bool = False - patch_in_forward: bool = False - - # Architecture and dimensions - dim_token: int | None = None - dim_global: int = 512 - dim_local_decoder: int = 512 - dim_local_encoder: int = 512 - n_layers_global: int = 8 - n_layers_local_decoder: int = 8 - n_layers_local_encoder: int = 8 - - # Tokenization and patching - patch_size: float | None = None - patching_mode: str | None = None - patching_threshold: float | None = None - patching_threshold_add: float | None = None - monotonicity: bool = False - patching_batch_size: int = 1 - patching_device: str = "cuda" - max_patch_length: int | None = None - - # Encoder/Decoder configuration - tie_local_encoder_decoder_logits: bool = False - use_local_encoder_transformer: bool = False - encoder_lm_loss: bool = False - max_encoder_seq_length: int | None = None - pad_to_max_length: bool = False - encoder_enable_byte_ngrams: bool = False - encoder_enable_byte_group_hash: bool = False - ngram_vocab_sizes: int | None = None - - # Cross attention configurations - cross_attn_encoder: bool = False - cross_attn_decoder: bool = False - cross_attn_window_encoder: int | None = None - cross_attn_window_decoder: int | None = None - cross_attn_k: int | None = None - cross_attn_nheads: int | None = None - cross_attn_all_layers_decoder: bool = False - cross_attn_all_layers_encoder: bool = False - cross_attn_use_flex_attention: bool = True - cross_attn_init_by_pooling: bool = False - - # Encoder hash configurations - encoder_hash_byte_group_size: Any | None = None - encoder_hash_byte_group_vocab: int = 30000 - encoder_hash_byte_group_nb_functions: int = 3 - - # Model behavior and optimization - log_patch_lengths: bool = False - non_linearity: str = "swiglu" - use_rope: bool = True - recompute_fc1_out: bool = False - recompute_fc3_out: bool = False - recompute_attn: bool = True - custom_bwd: bool = False - layer_ckpt: str = "all" - - # Initialization and attention - init_use_gaussian: bool = True - init_use_depth: str = "current" - attn_bias_type: str = "causal" - alpha_depth: str = "disabled" - max_length: int = 2048 - - # Norm configuration - norm_eps: float = 1e-5 - norm_affine: bool = True - pre_norm: bool = True - norm_type: str = "rmsnorm" - - # Additional configurations - multiple_of: int = 256 - ffn_dim_multiplier: float = 1.0 - dropout: float = 0 - output_size: int = -1 - - # Additional parameters from ModelArgs - architecture: str = "vanilla" - share_encoder_decoder_emb: bool = True - global_local_decoder_residual_layer: str | None = None - - tokenize_with_bpe_delimiter: bool = False - patching_thresholds_str: str | None = None - tie_local_encoder_decoder: bool = False - encoder_preds_low_entropy_toks: float | None = None - encoder_preds_random_toks: float | None = None - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - encoder_ngram_table_dir: str | None = None - encoder_ngram_to_size_str: str | None = None - - # Model architecture params - entropy_model_checkpoint_dir: str | None = None - entropy_model_is_ngram_model: bool = False - downsampling_by_pooling: str | None = None - n_heads_global: int = 8 - n_heads_local_decoder: int = 8 - n_heads_local_encoder: int = 8 - n_kv_heads: int | None = None - n_kv_heads_global: int | None = None - conv_kernel_size: int | None = None - local_attention_window_len: int | None = None - - # Performance optimization - sequence_parallel: bool = False - loss_parallel: bool = False - fuse_sequence_parallel: bool = False - use_fsdp: bool = True - attn_to_keep: str = "all" - - # Parameter mixing - pm_size: int = 0 - - # Logging - full_logging_n_layers: int = 4 - - @model_validator(mode="after") - def check_hash_byte_sizes(self) -> Self: - if ( - self.encoder_hash_byte_group_size is not None - and type(self.encoder_hash_byte_group_size) == str - ): - self.encoder_hash_byte_group_size = [ - int(x) - for x in self.encoder_hash_byte_group_size.split(",") - if len(x) > 0 - ] - return self - - -class GlobalTransformerArgs(ByteLatentTransformerArgs): - # Global encoder specific dimensions - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with global encoder specific values - self.dim = self.dim_global - self.n_layers = self.n_layers_global - self.n_heads = self.n_heads_global - self.n_kv_heads = self.n_kv_heads_global - self.local_attention_window_len = None - self.cross_attn_encoder = False - self.cross_attn_decoder = False - - -class LocalDecoderArgs(ByteLatentTransformerArgs): - # Local decoder specific dimensions - dim_token_emb: int | None = None - dim_patch_emb: int | None = None - - def __post_init__(self): - # Override base args with local decoder specific values - self.dim = self.dim_local_decoder - self.n_layers = self.n_layers_local_decoder - self.n_heads = self.n_heads_local_decoder - self.cross_attn_encoder = False - self.cross_attn_init_by_pooling = False - self.attn_bias_type = "local_block_causal" - def create_global_transformer(args: ByteLatentTransformerArgs): global_args = args.model_copy( @@ -1592,34 +1335,6 @@ def create_global_transformer(args: ByteLatentTransformerArgs): return GlobalTransformer(global_args) -class LocalModelArgs(BaseTransformerArgs): - model_config = ConfigDict(extra="forbid") - # Override defaults - attn_impl: str | None = "xformers" - attn_bias_type: str | None = "local_block_causal" - - # Local encoder specific dimensions - dropout: float - vocab_size: int - patch_size: float - sliding_window: int | None - use_rope: bool - cross_attn_encoder: bool | None - cross_attn_decoder: bool | None - cross_attn_k: int | None - cross_attn_init_by_pooling: bool - patching_mode: str - use_local_encoder_transformer: bool - downsampling_by_pooling: str | None - encoder_hash_byte_group_size: Any | None = None - cross_attn_all_layers_encoder: bool = False - cross_attn_all_layers_decoder: bool = False - cross_attn_nheads: int | None - - dim_token_emb: int - dim_patch_emb: int | None - - class LocalModelBase(nn.Module): def __init__(self, args: LocalModelArgs): super().__init__() @@ -1850,13 +1565,14 @@ class LocalEncoder(LocalModelBase): ): # apply pooling and project if self.cross_attn_init_by_pooling and patch_embeds is None: - patch_embeds = downsample( - h, - num_patches, - patch_ids=patch_ids, - downsampling_by_pooling=self.downsampling_by_pooling, - patch_size=self.patch_size, - ) + # patch_embeds = downsample( + # h, + # num_patches, + # patch_ids=patch_ids, + # downsampling_by_pooling=self.downsampling_by_pooling, + # patch_size=self.patch_size, + # ) + patch_embeds = patch_reduce(h, num_patches, "amax", patch_ids) if self.patch_embedding_projection is not None: patch_embeds = self.patch_embedding_projection(patch_embeds) patch_embeds = patch_embeds.reshape( @@ -2146,94 +1862,6 @@ class GlobalTransformer(BaseTransformer): b=3 * std, ) - - -def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: - local_encoder_args = LocalModelArgs( - # Updated args - dim=args.dim_local_encoder, - n_layers=args.n_layers_local_encoder, - n_heads=args.n_heads_local_encoder, - dim_token_emb=get_encoder_dim_token_emb(args), - dim_patch_emb=get_encoder_dim_patch_emb(args), - cross_attn_encoder=args.cross_attn_encoder, - cross_attn_decoder=False, - cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, - cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, - # Defaults - head_dim=args.head_dim, - max_seqlen=args.max_encoder_seq_length, - dropout=args.dropout, - vocab_size=args.vocab_size + args.pm_size, - norm_eps=args.norm_eps, - patch_size=args.patch_size, - sliding_window=args.local_attention_window_len, - use_rope=args.use_rope, - rope_theta=args.rope_theta, - rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, - init_base_std=args.init_base_std, - init_std_factor=args.init_std_factor, - n_kv_heads=args.n_kv_heads, - attn_impl=args.attn_impl, - attn_bias_type="local_block_causal", - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - patching_mode=args.patching_mode, - use_local_encoder_transformer=args.use_local_encoder_transformer, - downsampling_by_pooling=args.downsampling_by_pooling, - encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, - cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, - cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, - cross_attn_nheads=args.cross_attn_nheads, - eos_id=args.eos_id, - ) - - return LocalEncoder(local_encoder_args) - - -def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: - # First deep copy the original args - local_decoder_args = LocalModelArgs( - dim=args.dim_local_decoder, - n_layers=args.n_layers_local_decoder, - n_heads=args.n_heads_local_decoder, - dim_token_emb=get_decoder_dim_token_emb(args), - dim_patch_emb=args.dim_global, - cross_attn_encoder=False, - cross_attn_decoder=args.cross_attn_decoder, - cross_attn_init_by_pooling=False, # states are already defined - cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, - # Defaults - head_dim=args.head_dim, - max_seqlen=args.max_encoder_seq_length, - dropout=args.dropout, - vocab_size=args.vocab_size + args.pm_size, - norm_eps=args.norm_eps, - patch_size=args.patch_size, - sliding_window=args.local_attention_window_len, - use_rope=args.use_rope, - rope_theta=args.rope_theta, - rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, - init_base_std=args.init_base_std, - init_std_factor=args.init_std_factor, - n_kv_heads=args.n_kv_heads, - attn_impl=args.attn_impl, - attn_bias_type="local_block_causal", - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, - patching_mode=args.patching_mode, - use_local_encoder_transformer=args.use_local_encoder_transformer, - downsampling_by_pooling=args.downsampling_by_pooling, - encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, - cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, - cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, - cross_attn_nheads=args.cross_attn_nheads, - eos_id=args.eos_id, - ) - - return LocalDecoder(local_decoder_args) - - class EmbeddingType(Enum): HASH_TOK = auto() NGRAM = auto() @@ -2381,9 +2009,85 @@ class ByteLatentTransformer( ) # ByteLatent modules - self.local_encoder = create_local_encoder(args) + local_encoder_args = LocalModelArgs( + # Updated args + dim=args.dim_local_encoder, + n_layers=args.n_layers_local_encoder, + n_heads=args.n_heads_local_encoder, + dim_token_emb=get_encoder_dim_token_emb(args), + dim_patch_emb=get_encoder_dim_patch_emb(args), + cross_attn_encoder=args.cross_attn_encoder, + cross_attn_decoder=False, + cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None, + cross_attn_init_by_pooling=args.cross_attn_init_by_pooling, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, + ) + self.local_encoder = LocalEncoder(local_encoder_args) + local_decoder_args = LocalModelArgs( + dim=args.dim_local_decoder, + n_layers=args.n_layers_local_decoder, + n_heads=args.n_heads_local_decoder, + dim_token_emb=get_decoder_dim_token_emb(args), + dim_patch_emb=args.dim_global, + cross_attn_encoder=False, + cross_attn_decoder=args.cross_attn_decoder, + cross_attn_init_by_pooling=False, # states are already defined + cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None, + # Defaults + head_dim=args.head_dim, + max_seqlen=args.max_encoder_seq_length, + dropout=args.dropout, + vocab_size=args.vocab_size + args.pm_size, + norm_eps=args.norm_eps, + patch_size=args.patch_size, + sliding_window=args.local_attention_window_len, + use_rope=args.use_rope, + rope_theta=args.rope_theta, + rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + init_base_std=args.init_base_std, + init_std_factor=args.init_std_factor, + n_kv_heads=args.n_kv_heads, + attn_impl=args.attn_impl, + attn_bias_type="local_block_causal", + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + patching_mode=args.patching_mode, + use_local_encoder_transformer=args.use_local_encoder_transformer, + downsampling_by_pooling=args.downsampling_by_pooling, + encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, + cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder, + cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, + cross_attn_nheads=args.cross_attn_nheads, + eos_id=args.eos_id, + ) + self.global_transformer = create_global_transformer(args) - self.local_decoder = create_local_decoder(args) + self.local_decoder = LocalDecoder(local_decoder_args) self.encoder_hash_tok_embedding = init_embeddings( args, EmbeddingType.HASH_TOK, @@ -2534,21 +2238,7 @@ class ByteLatentTransformer( ) # Downsampling - if not self.cross_attn_encoder: - assert ( - patch_ids.shape[1] == h_encoder.shape[1] - ), f"{patch_ids.shape[1]} != {h_encoder.shape[1]}" - h = downsample( - h_encoder, - patch_lengths.shape[1], - patch_lengths, - patch_ids, - downsampling_by_pooling=self.downsampling_by_pooling, - patch_size=self.patch_size, - ) - else: - # Reshape h_cross - h = h_cross.view(bs, patch_lengths.shape[1], -1) + h = h_cross.view(bs, patch_lengths.shape[1], -1) # Global transformer global_tokens = tokens.new(h.shape[0], h.shape[1]).fill_(self.boe_id) @@ -2615,4 +2305,4 @@ class ByteLatentTransformer( std=emb_std, a=-3 * emb_std, b=3 * emb_std, - ) + ) \ No newline at end of file diff --git a/demo_hf.py b/demo_hf.py index 40d5c7d..f7187cc 100644 --- a/demo_hf.py +++ b/demo_hf.py @@ -56,7 +56,7 @@ def sample_top_p(probs, p): @torch.inference_mode() -def generate_nocache( +def generate( prompts: list[str] | None, *, model: ByteLatentTransformer, @@ -186,9 +186,8 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): # Generate text print("Generating text...") - prompts = [prompt] - outputs = generate_nocache( - prompts, + outputs = generate( + [prompt], model=model, tokenizer=tokenizer, patcher=model.patcher, # Use the model's patcher @@ -197,7 +196,7 @@ def main(prompt: str = "my name is", model_name: str = "blt-1b"): # Decode and print results text_outputs = [tokenizer.decode(t) for t in outputs] - for p, t in zip(prompts, text_outputs): + for p, t in zip([prompt], text_outputs): print(f'Prompt: "{p}"') print(f'Completion: "{t}"') print()