diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index d6eab14..d92a1fb 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -74,12 +74,10 @@ class LocalModelBase(nn.Module): self.boe_id = BOE_ID - self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.layers = nn.ModuleList( [TransformerBlock(args) for _ in range(args.n_layers)] ) - self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) if not self.use_rope: self.pos_embeddings = nn.Embedding(args.max_length, args.dim) else: @@ -131,16 +129,18 @@ class LocalModelBase(nn.Module): def init_weights(self, init_std=None): self.rope.reset_parameters() - self.norm.reset_parameters() + if hasattr(self, "norm"): + self.norm.reset_parameters() init_std = init_std or (self.dim ** (-0.5)) - nn.init.trunc_normal_( - self.tok_embeddings.weight, - mean=0.0, - std=init_std, - a=-3 * init_std, - b=3 * init_std, - ) + if hasattr(self, "tok_embeddings"): + nn.init.trunc_normal_( + self.tok_embeddings.weight, + mean=0.0, + std=init_std, + a=-3 * init_std, + b=3 * init_std, + ) if self.pos_embeddings is not None: nn.init.trunc_normal_( self.pos_embeddings.weight, @@ -212,6 +212,8 @@ class LocalEncoder(LocalModelBase): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads + self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = args.n_layers if self.cross_attn_all_layers_encoder else 1 @@ -314,6 +316,8 @@ class LocalDecoder(LocalModelBase): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() layers_to_add = args.n_layers if self.cross_attn_all_layers_decoder else 1