mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-23 05:22:16 +00:00
Fix init and repro
This commit is contained in:
parent
936d9437be
commit
30f82211c4
|
@ -445,7 +445,7 @@ class Attention(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def reset_parameters(self, init_std=None, factor=1.0):
|
def reset_parameters(self, init_std=None, factor=1.0):
|
||||||
init_std = init_std or (self.dim ** (-0.5))
|
init_std = init_std or (self.dim ** (-0.5)) / factor
|
||||||
|
|
||||||
for w in [self.wq, self.wk, self.wv]:
|
for w in [self.wq, self.wk, self.wv]:
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
|
@ -459,7 +459,7 @@ class Attention(nn.Module):
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.wo.weight,
|
self.wo.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
std=init_std / factor,
|
std=init_std,
|
||||||
a=-3 * init_std,
|
a=-3 * init_std,
|
||||||
b=3 * init_std,
|
b=3 * init_std,
|
||||||
)
|
)
|
||||||
|
@ -509,18 +509,16 @@ class FeedForward(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def reset_parameters(self, init_std=None, factor=1.0):
|
def reset_parameters(self, init_std=None, factor=1.0):
|
||||||
in_init_std = init_std or (self.dim ** (-0.5))
|
in_init_std = init_std or (self.dim ** (-0.5)) / factor
|
||||||
out_init_std = init_std or (self.hidden_dim ** (-0.5))
|
out_init_std = init_std or (self.hidden_dim ** (-0.5)) / factor
|
||||||
in_init_std = in_init_std
|
|
||||||
out_init_std = out_init_std / factor
|
nn.init.trunc_normal_(
|
||||||
for w in [self.w1, self.w3]:
|
self.w1.weight,
|
||||||
nn.init.trunc_normal_(
|
mean=0.0,
|
||||||
w.weight,
|
std=in_init_std,
|
||||||
mean=0.0,
|
a=-3 * in_init_std,
|
||||||
std=in_init_std,
|
b=3 * in_init_std,
|
||||||
a=-3 * in_init_std,
|
)
|
||||||
b=3 * in_init_std,
|
|
||||||
)
|
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.w2.weight,
|
self.w2.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
|
@ -528,6 +526,13 @@ class FeedForward(nn.Module):
|
||||||
a=-3 * out_init_std,
|
a=-3 * out_init_std,
|
||||||
b=3 * out_init_std,
|
b=3 * out_init_std,
|
||||||
)
|
)
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.w3.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=in_init_std,
|
||||||
|
a=-3 * in_init_std,
|
||||||
|
b=3 * in_init_std,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformerBlock(nn.Module):
|
class TransformerBlock(nn.Module):
|
||||||
|
|
|
@ -463,13 +463,15 @@ def parallelize_model(
|
||||||
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
raise ValueError(f"Invalid fsdp_type: {distributed_args.fsdp_type}")
|
||||||
|
|
||||||
if distributed_args.selective_activation_checkpointing:
|
if distributed_args.selective_activation_checkpointing:
|
||||||
model = checkpoint_wrapper(
|
for module in [model.global_transformer, model.local_encoder, model.local_decoder]:
|
||||||
model,
|
for i in range(len(module.layers)):
|
||||||
context_fn=partial(
|
module.layers[i] = checkpoint_wrapper(
|
||||||
create_selective_checkpoint_contexts,
|
module.layers[i],
|
||||||
get_default_policy(no_recompute_ops),
|
context_fn=partial(
|
||||||
),
|
create_selective_checkpoint_contexts,
|
||||||
)
|
get_default_policy(no_recompute_ops),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if distributed_args.compile:
|
if distributed_args.compile:
|
||||||
torch._dynamo.config.cache_size_limit = (
|
torch._dynamo.config.cache_size_limit = (
|
||||||
|
|
|
@ -825,12 +825,6 @@ class ByteLatentTransformer(nn.Module):
|
||||||
local_encoder_dim=self.local_encoder.dim,
|
local_encoder_dim=self.local_encoder.dim,
|
||||||
encoder_hash_byte_group_size=None,
|
encoder_hash_byte_group_size=None,
|
||||||
)
|
)
|
||||||
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
|
|
||||||
|
|
||||||
# Transformer layers
|
|
||||||
self.layers = nn.ModuleList(
|
|
||||||
[TransformerBlock(args) for _ in range(args.n_layers)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Encoder ngram embedding tables
|
# Encoder ngram embedding tables
|
||||||
self.encoder_ngram_embedding = None
|
self.encoder_ngram_embedding = None
|
||||||
|
@ -848,9 +842,6 @@ class ByteLatentTransformer(nn.Module):
|
||||||
|
|
||||||
# Output layer
|
# Output layer
|
||||||
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
assert args.vocab_size > 0, "vocab_size must be greater than 0"
|
||||||
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
|
|
||||||
if args.weight_tying:
|
|
||||||
self.output.weight = self.tok_embeddings.weight
|
|
||||||
|
|
||||||
# Patcher module
|
# Patcher module
|
||||||
if args.patch_in_forward:
|
if args.patch_in_forward:
|
||||||
|
@ -954,11 +945,10 @@ class ByteLatentTransformer(nn.Module):
|
||||||
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
local_encoder_embeds = local_encoder_embeds + ngram_embeds
|
||||||
|
|
||||||
# Local encoder
|
# Local encoder
|
||||||
h_cross = None
|
|
||||||
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
(h_encoder, h_cross), cache_encoder = self.local_encoder(
|
||||||
tokens=local_encoder_tokens,
|
tokens=local_encoder_tokens,
|
||||||
embeds=local_encoder_embeds,
|
embeds=local_encoder_embeds,
|
||||||
patch_embeds=h_cross if self.cross_attn_encoder else None,
|
patch_embeds=None,
|
||||||
cross_mask=cross_attn_mask_enc,
|
cross_mask=cross_attn_mask_enc,
|
||||||
num_patches=patch_lengths.shape[1],
|
num_patches=patch_lengths.shape[1],
|
||||||
patch_ids=patch_ids,
|
patch_ids=patch_ids,
|
||||||
|
@ -1033,47 +1023,17 @@ class ByteLatentTransformer(nn.Module):
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def reset_parameters(self, init_std=None):
|
|
||||||
# Either use fixed base std or sqrt model dim
|
|
||||||
init_std = init_std or (self.dim ** (-0.5))
|
|
||||||
nn.init.trunc_normal_(
|
|
||||||
self.tok_embeddings.weight,
|
|
||||||
mean=0.0,
|
|
||||||
std=init_std,
|
|
||||||
a=-3 * init_std,
|
|
||||||
b=3 * init_std,
|
|
||||||
)
|
|
||||||
if not self.weight_tying:
|
|
||||||
nn.init.trunc_normal_(
|
|
||||||
self.output.weight,
|
|
||||||
mean=0.0,
|
|
||||||
std=init_std,
|
|
||||||
a=-3 * init_std,
|
|
||||||
b=3 * init_std,
|
|
||||||
)
|
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.reset_parameters()
|
self.local_encoder.init_weights()
|
||||||
self.init_base_std = self.init_base_std or (self.dim ** (-0.5))
|
self.global_transformer.init_weights()
|
||||||
for depth, layer in enumerate(self.layers):
|
self.local_decoder.init_weights()
|
||||||
factor = {
|
|
||||||
InitStdFactor.CURRENT_DEPTH: (2 * (depth + 1)) ** 0.5,
|
|
||||||
InitStdFactor.GLOBAL_DEPTH: (2 * (len(self.layers) + 1)) ** 0.5,
|
|
||||||
InitStdFactor.DIM_RATIO: self.dim / 4096,
|
|
||||||
InitStdFactor.DISABLED: 1.0,
|
|
||||||
}[self.init_std_factor]
|
|
||||||
|
|
||||||
layer.init_weights(self.init_base_std, factor)
|
|
||||||
|
|
||||||
self.local_decoder.init_weights(self.init_base_std)
|
|
||||||
self.global_transformer.init_weights(self.init_base_std)
|
|
||||||
self.local_encoder.init_weights(self.init_base_std)
|
|
||||||
|
|
||||||
|
emb_std = self.local_encoder.dim ** (-0.5)
|
||||||
for emb in self.encoder_hash_tok_embedding:
|
for emb in self.encoder_hash_tok_embedding:
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
emb.weight,
|
emb.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
std=self.init_base_std,
|
std=emb_std,
|
||||||
a=-3 * self.init_base_std,
|
a=-3 * emb_std,
|
||||||
b=3 * self.init_base_std,
|
b=3 * emb_std,
|
||||||
)
|
)
|
||||||
|
|
|
@ -78,10 +78,10 @@ class CrossAttention(nn.Module):
|
||||||
# B S D
|
# B S D
|
||||||
bsz, seq_len, _ = x.shape
|
bsz, seq_len, _ = x.shape
|
||||||
_, slen_kv, _ = kv.shape
|
_, slen_kv, _ = kv.shape
|
||||||
x = self.cross_attn_norm_q(x)
|
x_norm = self.cross_attn_norm_q(x)
|
||||||
kv = self.cross_attn_norm_kv(kv)
|
kv = self.cross_attn_norm_kv(kv)
|
||||||
|
|
||||||
xq = self.wq(x)
|
xq = self.wq(x_norm)
|
||||||
xk = self.wk(kv)
|
xk = self.wk(kv)
|
||||||
xv = self.wv(kv)
|
xv = self.wv(kv)
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ class CrossAttention(nn.Module):
|
||||||
return x + output
|
return x + output
|
||||||
|
|
||||||
def init_weights(self, base_std: float, factor: float = 1.0):
|
def init_weights(self, base_std: float, factor: float = 1.0):
|
||||||
std = base_std * factor
|
std = base_std or (self.dim ** (-0.5)) / factor
|
||||||
|
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.wq.weight,
|
self.wq.weight,
|
||||||
|
@ -130,13 +130,12 @@ class CrossAttention(nn.Module):
|
||||||
b=3 * std,
|
b=3 * std,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_std = std / (2**0.5)
|
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.wo.weight,
|
self.wo.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
std=output_std,
|
std=std,
|
||||||
a=-3 * output_std,
|
a=-3 * std,
|
||||||
b=3 * output_std,
|
b=3 * std,
|
||||||
)
|
)
|
||||||
self.cross_attn_norm_q.reset_parameters()
|
self.cross_attn_norm_q.reset_parameters()
|
||||||
self.cross_attn_norm_kv.reset_parameters()
|
self.cross_attn_norm_kv.reset_parameters()
|
||||||
|
@ -147,6 +146,7 @@ class GlobalTransformer(BaseTransformer):
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
self.dropout = args.dropout
|
self.dropout = args.dropout
|
||||||
self.eos_id = args.eos_id
|
self.eos_id = args.eos_id
|
||||||
|
self.dim_token_emb = args.dim_token_emb
|
||||||
|
|
||||||
self.token_embedding_projection = None
|
self.token_embedding_projection = None
|
||||||
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
|
||||||
|
@ -192,13 +192,14 @@ class GlobalTransformer(BaseTransformer):
|
||||||
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
|
||||||
return h, cache
|
return h, cache
|
||||||
|
|
||||||
def init_weights(self, init_base_std: float):
|
def init_weights(self):
|
||||||
super().init_weights()
|
super().init_weights()
|
||||||
|
std = self.dim_token_emb ** (-0.5)
|
||||||
if self.token_embedding_projection is not None:
|
if self.token_embedding_projection is not None:
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.token_embedding_projection.weight,
|
self.token_embedding_projection.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
std=init_base_std,
|
std=std,
|
||||||
a=-3 * init_base_std,
|
a=-3 * std,
|
||||||
b=3 * init_base_std,
|
b=3 * std,
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class LocalModelArgs(BaseTransformerArgs):
|
||||||
# Local encoder specific dimensions
|
# Local encoder specific dimensions
|
||||||
dropout: float
|
dropout: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
patch_size: int
|
patch_size: float
|
||||||
sliding_window: int | None
|
sliding_window: int | None
|
||||||
use_rope: bool
|
use_rope: bool
|
||||||
cross_attn_encoder: bool | None
|
cross_attn_encoder: bool | None
|
||||||
|
@ -61,6 +61,7 @@ class LocalModelBase(nn.Module):
|
||||||
self.dropout = args.dropout
|
self.dropout = args.dropout
|
||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
self.patch_size = args.patch_size
|
self.patch_size = args.patch_size
|
||||||
|
self.dim_patch_emb = args.dim_patch_emb
|
||||||
|
|
||||||
self.attn_impl = args.attn_impl
|
self.attn_impl = args.attn_impl
|
||||||
self.sliding_window = args.sliding_window
|
self.sliding_window = args.sliding_window
|
||||||
|
@ -130,6 +131,7 @@ class LocalModelBase(nn.Module):
|
||||||
|
|
||||||
def init_weights(self, init_std=None):
|
def init_weights(self, init_std=None):
|
||||||
self.rope.reset_parameters()
|
self.rope.reset_parameters()
|
||||||
|
self.norm.reset_parameters()
|
||||||
|
|
||||||
init_std = init_std or (self.dim ** (-0.5))
|
init_std = init_std or (self.dim ** (-0.5))
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
|
@ -156,7 +158,16 @@ class LocalModelBase(nn.Module):
|
||||||
InitStdFactor.DISABLED: 1.0,
|
InitStdFactor.DISABLED: 1.0,
|
||||||
}[self.init_std_factor]
|
}[self.init_std_factor]
|
||||||
|
|
||||||
layer.init_weights(init_std, factor)
|
layer.init_weights(None, factor)
|
||||||
|
|
||||||
|
if hasattr(self, "output"):
|
||||||
|
nn.init.trunc_normal_(
|
||||||
|
self.output.weight,
|
||||||
|
mean=0.0,
|
||||||
|
std=init_std,
|
||||||
|
a=-3 * init_std,
|
||||||
|
b=3 * init_std,
|
||||||
|
)
|
||||||
|
|
||||||
if self.token_embedding_projection is not None:
|
if self.token_embedding_projection is not None:
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
|
@ -168,21 +179,13 @@ class LocalModelBase(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.patch_embedding_projection is not None:
|
if self.patch_embedding_projection is not None:
|
||||||
|
patch_emb_std = self.dim_patch_emb ** (-0.5)
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
self.patch_embedding_projection.weight,
|
self.patch_embedding_projection.weight,
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
std=init_std,
|
std=patch_emb_std,
|
||||||
a=-3 * init_std,
|
a=-3 * patch_emb_std,
|
||||||
b=3 * init_std,
|
b=3 * patch_emb_std,
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(self, "output"):
|
|
||||||
nn.init.trunc_normal_(
|
|
||||||
self.output.weight,
|
|
||||||
mean=0.0,
|
|
||||||
std=init_std,
|
|
||||||
a=-3 * init_std,
|
|
||||||
b=3 * init_std,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cross_attn_layers is not None:
|
if self.cross_attn_layers is not None:
|
||||||
|
@ -194,7 +197,7 @@ class LocalModelBase(nn.Module):
|
||||||
InitStdFactor.DISABLED: 1.0,
|
InitStdFactor.DISABLED: 1.0,
|
||||||
}[self.init_std_factor]
|
}[self.init_std_factor]
|
||||||
|
|
||||||
layer.init_weights(init_std, factor)
|
layer.init_weights(None, factor)
|
||||||
|
|
||||||
|
|
||||||
class LocalEncoder(LocalModelBase):
|
class LocalEncoder(LocalModelBase):
|
||||||
|
|
|
@ -137,14 +137,25 @@ def get_no_recompute_ops():
|
||||||
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
def build_fsdp_grouping_plan(model_args: LMTransformerArgs):
|
||||||
group_plan: Tuple[int, bool] = []
|
group_plan: Tuple[int, bool] = []
|
||||||
|
|
||||||
# Grouping and output seperately
|
if isinstance(model_args, LMTransformerArgs):
|
||||||
group_plan.append(("tok_embeddings", False))
|
group_plan.append(("tok_embeddings", False))
|
||||||
|
|
||||||
# Grouping by layers
|
for i in range(model_args.n_layers):
|
||||||
for i in range(model_args.n_layers):
|
group_plan.append((f"layers.{i}", False))
|
||||||
group_plan.append((f"layers.{i}", False))
|
|
||||||
|
|
||||||
group_plan.append(("output", True))
|
group_plan.append(("output", True))
|
||||||
|
else:
|
||||||
|
for i in range(model_args.n_layers_local_encoder):
|
||||||
|
group_plan.append((f"local_encoder.layers.{i}", True))
|
||||||
|
group_plan.append((f"local_encoder.cross_attn_layers.{i}", True))
|
||||||
|
for i in range(model_args.n_layers_local_decoder):
|
||||||
|
group_plan.append((f"local_decoder.layers.{i}", True))
|
||||||
|
group_plan.append((f"local_decoder.cross_attn_layers.{i}", True))
|
||||||
|
for i in range(model_args.n_layers_global):
|
||||||
|
group_plan.append((f"global_transformer.layers.{i}", True))
|
||||||
|
|
||||||
|
for i in range(len(model_args.encoder_hash_byte_group_size)):
|
||||||
|
group_plan.append((f"encoder_hash_tok_embedding.{i}", True))
|
||||||
|
|
||||||
return group_plan
|
return group_plan
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue