diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py
index d0d1a84..b678dca 100644
--- a/bytelatent/data/iterators/multiprocess_iterator.py
+++ b/bytelatent/data/iterators/multiprocess_iterator.py
@@ -190,7 +190,11 @@ class MultiprocessIterator(StatefulIterator):
         logging.info(
             "Main thread: Emptying the batch_queue until batch.is_final=True is found."
         )
-        self.prefetch_buffer = []
+        if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
+            buffer = self.prefetch_buffer
+        else:
+            buffer = []
+        self.prefetch_buffer = buffer
         final_batch_received = False
         while True:
             try:
@@ -261,12 +265,14 @@ class MultiprocessIterator(StatefulIterator):
                 "Attempted to get approximate state, but queue was erroniously empty."
             )
         self.received_approximate_state_event.set()
+        if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
+            buffer = [b.to_python_dict() for b in self.prefetch_buffer]
+        else:
+            buffer = []
         return MultiprocessIteratorState(
             base_iterator_state=base_iterator_state,
             n_batches_to_prefetch=self.n_batches_to_prefetch,
-            serialized_prefetch_buffer=json.dumps(
-                [b.to_python_dict() for b in self.prefetch_buffer]
-            ),
+            serialized_prefetch_buffer=json.dumps(buffer),
             persist_type=self.persist_type,
         )
 
@@ -281,9 +287,12 @@ class MultiprocessIterator(StatefulIterator):
                 "State will be invalid if shutdown was forced before state persisted."
             )
         if self.producer is None:
-            serialized_prefetch_buffer = json.dumps(
-                [b.to_python_dict() for b in self.prefetch_buffer]
-            )
+            if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
+                serialized_prefetch_buffer = json.dumps(
+                    [b.to_python_dict() for b in self.prefetch_buffer]
+                )
+            else:
+                serialized_prefetch_buffer = json.dumps([])
             return MultiprocessIteratorState(
                 base_iterator_state=self.base_iterator.get_state(),
                 n_batches_to_prefetch=self.n_batches_to_prefetch,
@@ -304,12 +313,6 @@ class MultiprocessIterator(StatefulIterator):
                 "Iterator may be invalid if shutdown was forced before state persisted."
             )
         logging.info("Main thread: Creating MP iterator")
-        # First yield from the stored prefetch buffer.
-        if self.prefetch_buffer is not None:
-            while len(self.prefetch_buffer) > 0:
-                item = self.prefetch_buffer.pop(0)
-                yield item
-            self.prefetch_buffer = None
 
         assert (
             self.producer is None
@@ -349,6 +352,13 @@ class MultiprocessIterator(StatefulIterator):
         logger.info("Async dataloader started")
         self.producer.start()
 
+        # First yield from the stored prefetch buffer.
+        if self.prefetch_buffer is not None:
+            while len(self.prefetch_buffer) > 0:
+                item = self.prefetch_buffer.pop(0)
+                yield item
+            self.prefetch_buffer = None
+
         while True:
             if self.producer.exitcode is not None:
                 raise RuntimeError(