From 8c61ab5e67ab044cd04176e03d45ef6845b62302 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 13 Feb 2025 11:58:23 -0800 Subject: [PATCH] Fix multiprocessing dataloader checkpointing and use it in the train script (#50) --- bytelatent/args.py | 2 - .../data/iterators/abstract_iterator.py | 10 ++++ bytelatent/data/iterators/arrow_iterator.py | 15 +++-- .../data/iterators/multiprocess_iterator.py | 27 ++++++--- bytelatent/train.py | 56 ++++++++++++------- 5 files changed, 77 insertions(+), 33 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index 263e8e3..47bd0f9 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,10 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -import json import logging import os from typing import Any -import fsspec import numpy as np import yaml from omegaconf import OmegaConf diff --git a/bytelatent/data/iterators/abstract_iterator.py b/bytelatent/data/iterators/abstract_iterator.py index 7fb442b..8ac7f19 100644 --- a/bytelatent/data/iterators/abstract_iterator.py +++ b/bytelatent/data/iterators/abstract_iterator.py @@ -21,3 +21,13 @@ class IteratorState(Generic[C]): @abc.abstractmethod def build(self) -> StatefulIterator[T, C]: pass + + +def get_state_and_refresh(iterator: StatefulIterator): + # Re-init dataloader and iterator is necessary since get_state() + # on mp iterator shuts down MP to correctly persist state and it needs + # to be restarted. + state = iterator.get_state() + data_loader = state.build() + py_iterator = data_loader.create_iter() + return state, data_loader, py_iterator diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 1c68d3a..995cd02 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -60,6 +60,13 @@ def shard_sort_key(file: str): return shard_number +def maybe_truncate_string(text: str, max_length: int): + if len(text) <= max_length: + return text + else: + return text[:max_length] + "..." + + class ArrowFileIterator(StatefulIterator): def __init__( self, @@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator): yield out def _set_row_num(self, target_row_num: int): - logger.info( - f"Setting arrow position to {target_row_num} for {self.dataset_files}" - ) + data_str = maybe_truncate_string(str(self.dataset_files), 200) + logger.info(f"Setting arrow position to {target_row_num} for {data_str}") if target_row_num is None or target_row_num == 0: self.row_num = 0 self.dataset = None @@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator): else: curr_remaining -= len(batch) self.row_num = target_row_num + data_str = maybe_truncate_string(str(self.dataset_files), 200) logger.info( - f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" + f"Finished setting arrow position to {target_row_num} for {data_str}" ) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 49d99ac..33bde94 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -54,9 +54,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.info( + logging.debug( "Worker thread: Stop event detected, outputting is_final=True batch" ) + logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -67,14 +68,17 @@ def start_work_from_state( ngram_ids=None, ) ) + logging.debug( + "Worker thread: is_final=True batch put in queue, breaking from loop." + ) break try: - logging.info("Worker thread: outputting state") - state_queue.put(iterator.get_state(), timeout=1) - logging.info("Worker thread: state dump complete") + logging.debug("Worker thread: outputting state") + state_queue.put(stateful_iterator.get_state(), timeout=1) + logging.debug("Worker thread: state dump complete") state_dumped_event.set() - logging.info("Worker thread: set state_dump_event") + logging.debug("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" @@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.info("Main thread: Sending stop iteration event") + logging.debug("Main thread: Sending stop iteration event") self.stop_iterating_event.set() - logging.info("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) self.prefetch_buffer = [] final_batch_received = False while True: try: batch = self.batch_queue.get(timeout=1) if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) final_batch_received = True break self.prefetch_buffer.append(batch) @@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator): logging.warning("Main thread: batch_queue is abnormally empty") assert final_batch_received + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + try: base_iterator_state = self.state_queue.get(timeout=1) assert isinstance(base_iterator_state, IteratorState) diff --git a/bytelatent/train.py b/bytelatent/train.py index 0ee87df..3669167 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -26,6 +26,7 @@ from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.data.file_util import get_fs +from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -35,7 +36,6 @@ from bytelatent.distributed import ( check_model_value_range, clean_env, dist_mean, - dist_mean_dict, dist_sum, get_device_mesh, get_is_master, @@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state): raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") +def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float: + if isinstance(num, (torch.Tensor, np.ndarray)): + return num.item() + else: + return num + + # TODO: Make this pydantic based instead of data class based # TODO: Generalize this to any iterator state @dataclass @@ -603,20 +610,20 @@ def train(args: TrainArgs): # step: Metric at a step # interval: Metric averaged/summed across all steps since the last log interval. # Typically, this is 10 - step_loss_per_gpu = loss.item() - step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() - interval_loss_per_gpu = np.mean(step_losses).item() - interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() + step_loss_per_gpu = loss + step_loss_across_gpus = dist_mean(step_loss_per_gpu) + interval_loss_per_gpu = np.mean(step_losses) + interval_loss_across_gpus = dist_mean(interval_loss_per_gpu) stacked_tok_loss = torch.cat(step_tok_losses, dim=0) - interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() + interval_total_tok_loss_per_gpu = stacked_tok_loss.sum() interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 - ).item() - interval_total_n_bytes_per_gpu = n_bytes.item() + ) + interval_total_n_bytes_per_gpu = n_bytes interval_total_n_bytes_across_gpus = dist_sum( n_bytes, reduce_dtype=torch.bfloat16 - ).item() + ) interval_bpb_per_gpu = ( interval_total_tok_loss_per_gpu @@ -645,18 +652,20 @@ def train(args: TrainArgs): }, "memory": gpu_mem_stats._asdict(), "loss": { - "step_per_gpu": step_loss_per_gpu, - "step_across_gpu": step_loss_across_gpus, - "interval_per_gpu": interval_loss_per_gpu, - "interval_across_gpu": interval_loss_across_gpus, + "step_per_gpu": to_py_num(step_loss_per_gpu), + "step_across_gpu": to_py_num(step_loss_across_gpus), + "interval_per_gpu": to_py_num(interval_loss_per_gpu), + "interval_across_gpu": to_py_num(interval_loss_across_gpus), }, "bpb": { - "interval_per_gpu": interval_bpb_per_gpu, - "interval_across_gpus": interval_bpb_across_gpus, + "interval_per_gpu": to_py_num(interval_bpb_per_gpu), + "interval_across_gpus": to_py_num(interval_bpb_across_gpus), }, "n_bytes": { - "interval_per_gpu": interval_total_n_bytes_per_gpu, - "interval_across_gpus": interval_total_n_bytes_across_gpus, + "interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu), + "interval_across_gpus": to_py_num( + interval_total_n_bytes_across_gpus + ), }, } @@ -676,8 +685,8 @@ def train(args: TrainArgs): logger.info( f"step: {train_state.step}" f" acc: {train_state.acc_step}" - f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}" - f" loss_avg: {round(interval_loss_across_gpus, 4):>7}" + f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}" + f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}" f" bpb_gpu: {interval_bpb_per_gpu:3f}" f" bpb_avg: {interval_bpb_across_gpus:3f}" f" grad: {grad_norm:.2e}" @@ -702,6 +711,9 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) saved = checkpoint.save( model, optimizer, @@ -743,6 +755,9 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer, @@ -754,6 +769,9 @@ def train(args: TrainArgs): sys.exit(0) if not saved: + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) checkpoint.save( model, optimizer,