diff --git a/.gitignore b/.gitignore index 6c664b8..d1d7c2a 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ figures/ .vscode/ .DS_Store internal/ +jobs_parallel-copy/ diff --git a/bytelatent/args.py b/bytelatent/args.py index a332c89..56de22d 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel): max_encoder_seq_length: int = 12288 enable_byte_ngrams: bool = False + add_patches: bool = True + tokenizer_args: TokenizerArgs = TokenizerArgs() patcher_args: PatcherArgs = PatcherArgs() @@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel): looping_iterator, patcher_args=self.patcher_args, tokenizer_args=self.tokenizer_args, + add_patches=self.add_patches, ) sequence_iterator = SequenceIterator( preprocess_iterator, @@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel): source_to_iterator=source_to_sequence_iterators, ) tokenizer = self.tokenizer_args.build() + if self.tokenizer_args.name == "bytes": + # TODO: Check this with Artidoro + pad_id = 0 + else: + pad_id = tokenizer.boe_id packing_args = PackingArgs( batch_size=self.batch_size, seq_len=self.seq_len, - pad_id=tokenizer.boe_id, + pad_id=pad_id, max_length=self.max_encoder_seq_length, pad_to_max_length=self.pad_to_max_length, enable_byte_ngrams=self.enable_byte_ngrams, + tokenizer_name=self.tokenizer_args.name, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: @@ -180,7 +189,7 @@ class TrainArgs(BaseModel): data: DataloaderArgs = DataloaderArgs() optim: OptimArgs = OptimArgs() - model: ByteLatentTransformerArgs = ByteLatentTransformerArgs() + model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs() # This is only needed for training the entropy model entropy_model: LMTransformerArgs | None = None # Instead of training main model, train entropy model diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 4ae4459..1098ff5 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -26,10 +26,9 @@ model: vocab_size: 260 dim_token: 256 patch_size: 6 - tokenization_mode: "bytes" patching_mode: "space" tie_local_encoder_decoder_logits: false - data_loader_patching: true + patch_in_forward: false max_encoder_seq_length: 12288 pad_to_max_length: true patching_threshold: 3.1439168453216553 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml new file mode 100644 index 0000000..51b65d4 --- /dev/null +++ b/bytelatent/configs/entropy_model.yaml @@ -0,0 +1,82 @@ +# Template config, need to change dump_dir, data.root_dir and tokenizer.path +# Evals can be activated by uncommenting its config +# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest + +dump_dir: /tmp/ +name: "debug" +steps: 100_000 +probe_freq: null +seed: 777 +optim: + lr: 4e-04 + warmup: 500 + lr_min_ratio: 0.1 + clip: 10.0 + +distributed: + fsdp_type: full_shard + model_dtype: bf16 + matmul_allow_tf32: false + selective_activation_checkpointing: false + tp_size: 1 + +train_entropy_model: true +model: null +entropy_model: + dim: 768 + n_layers: 14 + n_heads: 12 + max_seqlen: 8192 + # vocab_size: -1 + vocab_size: 260 + ffn_dim_multiplier: 1.0 + sliding_window: 512 + attn_bias_type: "local_block_causal" + attn_impl: "xformers" + +data: + s3_profile: blt + root_dir: ??? + sources: + dclm_baseline_1.0: 1.0 + batch_size: 2 + prefetch_size: 64 + # seqlen is in terms of patches and + # max_encoder_seq_length is in terms of bytes. + # For entropy model, these are the same since 1 patch=1 byte + seq_len: 8192 + max_encoder_seq_length: 8192 + load_async: true + preprocess_dir: ??? + # We don't need patches for this model + add_patches: false + patcher_args: + # This doesn't matter since byte entropy model doesn't use patching, + # so pick the most efficient, so static + patching_mode: byte + tokenizer_args: + name: bytes + +profiling: + run: false + +checkpoint: + dump: + every: 500 + keep: 3 + eval: + every: 1000 + keep: -1 + +logging: + freq: 10 + +eval_on_gpus: 8 +eval: + dataset_dir: ??? + tasks: ??? + generator: + max_tokens: 65536 + dtype: bf16 + + mp_size: 1 diff --git a/bytelatent/data/data_types.py b/bytelatent/data/data_types.py index 7e142e4..aa2daa9 100644 --- a/bytelatent/data/data_types.py +++ b/bytelatent/data/data_types.py @@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]] class BltSequence(BaseModel): tokens: list[int] mask: list[bool] - patch_lengths: list[int] + patch_lengths: list[int] | None @dataclass diff --git a/bytelatent/data/iterators/packing_iterator.py b/bytelatent/data/iterators/packing_iterator.py index 361fc03..fa29149 100644 --- a/bytelatent/data/iterators/packing_iterator.py +++ b/bytelatent/data/iterators/packing_iterator.py @@ -17,6 +17,7 @@ class PackingArgs(BaseModel): max_length: int | None pad_to_max_length: bool enable_byte_ngrams: bool + tokenizer_name: str class PackingIteratorState(BaseModel, IteratorState): @@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): ) def create_iter(self): + if self.packing_args.tokenizer_name == "bytes": + return self._create_iter_from_bytes() + else: + return self._create_iter_from_patch_lengths() + + def _create_iter_from_bytes(self): + sequence_iter = self.sequence_iterator.create_iter() + batch_size = self.packing_args.batch_size + pad_id = self.packing_args.pad_id + seq_len = self.packing_args.seq_len + while True: + tokens: list[list[int]] = [] + masks: list[list[bool]] = [] + + for _ in range(self.packing_args.batch_size): + sequence = next(sequence_iter) + _tokens = sequence.tokens + _mask = sequence.mask + assert ( + sequence.patch_lengths is None + ), "patch_lengths should not be used in byte packing" + tokens.append(_tokens) + masks.append(_mask) + + x = np.full((batch_size, seq_len), fill_value=pad_id) + y = np.full((batch_size, seq_len), fill_value=pad_id) + + for i, tok_seq in enumerate(tokens): + x[i, : len(tok_seq)] = tok_seq + y[i, : len(tok_seq) - 1] = tok_seq[1:] + batch = Batch(x=x, y=y) + assert ( + batch.mask is None or np.sum(x != pad_id) == batch.mask.sum() + ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}" + yield batch + + def _create_iter_from_patch_lengths(self): sequence_iter = self.sequence_iterator.create_iter() batch_size = self.packing_args.batch_size pad_id = self.packing_args.pad_id @@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]): _tokens = sequence.tokens _mask = sequence.mask _patch_lengths = sequence.patch_lengths + assert ( + _patch_lengths is not None + ), "patch lengths are required for packing based on patches." + # Reminder: seq_len is in terms of patches assert len(sequence.patch_lengths) == self.packing_args.seq_len last_patch_length = 0 if _patch_lengths[0] > 1: diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py index 14e3747..d90ea31 100644 --- a/bytelatent/data/iterators/sequence_iterator.py +++ b/bytelatent/data/iterators/sequence_iterator.py @@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator): for example in example_iter: assert example.tokens is not None assert example.mask is not None - assert example.patch_lengths is not None + if self.preprocess_iterator.add_patches: + assert example.patch_lengths is not None + assert len(example.tokens) == sum(example.patch_lengths) + else: + assert example.patch_lengths is None assert len(example.tokens) != 0 assert len(example.mask) != 0 assert len(example.tokens) == len(example.mask) - assert len(example.tokens) == sum(example.patch_lengths) tokens.extend(example.tokens) mask.extend(example.mask) - patch_lengths.extend(example.patch_lengths) + if self.preprocess_iterator.add_patches: + patch_lengths.extend(example.patch_lengths) + else: + # This lets the rest of the code work as expected and just yield byte seqs + patch_lengths.extend([1] * len(example.tokens)) while len(patch_lengths) >= n_buffer_patches: if first: @@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator): == len(seq_mask[idx]) ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}" assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}" - yield BltSequence( - tokens=seq_tokens[idx], - mask=seq_mask[idx], - patch_lengths=seq_patch_lengths[idx], - ) + if self.preprocess_iterator.add_patches: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=seq_patch_lengths[idx], + ) + else: + yield BltSequence( + tokens=seq_tokens[idx], + mask=seq_mask[idx], + patch_lengths=None, + ) diff --git a/bytelatent/data/patcher.py b/bytelatent/data/patcher.py index afcfa2e..44ff5e9 100644 --- a/bytelatent/data/patcher.py +++ b/bytelatent/data/patcher.py @@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum): bpe = "bpe" bpe_patcher = "bpe_patcher" space = "space" + static = "static" + byte = "byte" class PatcherArgs(BaseModel): @@ -34,7 +36,6 @@ class PatcherArgs(BaseModel): max_patch_length: int | None = None patch_size: float = 4.5 patching_batch_size: int = 1 - data_loader_patching: bool = False device: str = "cuda" monotonicity: bool = False log_time: bool = False @@ -486,7 +487,6 @@ class Patcher: self.max_patch_length = patcher_args.max_patch_length self.patch_size = patcher_args.patch_size self.patching_batch_size = patcher_args.patching_batch_size - self.data_loader_patching = patcher_args.data_loader_patching self.device = patcher_args.device self.monotonicity = patcher_args.monotonicity self.log_time = patcher_args.log_time @@ -528,7 +528,7 @@ class Patcher: seq_len_next_tok = seq_len + 1 if include_next_token else seq_len scores = None # STATIC - if self.patching_mode is None: + if self.patching_mode == PatchingModeEnum.static: patch_lengths = torch.zeros( (bs, math.ceil(seq_len_next_tok / self.patch_size)), dtype=tokens.dtype, @@ -536,6 +536,10 @@ class Patcher: ).fill_(self.patch_size) if seq_len_next_tok % self.patch_size != 0: patch_lengths[:, -1] = seq_len_next_tok % self.patch_size + elif self.patching_mode == PatchingModeEnum.byte: + patch_lengths = torch.ones( + (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device + ) # ENTROPY elif self.patching_mode == PatchingModeEnum.entropy: if self.log_time: diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 843ad34..a62be23 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_heads: int = 8 # TODO: What is the purpose of this parameter? weight_tying: bool = False + patch_in_forward: bool = False # Architecture and dimensions dim_token: int = 256 @@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): n_layers_local_encoder: int = 8 # Tokenization and patching - tokenization_mode: str = "bpe" patch_size: float | None = None patching_mode: str | None = None patching_threshold: float | None = None @@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs): monotonicity: bool = False patching_batch_size: int = 1 patching_device: str = "cuda" - data_loader_patching: bool = False max_patch_length: int | None = None # Encoder/Decoder configuration @@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module): self.output.weight = self.tok_embeddings.weight # Patcher module - if not args.data_loader_patching: + if args.patch_in_forward: self.patcher = Patcher( PatcherArgs( patch_size=args.patch_size, diff --git a/bytelatent/test_blt.py b/bytelatent/test_blt.py index 36a9882..eb94df3 100644 --- a/bytelatent/test_blt.py +++ b/bytelatent/test_blt.py @@ -68,10 +68,9 @@ def create_args(cross_attention=False): # Additional args from command line dim_token=256, patch_size=6, - tokenization_mode="bytes", patching_mode="space", tie_local_encoder_decoder_logits=False, - data_loader_patching=True, + patch_in_forward=False, max_encoder_seq_length=12288, pad_to_max_length=True, encoder_lm_loss=False, diff --git a/bytelatent/train.py b/bytelatent/train.py index 80bd393..1d0fa40 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD from bytelatent.profiling import maybe_run_profiler from bytelatent.stool import StoolArgs, launch_job from bytelatent.transformer import ( + LMTransformer, build_fsdp_grouping_plan, get_no_recompute_ops, get_num_flop_per_token, @@ -103,10 +104,15 @@ class TrainState(Stateful): def validate_train_args(args: TrainArgs, output_size: int): - if args.model.vocab_size < 0: + assert args.model is not None or args.entropy_model is not None + if args.model is not None: logger.info(f"Setting model output size to {args.model.vocab_size}") args.model.vocab_size = output_size + if args.entropy_model is not None: + logger.info(f"Setting model output size to {args.entropy_model.vocab_size}") + args.entropy_model.vocab_size = output_size + assert args.dump_dir, "Dump dir not set" if args.checkpoint.path is None: @@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int): and args.distributed.dp_replicate == get_world_size() ) - args.model.max_seqlen = args.data.seq_len + if args.model is not None: + args.model.max_seqlen = args.data.seq_len + if args.entropy_model is not None: + args.entropy_model.max_seqlen = args.data.seq_len if args.distributed.tp_size == 1: logger.warning( @@ -237,7 +246,14 @@ def train(args: TrainArgs): # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory with torch.device("meta"): - model = ByteLatentTransformer(args.model) + if args.train_entropy_model: + assert args.entropy_model is not None + model = LMTransformer(args.entropy_model) + model_args = args.entropy_model + else: + assert args.model is not None + model = ByteLatentTransformer(args.model) + model_args = args.model logger.info("Model is built !") model_param_count = get_num_params(model) @@ -247,7 +263,7 @@ def train(args: TrainArgs): world_mesh, args.model, args.distributed, - fsdp_grouping_plan=build_fsdp_grouping_plan(args.model), + fsdp_grouping_plan=build_fsdp_grouping_plan(model_args), tp_parallelize=tp_parallelize, no_recompute_ops=get_no_recompute_ops(), ) @@ -267,7 +283,7 @@ def train(args: TrainArgs): model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: with torch.random.fork_rng(devices=[torch.cuda.current_device()]): - torch.manual_seed(args.model.seed) + torch.manual_seed(model_args.seed) model.init_weights() check_model_value_range(model, range=10.0, std=1.0) @@ -342,10 +358,17 @@ def train(args: TrainArgs): batch.x, ).cuda() batch_y = torch.from_numpy(batch.y).cuda() - batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() + if batch.patch_lengths is None: + batch_patch_lengths = None + else: + batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda() mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda() - if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None: + if ( + not args.train_entropy_model + and args.model.encoder_enable_byte_ngrams + and batch.ngram_ids is None + ): raise ValueError( "Cannot enable byte ngrams and have batch.ngram_ids be None" ) @@ -408,9 +431,12 @@ def train(args: TrainArgs): next(probe_mod.parameters()).grad is None ), "Probe model shouldn't have grads at this point" - pred = model( - batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids - ) + if args.train_entropy_model: + pred = model(batch_x) + else: + pred = model( + batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids + ) loss, _ = compute_loss(pred, batch_y, mask, train_state.scale) @@ -474,9 +500,9 @@ def train(args: TrainArgs): # Use xformer's analyze profile trace to get actual measurement FLOPS = ( get_num_flop_per_token( - model_param_count - args.model.vocab_size * args.model.dim, - args.model.n_layers, - args.model.dim, + model_param_count - model_args.vocab_size * model_args.dim, + model_args.n_layers, + model_args.dim, args.data.seq_len, ) * wps