From abb4f7e6a4f6e675d12f260aa54ad9da9c238dcb Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 4 Mar 2025 01:42:53 +0000 Subject: [PATCH] Let process start before yielding preloaded prefetch buffer, avoid needlessly losing buffer in edge cases Summary: Test Plan: --- .../data/iterators/multiprocess_iterator.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index 8ff9c51..4306eea 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -186,7 +186,11 @@ class MultiprocessIterator(StatefulIterator): logging.debug( "Main thread: Emptying the batch_queue until batch.is_final=True is found." ) - self.prefetch_buffer = [] + if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: + buffer = self.prefetch_buffer + else: + buffer = [] + self.prefetch_buffer = buffer final_batch_received = False while True: try: @@ -250,12 +254,14 @@ class MultiprocessIterator(StatefulIterator): "Attempted to get approximate state, but queue was erroniously empty." ) self.received_approximate_state_event.set() + if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: + buffer = [b.to_python_dict() for b in self.prefetch_buffer] + else: + buffer = [] return MultiprocessIteratorState( base_iterator_state=base_iterator_state, n_batches_to_prefetch=self.n_batches_to_prefetch, - serialized_prefetch_buffer=json.dumps( - [b.to_python_dict() for b in self.prefetch_buffer] - ), + serialized_prefetch_buffer=json.dumps(buffer), persist_type=self.persist_type, ) @@ -270,9 +276,12 @@ class MultiprocessIterator(StatefulIterator): "State will be invalid if shutdown was forced before state persisted." ) if self.producer is None: - serialized_prefetch_buffer = json.dumps( - [b.to_python_dict() for b in self.prefetch_buffer] - ) + if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: + serialized_prefetch_buffer = json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ) + else: + serialized_prefetch_buffer = json.dumps([]) return MultiprocessIteratorState( base_iterator_state=self.base_iterator.get_state(), n_batches_to_prefetch=self.n_batches_to_prefetch, @@ -293,12 +302,6 @@ class MultiprocessIterator(StatefulIterator): "Iterator may be invalid if shutdown was forced before state persisted." ) logging.info("Main thread: Creating MP iterator") - # First yield from the stored prefetch buffer. - if self.prefetch_buffer is not None: - while len(self.prefetch_buffer) > 0: - item = self.prefetch_buffer.pop(0) - yield item - self.prefetch_buffer = None assert ( self.producer is None @@ -338,6 +341,13 @@ class MultiprocessIterator(StatefulIterator): logger.info("Async dataloader started") self.producer.start() + # First yield from the stored prefetch buffer. + if self.prefetch_buffer is not None: + while len(self.prefetch_buffer) > 0: + item = self.prefetch_buffer.pop(0) + yield item + self.prefetch_buffer = None + while True: if self.producer.exitcode is not None: raise RuntimeError(