From 38cc67a9535f2a19ae3dee526114ee77d291d76f Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 11 Feb 2025 22:56:25 +0000 Subject: [PATCH] Fix multiprocessing dataloader checkpointing and use it in the train script Summary: Test Plan: --- .../data/iterators/abstract_iterator.py | 10 +++++++ bytelatent/data/iterators/arrow_iterator.py | 4 +-- .../data/iterators/multiprocess_iterator.py | 27 +++++++++++++------ bytelatent/train.py | 10 +++++++ 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..7e43360 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator): def _set_row_num(self, target_row_num: int): logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" + f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) if target_row_num is None or target_row_num == 0: self.row_num = 0 @@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator): curr_remaining -= len(batch) self.row_num = target_row_num logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index ed84233..8b667f1 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -699,6 +700,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -740,6 +744,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -751,6 +758,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer,