From a828594625ed27e24b01634e55ddf9453e9072ca Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <me@pedro.ai> Date: Wed, 5 Mar 2025 23:02:22 +0000 Subject: [PATCH] Let process start before yielding preloaded prefetch buffer, avoid needlessly losing buffer in edge cases Summary: Test Plan: --- .../data/iterators/multiprocess_iterator.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index d0d1a84..b678dca 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -190,7 +190,11 @@ class MultiprocessIterator(StatefulIterator): logging.info( "Main thread: Emptying the batch_queue until batch.is_final=True is found." ) - self.prefetch_buffer = [] + if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: + buffer = self.prefetch_buffer + else: + buffer = [] + self.prefetch_buffer = buffer final_batch_received = False while True: try: @@ -261,12 +265,14 @@ class MultiprocessIterator(StatefulIterator): "Attempted to get approximate state, but queue was erroniously empty." ) self.received_approximate_state_event.set() + if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: + buffer = [b.to_python_dict() for b in self.prefetch_buffer] + else: + buffer = [] return MultiprocessIteratorState( base_iterator_state=base_iterator_state, n_batches_to_prefetch=self.n_batches_to_prefetch, - serialized_prefetch_buffer=json.dumps( - [b.to_python_dict() for b in self.prefetch_buffer] - ), + serialized_prefetch_buffer=json.dumps(buffer), persist_type=self.persist_type, ) @@ -281,9 +287,12 @@ class MultiprocessIterator(StatefulIterator): "State will be invalid if shutdown was forced before state persisted." ) if self.producer is None: - serialized_prefetch_buffer = json.dumps( - [b.to_python_dict() for b in self.prefetch_buffer] - ) + if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0: + serialized_prefetch_buffer = json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ) + else: + serialized_prefetch_buffer = json.dumps([]) return MultiprocessIteratorState( base_iterator_state=self.base_iterator.get_state(), n_batches_to_prefetch=self.n_batches_to_prefetch, @@ -304,12 +313,6 @@ class MultiprocessIterator(StatefulIterator): "Iterator may be invalid if shutdown was forced before state persisted." ) logging.info("Main thread: Creating MP iterator") - # First yield from the stored prefetch buffer. - if self.prefetch_buffer is not None: - while len(self.prefetch_buffer) > 0: - item = self.prefetch_buffer.pop(0) - yield item - self.prefetch_buffer = None assert ( self.producer is None @@ -349,6 +352,13 @@ class MultiprocessIterator(StatefulIterator): logger.info("Async dataloader started") self.producer.start() + # First yield from the stored prefetch buffer. + if self.prefetch_buffer is not None: + while len(self.prefetch_buffer) > 0: + item = self.prefetch_buffer.pop(0) + yield item + self.prefetch_buffer = None + while True: if self.producer.exitcode is not None: raise RuntimeError(