diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index 87d7334..217224f 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -445,7 +445,7 @@ class Attention(nn.Module): return output def reset_parameters(self, init_std=None, factor=1.0): - init_std = init_std or (self.dim ** (-0.5)) + init_std = init_std or (self.dim ** (-0.5)) / factor for w in [self.wq, self.wk, self.wv]: nn.init.trunc_normal_( @@ -459,7 +459,7 @@ class Attention(nn.Module): nn.init.trunc_normal_( self.wo.weight, mean=0.0, - std=init_std / factor, + std=init_std, a=-3 * init_std, b=3 * init_std, ) @@ -509,18 +509,16 @@ class FeedForward(nn.Module): return output def reset_parameters(self, init_std=None, factor=1.0): - in_init_std = init_std or (self.dim ** (-0.5)) - out_init_std = init_std or (self.hidden_dim ** (-0.5)) - in_init_std = in_init_std - out_init_std = out_init_std / factor - for w in [self.w1, self.w3]: - nn.init.trunc_normal_( - w.weight, - mean=0.0, - std=in_init_std, - a=-3 * in_init_std, - b=3 * in_init_std, - ) + in_init_std = init_std or (self.dim ** (-0.5)) / factor + out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor + + nn.init.trunc_normal_( + self.w1.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) nn.init.trunc_normal_( self.w2.weight, mean=0.0, @@ -528,6 +526,13 @@ class FeedForward(nn.Module): a=-3 * out_init_std, b=3 * out_init_std, ) + nn.init.trunc_normal_( + self.w3.weight, + mean=0.0, + std=in_init_std, + a=-3 * in_init_std, + b=3 * in_init_std, + ) class TransformerBlock(nn.Module): diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index 168cb7c..298d3d4 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -463,13 +463,21 @@ def parallelize_model( raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}") if distributed_args.selective_activation_checkpointing: - model = checkpoint_wrapper( - model, - context_fn=partial( - create_selective_checkpoint_contexts, - get_default_policy(no_recompute_ops), - ), - ) + # only works for blt models + # assuming that entropy models will not use checkpointing + for module in [ + model.global_transformer, + model.local_encoder, + model.local_decoder, + ]: + for i in range(len(module.layers)): + module.layers[i] = checkpoint_wrapper( + module.layers[i], + context_fn=partial( + create_selective_checkpoint_contexts, + get_default_policy(no_recompute_ops), + ), + ) if distributed_args.compile: torch._dynamo.config.cache_size_limit = ( diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 53a3be6..b8586fe 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module): local_encoder_dim=self.local_encoder.dim, encoder_hash_byte_group_size=None, ) - self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) - - # Transformer layers - self.layers = nn.ModuleList( - [TransformerBlock(args) for _ in range(args.n_layers)] - ) # Encoder ngram embedding tables self.encoder_ngram_embedding = None @@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module): # Output layer assert args.vocab_size > 0, "vocab_size must be greater than 0" - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) - if args.weight_tying: - self.output.weight = self.tok_embeddings.weight # Patcher module if args.patch_in_forward: @@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module): local_encoder_embeds = local_encoder_embeds + ngram_embeds # Local encoder - h_cross = None (h_encoder, h_cross), cache_encoder = self.local_encoder( tokens=local_encoder_tokens, embeds=local_encoder_embeds, - patch_embeds=h_cross if self.cross_attn_encoder else None, + patch_embeds=None, cross_mask=cross_attn_mask_enc, num_patches=patch_lengths.shape[1], patch_ids=patch_ids, @@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module): ) return output - def reset_parameters(self, init_std=None): - # Either use fixed base std or sqrt model dim - init_std = init_std or (self.dim ** (-0.5)) - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - if not self.weight_tying: - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - def init_weights(self): - self.reset_parameters() - self.init_base_std = self.init_base_std or (self.dim ** (-0.5)) - for depth, layer in enumerate(self.layers): - factor = { - InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5, - InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5, - InitStdFactor.DIM_RATIO: self.dim / 4096, - InitStdFactor.DISABLED: 1.0, - }[self.init_std_factor] - - layer.init_weights(self.init_base_std, factor) - - self.local_decoder.init_weights(self.init_base_std) - self.global_transformer.init_weights(self.init_base_std) - self.local_encoder.init_weights(self.init_base_std) + self.local_encoder.init_weights() + self.global_transformer.init_weights() + self.local_decoder.init_weights() + emb_std = self.local_encoder.dim ** (-0.5) for emb in self.encoder_hash_tok_embedding: nn.init.trunc_normal_( emb.weight, mean=0.0, - std=self.init_base_std, - a=-3 * self.init_base_std, - b=3 * self.init_base_std, + std=emb_std, + a=-3 * emb_std, + b=3 * emb_std, ) diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index 21c3f0c..d91f49f 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -78,10 +78,10 @@ class CrossAttention(nn.Module): # B S D bsz, seq_len, _ = x.shape _, slen_kv, _ = kv.shape - x = self.cross_attn_norm_q(x) + x_norm = self.cross_attn_norm_q(x) kv = self.cross_attn_norm_kv(kv) - xq = self.wq(x) + xq = self.wq(x_norm) xk = self.wk(kv) xv = self.wv(kv) @@ -104,7 +104,7 @@ class CrossAttention(nn.Module): return x + output def init_weights(self, base_std: float, factor: float = 1.0): - std = base_std * factor + std = base_std or (self.dim ** (-0.5)) / factor nn.init.trunc_normal_( self.wq.weight, @@ -130,13 +130,12 @@ class CrossAttention(nn.Module): b=3 * std, ) - output_std = std / (2**0.5) nn.init.trunc_normal_( self.wo.weight, mean=0.0, - std=output_std, - a=-3 * output_std, - b=3 * output_std, + std=std, + a=-3 * std, + b=3 * std, ) self.cross_attn_norm_q.reset_parameters() self.cross_attn_norm_kv.reset_parameters() @@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer): super().__init__(args) self.dropout = args.dropout self.eos_id = args.eos_id + self.dim_token_emb = args.dim_token_emb self.token_embedding_projection = None if args.dim_token_emb is not None and args.dim_token_emb != self.dim: @@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer): h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl) return h, cache - def init_weights(self, init_base_std: float): + def init_weights(self): super().init_weights() + std = self.dim_token_emb ** (-0.5) if self.token_embedding_projection is not None: nn.init.trunc_normal_( self.token_embedding_projection.weight, mean=0.0, - std=init_base_std, - a=-3 * init_base_std, - b=3 * init_base_std, + std=std, + a=-3 * std, + b=3 * std, ) diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index d0e24c0..d6eab14 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs): # Local encoder specific dimensions dropout: float vocab_size: int - patch_size: int + patch_size: float sliding_window: int | None use_rope: bool cross_attn_encoder: bool | None @@ -61,6 +61,7 @@ class LocalModelBase(nn.Module): self.dropout = args.dropout self.vocab_size = args.vocab_size self.patch_size = args.patch_size + self.dim_patch_emb = args.dim_patch_emb self.attn_impl = args.attn_impl self.sliding_window = args.sliding_window @@ -130,6 +131,7 @@ class LocalModelBase(nn.Module): def init_weights(self, init_std=None): self.rope.reset_parameters() + self.norm.reset_parameters() init_std = init_std or (self.dim ** (-0.5)) nn.init.trunc_normal_( @@ -156,7 +158,16 @@ class LocalModelBase(nn.Module): InitStdFactor.DISABLED: 1.0, }[self.init_std_factor] - layer.init_weights(init_std, factor) + layer.init_weights(None, factor) + + if hasattr(self, "output"): + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) if self.token_embedding_projection is not None: nn.init.trunc_normal_( @@ -168,21 +179,13 @@ class LocalModelBase(nn.Module): ) if self.patch_embedding_projection is not None: + patch_emb_std = self.dim_patch_emb ** (-0.5) nn.init.trunc_normal_( self.patch_embedding_projection.weight, mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) - - if hasattr(self, "output"): - nn.init.trunc_normal_( - self.output.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, + std=patch_emb_std, + a=-3 * patch_emb_std, + b=3 * patch_emb_std, ) if self.cross_attn_layers is not None: @@ -194,7 +197,7 @@ class LocalModelBase(nn.Module): InitStdFactor.DISABLED: 1.0, }[self.init_std_factor] - layer.init_weights(init_std, factor) + layer.init_weights(None, factor) class LocalEncoder(LocalModelBase): diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 92c5ff5..ad8affa 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -137,14 +137,25 @@ def get_no_recompute_ops(): def build_fsdp_grouping_plan(model_args: LMTransformerArgs): group_plan: Tuple[int, bool] = [] - # Grouping and output seperately - group_plan.append(("tok_embeddings", False)) + if isinstance(model_args, LMTransformerArgs): + group_plan.append(("tok_embeddings", False)) - # Grouping by layers - for i in range(model_args.n_layers): - group_plan.append((f"layers.{i}", False)) + for i in range(model_args.n_layers): + group_plan.append((f"layers.{i}", False)) - group_plan.append(("output", True)) + group_plan.append(("output", True)) + else: + for i in range(model_args.n_layers_local_encoder): + group_plan.append((f"local_encoder.layers.{i}", True)) + group_plan.append((f"local_encoder.cross_attn_layers.{i}", True)) + for i in range(model_args.n_layers_local_decoder): + group_plan.append((f"local_decoder.layers.{i}", True)) + group_plan.append((f"local_decoder.cross_attn_layers.{i}", True)) + for i in range(model_args.n_layers_global): + group_plan.append((f"global_transformer.layers.{i}", True)) + + for i in range(len(model_args.encoder_hash_byte_group_size)): + group_plan.append((f"encoder_hash_tok_embedding.{i}", True)) return group_plan