Fix init and repro

This commit is contained in:
Srini Iyer 2025-02-06 20:01:32 +00:00
parent 936d9437be
commit 30f82211c4
6 changed files with 83 additions and 101 deletions

View file

@ -445,7 +445,7 @@ class Attention(nn.Module):
return output
def reset_parameters(self, init_std=None, factor=1.0):
init_std = init_std or (self.dim ** (-0.5))
init_std = init_std or (self.dim ** (-0.5)) / factor
for w in [self.wq, self.wk, self.wv]:
nn.init.trunc_normal_(
@ -459,7 +459,7 @@ class Attention(nn.Module):
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=init_std / factor,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
@ -509,18 +509,16 @@ class FeedForward(nn.Module):
return output
def reset_parameters(self, init_std=None, factor=1.0):
in_init_std = init_std or (self.dim ** (-0.5))
out_init_std = init_std or (self.hidden_dim ** (-0.5))
in_init_std = in_init_std
out_init_std = out_init_std / factor
for w in [self.w1, self.w3]:
nn.init.trunc_normal_(
w.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
in_init_std = init_std or (self.dim ** (-0.5)) / factor
out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
nn.init.trunc_normal_(
self.w1.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
nn.init.trunc_normal_(
self.w2.weight,
mean=0.0,
@ -528,6 +526,13 @@ class FeedForward(nn.Module):
a=-3 * out_init_std,
b=3 * out_init_std,
)
nn.init.trunc_normal_(
self.w3.weight,
mean=0.0,
std=in_init_std,
a=-3 * in_init_std,
b=3 * in_init_std,
)
class TransformerBlock(nn.Module):

View file

@ -463,13 +463,15 @@ def parallelize_model(
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
if distributed_args.selective_activation_checkpointing:
model = checkpoint_wrapper(
model,
context_fn=partial(
create_selective_checkpoint_contexts,
get_default_policy(no_recompute_ops),
),
)
for module in [model.global_transformer, model.local_encoder, model.local_decoder]:
for i in range(len(module.layers)):
module.layers[i] = checkpoint_wrapper(
module.layers[i],
context_fn=partial(
create_selective_checkpoint_contexts,
get_default_policy(no_recompute_ops),
),
)
if distributed_args.compile:
torch._dynamo.config.cache_size_limit = (

View file

@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module):
local_encoder_dim=self.local_encoder.dim,
encoder_hash_byte_group_size=None,
)
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
# Transformer layers
self.layers = nn.ModuleList(
[TransformerBlock(args) for _ in range(args.n_layers)]
)
# Encoder ngram embedding tables
self.encoder_ngram_embedding = None
@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module):
# Output layer
assert args.vocab_size > 0, "vocab_size must be greater than 0"
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
if args.weight_tying:
self.output.weight = self.tok_embeddings.weight
# Patcher module
if args.patch_in_forward:
@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module):
local_encoder_embeds = local_encoder_embeds + ngram_embeds
# Local encoder
h_cross = None
(h_encoder, h_cross), cache_encoder = self.local_encoder(
tokens=local_encoder_tokens,
embeds=local_encoder_embeds,
patch_embeds=h_cross if self.cross_attn_encoder else None,
patch_embeds=None,
cross_mask=cross_attn_mask_enc,
num_patches=patch_lengths.shape[1],
patch_ids=patch_ids,
@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module):
)
return output
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
init_std = init_std or (self.dim ** (-0.5))
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
def init_weights(self):
self.reset_parameters()
self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
for depth, layer in enumerate(self.layers):
factor = {
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
InitStdFactor.DIM_RATIO: self.dim / 4096,
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(self.init_base_std, factor)
self.local_decoder.init_weights(self.init_base_std)
self.global_transformer.init_weights(self.init_base_std)
self.local_encoder.init_weights(self.init_base_std)
self.local_encoder.init_weights()
self.global_transformer.init_weights()
self.local_decoder.init_weights()
emb_std = self.local_encoder.dim ** (-0.5)
for emb in self.encoder_hash_tok_embedding:
nn.init.trunc_normal_(
emb.weight,
mean=0.0,
std=self.init_base_std,
a=-3 * self.init_base_std,
b=3 * self.init_base_std,
std=emb_std,
a=-3 * emb_std,
b=3 * emb_std,
)

View file

@ -78,10 +78,10 @@ class CrossAttention(nn.Module):
# B S D
bsz, seq_len, _ = x.shape
_, slen_kv, _ = kv.shape
x = self.cross_attn_norm_q(x)
x_norm = self.cross_attn_norm_q(x)
kv = self.cross_attn_norm_kv(kv)
xq = self.wq(x)
xq = self.wq(x_norm)
xk = self.wk(kv)
xv = self.wv(kv)
@ -104,7 +104,7 @@ class CrossAttention(nn.Module):
return x + output
def init_weights(self, base_std: float, factor: float = 1.0):
std = base_std * factor
std = base_std or (self.dim ** (-0.5)) / factor
nn.init.trunc_normal_(
self.wq.weight,
@ -130,13 +130,12 @@ class CrossAttention(nn.Module):
b=3 * std,
)
output_std = std / (2**0.5)
nn.init.trunc_normal_(
self.wo.weight,
mean=0.0,
std=output_std,
a=-3 * output_std,
b=3 * output_std,
std=std,
a=-3 * std,
b=3 * std,
)
self.cross_attn_norm_q.reset_parameters()
self.cross_attn_norm_kv.reset_parameters()
@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer):
super().__init__(args)
self.dropout = args.dropout
self.eos_id = args.eos_id
self.dim_token_emb = args.dim_token_emb
self.token_embedding_projection = None
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer):
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
return h, cache
def init_weights(self, init_base_std: float):
def init_weights(self):
super().init_weights()
std = self.dim_token_emb ** (-0.5)
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
self.token_embedding_projection.weight,
mean=0.0,
std=init_base_std,
a=-3 * init_base_std,
b=3 * init_base_std,
std=std,
a=-3 * std,
b=3 * std,
)

View file

@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs):
# Local encoder specific dimensions
dropout: float
vocab_size: int
patch_size: int
patch_size: float
sliding_window: int | None
use_rope: bool
cross_attn_encoder: bool | None
@ -61,6 +61,7 @@ class LocalModelBase(nn.Module):
self.dropout = args.dropout
self.vocab_size = args.vocab_size
self.patch_size = args.patch_size
self.dim_patch_emb = args.dim_patch_emb
self.attn_impl = args.attn_impl
self.sliding_window = args.sliding_window
@ -130,6 +131,7 @@ class LocalModelBase(nn.Module):
def init_weights(self, init_std=None):
self.rope.reset_parameters()
self.norm.reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
nn.init.trunc_normal_(
@ -156,7 +158,16 @@ class LocalModelBase(nn.Module):
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(init_std, factor)
layer.init_weights(None, factor)
if hasattr(self, "output"):
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if self.token_embedding_projection is not None:
nn.init.trunc_normal_(
@ -168,21 +179,13 @@ class LocalModelBase(nn.Module):
)
if self.patch_embedding_projection is not None:
patch_emb_std = self.dim_patch_emb ** (-0.5)
nn.init.trunc_normal_(
self.patch_embedding_projection.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if hasattr(self, "output"):
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
std=patch_emb_std,
a=-3 * patch_emb_std,
b=3 * patch_emb_std,
)
if self.cross_attn_layers is not None:
@ -194,7 +197,7 @@ class LocalModelBase(nn.Module):
InitStdFactor.DISABLED: 1.0,
}[self.init_std_factor]
layer.init_weights(init_std, factor)
layer.init_weights(None, factor)
class LocalEncoder(LocalModelBase):

View file

@ -137,14 +137,25 @@ def get_no_recompute_ops():
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
group_plan: Tuple[int, bool] = []
# Grouping and output seperately
group_plan.append(("tok_embeddings", False))
if isinstance(model_args, LMTransformerArgs):
group_plan.append(("tok_embeddings", False))
# Grouping by layers
for i in range(model_args.n_layers):
group_plan.append((f"layers.{i}", False))
for i in range(model_args.n_layers):
group_plan.append((f"layers.{i}", False))
group_plan.append(("output", True))
group_plan.append(("output", True))
else:
for i in range(model_args.n_layers_local_encoder):
group_plan.append((f"local_encoder.layers.{i}", True))
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
for i in range(model_args.n_layers_local_decoder):
group_plan.append((f"local_decoder.layers.{i}", True))
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
for i in range(model_args.n_layers_global):
group_plan.append((f"global_transformer.layers.{i}", True))
for i in range(len(model_args.encoder_hash_byte_group_size)):
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
return group_plan