diff --git a/bytelatent/args.py b/bytelatent/args.py index 11c6548..4a67231 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator -from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator +from bytelatent.data.iterators.multiprocess_iterator import ( + MultiprocessIterator, + PersistType, +) from bytelatent.data.iterators.packing_iterator import ( PackingArgs, PackingIterator, @@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel): add_bos: bool = True add_eos: bool = True load_async: bool = True + async_persist_type: PersistType = PersistType.EXACT prefetch_size: int = 64 preprocess_dir: str | None = None dataset_files: list[str] | None = None @@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel): packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) if self.load_async: mp_iterator = MultiprocessIterator( - packing_iterator, n_batches_to_prefetch=self.prefetch_size + packing_iterator, + n_batches_to_prefetch=self.prefetch_size, + persist_type=self.async_persist_type, ) return mp_iterator else: diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index b4df945..d0d1a84 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -2,6 +2,7 @@ import json import logging import multiprocessing as mp +from enum import Enum from multiprocessing.synchronize import Event as EventClass from queue import Empty, Full @@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState logger = logging.getLogger() +class PersistType(str, Enum): + EXACT = "exact" + APPROXIMATE = "approximate" + + class MultiprocessIteratorState(PydanticIteratorState): model_config = ConfigDict(extra="forbid") base_iterator_state: PackingIteratorState n_batches_to_prefetch: int serialized_prefetch_buffer: str + persist_type: PersistType def build(self): base_iterator = self.base_iterator_state.build() @@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState): base_iterator, n_batches_to_prefetch=self.n_batches_to_prefetch, prefetch_buffer=prefetch_buffer, + persist_type=self.persist_type, ) def start_work_from_state( batch_queue: mp.Queue, state_queue: mp.Queue, + approximate_state_queue: mp.Queue, stop_event: EventClass, state_dumped_event: EventClass, + trigger_approximate_send_state_event: EventClass, + sent_approximate_state_event: EventClass, + received_approximate_state_event: EventClass, state: IteratorState, ): logging.info("Worker thread: Starting base_iterator work") @@ -49,6 +61,25 @@ def start_work_from_state( for item in iterator: while not stop_event.is_set(): try: + if trigger_approximate_send_state_event.is_set(): + logger.info("WT: trigger_approximate_send ack") + # Since this can be triggered again (but only after the state is received on mp), + # we should cleanup as soon as possible. + trigger_approximate_send_state_event.clear() + logging.info("WT: Computing approximate state") + approximate_state = stateful_iterator.get_state() + # At this state, there should always be exactly 1 slot. + # Blocking here would be a bug. + logger.info("WT: Attempting to send approximate state") + approximate_state_queue.put( + approximate_state, block=True, timeout=None + ) + sent_approximate_state_event.set() + logger.info("WT: Approximate state sent") + # Same here, clear events as we no longer need them. + received_approximate_state_event.wait() + received_approximate_state_event.clear() + logger.info("WT: State received by MT, resuming batch iteration") # Attempt to put on queue or timeout to try again (maybe main thread is busy) batch_queue.put(item, timeout=0.1) # On success, stop trying @@ -58,10 +89,10 @@ def start_work_from_state( if stop_event.is_set(): # Signal the end of output, this ensures that even if the queue takes a while to # buffer, that the main thread receives everything (and tosses this fake batch) - logging.debug( + logging.info( "Worker thread: Stop event detected, outputting is_final=True batch" ) - logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) + logging.info("Worker thread: batch_queue full=%s", batch_queue.full()) batch_queue.put( Batch( x=np.zeros((1, 1)), @@ -72,23 +103,26 @@ def start_work_from_state( ngram_ids=None, ) ) - logging.debug( + logging.info( "Worker thread: is_final=True batch put in queue, breaking from loop." ) break try: - logging.debug("Worker thread: outputting state") + logging.info("Worker thread: outputting state") state_queue.put(stateful_iterator.get_state(), timeout=1) - logging.debug("Worker thread: state dump complete") + logging.info("Worker thread: state dump complete") state_dumped_event.set() - logging.debug("Worker thread: set state_dump_event") + logging.info("Worker thread: set state_dump_event") except Full: raise ValueError( "Attempted to dump state into the state queue, but it was full" ) +FETCH_STATE_TIMEOUT = 120 + + class MultiprocessIterator(StatefulIterator): """ Design sketch of the multiprocess iterator: @@ -124,18 +158,24 @@ class MultiprocessIterator(StatefulIterator): base_iterator: StatefulIterator, *, n_batches_to_prefetch: int, - prefetch_buffer: list | None = None + prefetch_buffer: list | None = None, + persist_type: PersistType = PersistType.EXACT, ): self.base_iterator = base_iterator self.n_batches_to_prefetch = n_batches_to_prefetch + self.persist_type = persist_type if prefetch_buffer is None: prefetch_buffer = [] self.prefetch_buffer = prefetch_buffer self.batch_queue = None self.state_queue = None + self.approximate_state_queue = None self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.trigger_approximate_send_state_event = None + self.sent_approximate_state_event = None + self.received_approximate_state_event = None self.force_shutdown = False def shutdown(self): @@ -144,6 +184,92 @@ class MultiprocessIterator(StatefulIterator): self.producer.kill() self.force_shutdown = True + def _get_state_exact(self): + logging.info("Main thread: Sending stop iteration event") + self.stop_iterating_event.set() + logging.info( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) + self.prefetch_buffer = [] + final_batch_received = False + while True: + try: + batch = self.batch_queue.get(timeout=1) + if batch.is_final: + logging.info( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) + final_batch_received = True + break + self.prefetch_buffer.append(batch) + except Empty: + logging.warning("Main thread: batch_queue is abnormally empty") + assert final_batch_received + + logging.info("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + + try: + logging.info( + "Main thread: state_dumped_event received, waiting for state from queue" + ) + base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT) + logging.info("Main thread: received state from queue") + assert isinstance(base_iterator_state, IteratorState) + except Empty: + raise ValueError( + "Attempted to get the state, but it was unexpectantly missing" + ) + + self.base_iterator = base_iterator_state.build() + self.producer.close() + self.producer = None + self.batch_queue = None + self.state_queue = None + self.approximate_state_queue = None + self.stop_iterating_event = None + self.state_dumped_event = None + self.trigger_approximate_send_state_event = None + self.sent_approximate_state_event = None + self.received_approximate_state_event = None + + return MultiprocessIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ), + persist_type=self.persist_type, + ) + + def _get_state_approximate(self): + logging.info("MT: Sending approximate get_state request") + self.trigger_approximate_send_state_event.set() + logging.info("MT: Waiting for sent_approximate_state_event") + self.sent_approximate_state_event.wait() + logging.info("MT: sent_approximate_state_event ack") + try: + logging.info("MT: waiting for approximate state in queue") + base_iterator_state = self.approximate_state_queue.get( + timeout=FETCH_STATE_TIMEOUT + ) + logging.info("MT: approximate state received") + assert isinstance(base_iterator_state, IteratorState) + assert self.approximate_state_queue.empty() + except Empty: + raise ValueError( + "Attempted to get approximate state, but queue was erroniously empty." + ) + self.received_approximate_state_event.set() + return MultiprocessIteratorState( + base_iterator_state=base_iterator_state, + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ), + persist_type=self.persist_type, + ) + def get_state(self) -> MultiprocessIteratorState: """ This is slightly unusual in effectively destroying the current iterator, its necessary @@ -162,55 +288,15 @@ class MultiprocessIterator(StatefulIterator): base_iterator_state=self.base_iterator.get_state(), n_batches_to_prefetch=self.n_batches_to_prefetch, serialized_prefetch_buffer=serialized_prefetch_buffer, + persist_type=self.persist_type, ) else: - logging.debug("Main thread: Sending stop iteration event") - self.stop_iterating_event.set() - logging.debug( - "Main thread: Emptying the batch_queue until batch.is_final=True is found." - ) - self.prefetch_buffer = [] - final_batch_received = False - while True: - try: - batch = self.batch_queue.get(timeout=1) - if batch.is_final: - logging.debug( - "Main thread: is_final=True batch found, stopping fetch from batch_queue" - ) - final_batch_received = True - break - self.prefetch_buffer.append(batch) - except Empty: - logging.warning("Main thread: batch_queue is abnormally empty") - assert final_batch_received - - logging.debug("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() - - try: - base_iterator_state = self.state_queue.get(timeout=1) - assert isinstance(base_iterator_state, IteratorState) - except Empty: - raise ValueError( - "Attempted to get the state, but it was unexpectantly missing" - ) - - self.base_iterator = base_iterator_state.build() - self.producer.close() - self.producer = None - self.batch_queue = None - self.state_queue = None - self.stop_iterating_event = None - self.state_dumped_event = None - - return MultiprocessIteratorState( - base_iterator_state=self.base_iterator.get_state(), - n_batches_to_prefetch=self.n_batches_to_prefetch, - serialized_prefetch_buffer=json.dumps( - [b.to_python_dict() for b in self.prefetch_buffer] - ), - ) + if self.persist_type == PersistType.EXACT: + return self._get_state_exact() + elif self.persist_type == PersistType.APPROXIMATE: + return self._get_state_approximate() + else: + raise ValueError("invalid persist_type") def create_iter(self): if self.force_shutdown: @@ -236,8 +322,14 @@ class MultiprocessIterator(StatefulIterator): # We should only ever one state, which is output at the detection of a stop event self.state_queue = ctx.Manager().Queue(maxsize=1) + # Similarly, there should only ever be one state in flight due to event signals + self.approximate_state_queue = ctx.Manager().Queue(maxsize=1) + self.stop_iterating_event = ctx.Event() self.state_dumped_event = ctx.Event() + self.trigger_approximate_send_state_event = ctx.Event() + self.sent_approximate_state_event = ctx.Event() + self.received_approximate_state_event = ctx.Event() self.producer = mp.Process( name="blt_data_loader", @@ -245,8 +337,12 @@ class MultiprocessIterator(StatefulIterator): args=( self.batch_queue, self.state_queue, + self.approximate_state_queue, self.stop_iterating_event, self.state_dumped_event, + self.trigger_approximate_send_state_event, + self.sent_approximate_state_event, + self.received_approximate_state_event, self.base_iterator.get_state(), ), ) diff --git a/bytelatent/train.py b/bytelatent/train.py index eb1c700..ad2fd9b 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -31,6 +31,7 @@ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, + PersistType, ) from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.distributed import ( @@ -712,9 +713,15 @@ def train(args: TrainArgs): if every_n_steps( train_state, args.checkpoint.dump.every, acc_step=0 ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): - train_state.data_loader_state, data_loader, batch_iterator = ( - get_state_and_refresh(data_loader) - ) + if ( + args.data.load_async + and args.data.async_persist_type == PersistType.EXACT + ): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) + else: + train_state.data_loader_state = data_loader.get_state() saved = checkpoint.save( model, optimizer, @@ -756,9 +763,16 @@ def train(args: TrainArgs): if preemption_flag["flag"]: if not saved: - train_state.data_loader_state, data_loader, batch_iterator = ( - get_state_and_refresh(data_loader) - ) + if ( + args.data.load_async + and args.data.async_persist_type == PersistType.EXACT + ): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) + else: + train_state.data_loader_state = data_loader.get_state() + checkpoint.save( model, optimizer, @@ -769,21 +783,27 @@ def train(args: TrainArgs): requeue_slurm_job() sys.exit(0) - if not saved: - train_state.data_loader_state, data_loader, batch_iterator = ( - get_state_and_refresh(data_loader) - ) - checkpoint.save( - model, - optimizer, - train_state, - args, - device_mesh=world_mesh, - ) - if isinstance(data_loader, MultiprocessIterator): - logger.info("Closing MP iterator before exiting") - data_loader.shutdown() - gc.collect() + if not saved: + if ( + args.data.load_async + and args.data.async_persist_type == PersistType.EXACT + ): + train_state.data_loader_state, data_loader, batch_iterator = ( + get_state_and_refresh(data_loader) + ) + else: + train_state.data_loader_state = data_loader.get_state() + checkpoint.save( + model, + optimizer, + train_state, + args, + device_mesh=world_mesh, + ) + if isinstance(data_loader, MultiprocessIterator): + logger.info("Closing MP iterator before exiting") + data_loader.shutdown() + gc.collect() def main():