Let process start before yielding preloaded prefetch buffer, avoid needlessly losing buffer in edge cases ()

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-03-05 15:02:57 -08:00 committed by GitHub
parent ea1fc75862
commit 8f2cf8899d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -190,7 +190,11 @@ class MultiprocessIterator(StatefulIterator):
logging.info(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
buffer = self.prefetch_buffer
else:
buffer = []
self.prefetch_buffer = buffer
final_batch_received = False
while True:
try:
@ -261,12 +265,14 @@ class MultiprocessIterator(StatefulIterator):
"Attempted to get approximate state, but queue was erroniously empty."
)
self.received_approximate_state_event.set()
if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
buffer = [b.to_python_dict() for b in self.prefetch_buffer]
else:
buffer = []
return MultiprocessIteratorState(
base_iterator_state=base_iterator_state,
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
serialized_prefetch_buffer=json.dumps(buffer),
persist_type=self.persist_type,
)
@ -281,9 +287,12 @@ class MultiprocessIterator(StatefulIterator):
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
)
if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
)
else:
serialized_prefetch_buffer = json.dumps([])
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
@ -304,12 +313,6 @@ class MultiprocessIterator(StatefulIterator):
"Iterator may be invalid if shutdown was forced before state persisted."
)
logging.info("Main thread: Creating MP iterator")
# First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None:
while len(self.prefetch_buffer) > 0:
item = self.prefetch_buffer.pop(0)
yield item
self.prefetch_buffer = None
assert (
self.producer is None
@ -349,6 +352,13 @@ class MultiprocessIterator(StatefulIterator):
logger.info("Async dataloader started")
self.producer.start()
# First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None:
while len(self.prefetch_buffer) > 0:
item = self.prefetch_buffer.pop(0)
yield item
self.prefetch_buffer = None
while True:
if self.producer.exitcode is not None:
raise RuntimeError(