mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 13:02:14 +00:00
Fix init and repro (#48)
* Fix init and repro * comment + black --------- Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
936d9437be
commit
aebdc481a8
|
@ -445,7 +445,7 @@ class Attention(nn.Module):
|
|||
return output
|
||||
|
||||
def reset_parameters(self, init_std=None, factor=1.0):
|
||||
init_std = init_std or (self.dim ** (-0.5))
|
||||
init_std = init_std or (self.dim ** (-0.5)) / factor
|
||||
|
||||
for w in [self.wq, self.wk, self.wv]:
|
||||
nn.init.trunc_normal_(
|
||||
|
@ -459,7 +459,7 @@ class Attention(nn.Module):
|
|||
nn.init.trunc_normal_(
|
||||
self.wo.weight,
|
||||
mean=0.0,
|
||||
std=init_std / factor,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
|
@ -509,18 +509,16 @@ class FeedForward(nn.Module):
|
|||
return output
|
||||
|
||||
def reset_parameters(self, init_std=None, factor=1.0):
|
||||
in_init_std = init_std or (self.dim ** (-0.5))
|
||||
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
||||
in_init_std = in_init_std
|
||||
out_init_std = out_init_std / factor
|
||||
for w in [self.w1, self.w3]:
|
||||
nn.init.trunc_normal_(
|
||||
w.weight,
|
||||
mean=0.0,
|
||||
std=in_init_std,
|
||||
a=-3 * in_init_std,
|
||||
b=3 * in_init_std,
|
||||
)
|
||||
in_init_std = init_std or (self.dim ** (-0.5)) / factor
|
||||
out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
|
||||
|
||||
nn.init.trunc_normal_(
|
||||
self.w1.weight,
|
||||
mean=0.0,
|
||||
std=in_init_std,
|
||||
a=-3 * in_init_std,
|
||||
b=3 * in_init_std,
|
||||
)
|
||||
nn.init.trunc_normal_(
|
||||
self.w2.weight,
|
||||
mean=0.0,
|
||||
|
@ -528,6 +526,13 @@ class FeedForward(nn.Module):
|
|||
a=-3 * out_init_std,
|
||||
b=3 * out_init_std,
|
||||
)
|
||||
nn.init.trunc_normal_(
|
||||
self.w3.weight,
|
||||
mean=0.0,
|
||||
std=in_init_std,
|
||||
a=-3 * in_init_std,
|
||||
b=3 * in_init_std,
|
||||
)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
|
|
@ -463,13 +463,21 @@ def parallelize_model(
|
|||
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
||||
|
||||
if distributed_args.selective_activation_checkpointing:
|
||||
model = checkpoint_wrapper(
|
||||
model,
|
||||
context_fn=partial(
|
||||
create_selective_checkpoint_contexts,
|
||||
get_default_policy(no_recompute_ops),
|
||||
),
|
||||
)
|
||||
# only works for blt models
|
||||
# assuming that entropy models will not use checkpointing
|
||||
for module in [
|
||||
model.global_transformer,
|
||||
model.local_encoder,
|
||||
model.local_decoder,
|
||||
]:
|
||||
for i in range(len(module.layers)):
|
||||
module.layers[i] = checkpoint_wrapper(
|
||||
module.layers[i],
|
||||
context_fn=partial(
|
||||
create_selective_checkpoint_contexts,
|
||||
get_default_policy(no_recompute_ops),
|
||||
),
|
||||
)
|
||||
|
||||
if distributed_args.compile:
|
||||
torch._dynamo.config.cache_size_limit = (
|
||||
|
|
|
@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module):
|
|||
local_encoder_dim=self.local_encoder.dim,
|
||||
encoder_hash_byte_group_size=None,
|
||||
)
|
||||
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
||||
|
||||
# Transformer layers
|
||||
self.layers = nn.ModuleList(
|
||||
[TransformerBlock(args) for _ in range(args.n_layers)]
|
||||
)
|
||||
|
||||
# Encoder ngram embedding tables
|
||||
self.encoder_ngram_embedding = None
|
||||
|
@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module):
|
|||
|
||||
# Output layer
|
||||
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
||||
if args.weight_tying:
|
||||
self.output.weight = self.tok_embeddings.weight
|
||||
|
||||
# Patcher module
|
||||
if args.patch_in_forward:
|
||||
|
@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module):
|
|||
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
||||
|
||||
# Local encoder
|
||||
h_cross = None
|
||||
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
||||
tokens=local_encoder_tokens,
|
||||
embeds=local_encoder_embeds,
|
||||
patch_embeds=h_cross if self.cross_attn_encoder else None,
|
||||
patch_embeds=None,
|
||||
cross_mask=cross_attn_mask_enc,
|
||||
num_patches=patch_lengths.shape[1],
|
||||
patch_ids=patch_ids,
|
||||
|
@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module):
|
|||
)
|
||||
return output
|
||||
|
||||
def reset_parameters(self, init_std=None):
|
||||
# Either use fixed base std or sqrt model dim
|
||||
init_std = init_std or (self.dim ** (-0.5))
|
||||
nn.init.trunc_normal_(
|
||||
self.tok_embeddings.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
if not self.weight_tying:
|
||||
nn.init.trunc_normal_(
|
||||
self.output.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
self.reset_parameters()
|
||||
self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
|
||||
for depth, layer in enumerate(self.layers):
|
||||
factor = {
|
||||
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
||||
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
||||
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
||||
InitStdFactor.DISABLED: 1.0,
|
||||
}[self.init_std_factor]
|
||||
|
||||
layer.init_weights(self.init_base_std, factor)
|
||||
|
||||
self.local_decoder.init_weights(self.init_base_std)
|
||||
self.global_transformer.init_weights(self.init_base_std)
|
||||
self.local_encoder.init_weights(self.init_base_std)
|
||||
self.local_encoder.init_weights()
|
||||
self.global_transformer.init_weights()
|
||||
self.local_decoder.init_weights()
|
||||
|
||||
emb_std = self.local_encoder.dim ** (-0.5)
|
||||
for emb in self.encoder_hash_tok_embedding:
|
||||
nn.init.trunc_normal_(
|
||||
emb.weight,
|
||||
mean=0.0,
|
||||
std=self.init_base_std,
|
||||
a=-3 * self.init_base_std,
|
||||
b=3 * self.init_base_std,
|
||||
std=emb_std,
|
||||
a=-3 * emb_std,
|
||||
b=3 * emb_std,
|
||||
)
|
||||
|
|
|
@ -78,10 +78,10 @@ class CrossAttention(nn.Module):
|
|||
# B S D
|
||||
bsz, seq_len, _ = x.shape
|
||||
_, slen_kv, _ = kv.shape
|
||||
x = self.cross_attn_norm_q(x)
|
||||
x_norm = self.cross_attn_norm_q(x)
|
||||
kv = self.cross_attn_norm_kv(kv)
|
||||
|
||||
xq = self.wq(x)
|
||||
xq = self.wq(x_norm)
|
||||
xk = self.wk(kv)
|
||||
xv = self.wv(kv)
|
||||
|
||||
|
@ -104,7 +104,7 @@ class CrossAttention(nn.Module):
|
|||
return x + output
|
||||
|
||||
def init_weights(self, base_std: float, factor: float = 1.0):
|
||||
std = base_std * factor
|
||||
std = base_std or (self.dim ** (-0.5)) / factor
|
||||
|
||||
nn.init.trunc_normal_(
|
||||
self.wq.weight,
|
||||
|
@ -130,13 +130,12 @@ class CrossAttention(nn.Module):
|
|||
b=3 * std,
|
||||
)
|
||||
|
||||
output_std = std / (2**0.5)
|
||||
nn.init.trunc_normal_(
|
||||
self.wo.weight,
|
||||
mean=0.0,
|
||||
std=output_std,
|
||||
a=-3 * output_std,
|
||||
b=3 * output_std,
|
||||
std=std,
|
||||
a=-3 * std,
|
||||
b=3 * std,
|
||||
)
|
||||
self.cross_attn_norm_q.reset_parameters()
|
||||
self.cross_attn_norm_kv.reset_parameters()
|
||||
|
@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer):
|
|||
super().__init__(args)
|
||||
self.dropout = args.dropout
|
||||
self.eos_id = args.eos_id
|
||||
self.dim_token_emb = args.dim_token_emb
|
||||
|
||||
self.token_embedding_projection = None
|
||||
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
||||
|
@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer):
|
|||
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
||||
return h, cache
|
||||
|
||||
def init_weights(self, init_base_std: float):
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
std = self.dim_token_emb ** (-0.5)
|
||||
if self.token_embedding_projection is not None:
|
||||
nn.init.trunc_normal_(
|
||||
self.token_embedding_projection.weight,
|
||||
mean=0.0,
|
||||
std=init_base_std,
|
||||
a=-3 * init_base_std,
|
||||
b=3 * init_base_std,
|
||||
std=std,
|
||||
a=-3 * std,
|
||||
b=3 * std,
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs):
|
|||
# Local encoder specific dimensions
|
||||
dropout: float
|
||||
vocab_size: int
|
||||
patch_size: int
|
||||
patch_size: float
|
||||
sliding_window: int | None
|
||||
use_rope: bool
|
||||
cross_attn_encoder: bool | None
|
||||
|
@ -61,6 +61,7 @@ class LocalModelBase(nn.Module):
|
|||
self.dropout = args.dropout
|
||||
self.vocab_size = args.vocab_size
|
||||
self.patch_size = args.patch_size
|
||||
self.dim_patch_emb = args.dim_patch_emb
|
||||
|
||||
self.attn_impl = args.attn_impl
|
||||
self.sliding_window = args.sliding_window
|
||||
|
@ -130,6 +131,7 @@ class LocalModelBase(nn.Module):
|
|||
|
||||
def init_weights(self, init_std=None):
|
||||
self.rope.reset_parameters()
|
||||
self.norm.reset_parameters()
|
||||
|
||||
init_std = init_std or (self.dim ** (-0.5))
|
||||
nn.init.trunc_normal_(
|
||||
|
@ -156,7 +158,16 @@ class LocalModelBase(nn.Module):
|
|||
InitStdFactor.DISABLED: 1.0,
|
||||
}[self.init_std_factor]
|
||||
|
||||
layer.init_weights(init_std, factor)
|
||||
layer.init_weights(None, factor)
|
||||
|
||||
if hasattr(self, "output"):
|
||||
nn.init.trunc_normal_(
|
||||
self.output.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
|
||||
if self.token_embedding_projection is not None:
|
||||
nn.init.trunc_normal_(
|
||||
|
@ -168,21 +179,13 @@ class LocalModelBase(nn.Module):
|
|||
)
|
||||
|
||||
if self.patch_embedding_projection is not None:
|
||||
patch_emb_std = self.dim_patch_emb ** (-0.5)
|
||||
nn.init.trunc_normal_(
|
||||
self.patch_embedding_projection.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
)
|
||||
|
||||
if hasattr(self, "output"):
|
||||
nn.init.trunc_normal_(
|
||||
self.output.weight,
|
||||
mean=0.0,
|
||||
std=init_std,
|
||||
a=-3 * init_std,
|
||||
b=3 * init_std,
|
||||
std=patch_emb_std,
|
||||
a=-3 * patch_emb_std,
|
||||
b=3 * patch_emb_std,
|
||||
)
|
||||
|
||||
if self.cross_attn_layers is not None:
|
||||
|
@ -194,7 +197,7 @@ class LocalModelBase(nn.Module):
|
|||
InitStdFactor.DISABLED: 1.0,
|
||||
}[self.init_std_factor]
|
||||
|
||||
layer.init_weights(init_std, factor)
|
||||
layer.init_weights(None, factor)
|
||||
|
||||
|
||||
class LocalEncoder(LocalModelBase):
|
||||
|
|
|
@ -137,14 +137,25 @@ def get_no_recompute_ops():
|
|||
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
||||
group_plan: Tuple[int, bool] = []
|
||||
|
||||
# Grouping and output seperately
|
||||
group_plan.append(("tok_embeddings", False))
|
||||
if isinstance(model_args, LMTransformerArgs):
|
||||
group_plan.append(("tok_embeddings", False))
|
||||
|
||||
# Grouping by layers
|
||||
for i in range(model_args.n_layers):
|
||||
group_plan.append((f"layers.{i}", False))
|
||||
for i in range(model_args.n_layers):
|
||||
group_plan.append((f"layers.{i}", False))
|
||||
|
||||
group_plan.append(("output", True))
|
||||
group_plan.append(("output", True))
|
||||
else:
|
||||
for i in range(model_args.n_layers_local_encoder):
|
||||
group_plan.append((f"local_encoder.layers.{i}", True))
|
||||
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
|
||||
for i in range(model_args.n_layers_local_decoder):
|
||||
group_plan.append((f"local_decoder.layers.{i}", True))
|
||||
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
|
||||
for i in range(model_args.n_layers_global):
|
||||
group_plan.append((f"global_transformer.layers.{i}", True))
|
||||
|
||||
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
||||
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
|
||||
|
||||
return group_plan
|
||||
|
||||
|
|
Loading…
Reference in a new issue