diff --git a/bytelatent/args.py b/bytelatent/args.py
index 11c6548..4a67231 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs
 from bytelatent.data.iterators.abstract_iterator import StatefulIterator
 from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
 from bytelatent.data.iterators.looping_iterator import LoopingIterator
-from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
+from bytelatent.data.iterators.multiprocess_iterator import (
+    MultiprocessIterator,
+    PersistType,
+)
 from bytelatent.data.iterators.packing_iterator import (
     PackingArgs,
     PackingIterator,
@@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel):
     add_bos: bool = True
     add_eos: bool = True
     load_async: bool = True
+    async_persist_type: PersistType = PersistType.EXACT
     prefetch_size: int = 64
     preprocess_dir: str | None = None
     dataset_files: list[str] | None = None
@@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel):
         packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
         if self.load_async:
             mp_iterator = MultiprocessIterator(
-                packing_iterator, n_batches_to_prefetch=self.prefetch_size
+                packing_iterator,
+                n_batches_to_prefetch=self.prefetch_size,
+                persist_type=self.async_persist_type,
             )
             return mp_iterator
         else:
diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py
index b4df945..d0d1a84 100644
--- a/bytelatent/data/iterators/multiprocess_iterator.py
+++ b/bytelatent/data/iterators/multiprocess_iterator.py
@@ -2,6 +2,7 @@
 import json
 import logging
 import multiprocessing as mp
+from enum import Enum
 from multiprocessing.synchronize import Event as EventClass
 from queue import Empty, Full
 
@@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
 logger = logging.getLogger()
 
 
+class PersistType(str, Enum):
+    EXACT = "exact"
+    APPROXIMATE = "approximate"
+
+
 class MultiprocessIteratorState(PydanticIteratorState):
     model_config = ConfigDict(extra="forbid")
     base_iterator_state: PackingIteratorState
     n_batches_to_prefetch: int
     serialized_prefetch_buffer: str
+    persist_type: PersistType
 
     def build(self):
         base_iterator = self.base_iterator_state.build()
@@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState):
             base_iterator,
             n_batches_to_prefetch=self.n_batches_to_prefetch,
             prefetch_buffer=prefetch_buffer,
+            persist_type=self.persist_type,
         )
 
 
 def start_work_from_state(
     batch_queue: mp.Queue,
     state_queue: mp.Queue,
+    approximate_state_queue: mp.Queue,
     stop_event: EventClass,
     state_dumped_event: EventClass,
+    trigger_approximate_send_state_event: EventClass,
+    sent_approximate_state_event: EventClass,
+    received_approximate_state_event: EventClass,
     state: IteratorState,
 ):
     logging.info("Worker thread: Starting base_iterator work")
@@ -49,6 +61,25 @@ def start_work_from_state(
     for item in iterator:
         while not stop_event.is_set():
             try:
+                if trigger_approximate_send_state_event.is_set():
+                    logger.info("WT: trigger_approximate_send ack")
+                    # Since this can be triggered again (but only after the state is received on mp),
+                    # we should cleanup as soon as possible.
+                    trigger_approximate_send_state_event.clear()
+                    logging.info("WT: Computing approximate state")
+                    approximate_state = stateful_iterator.get_state()
+                    # At this state, there should always be exactly 1 slot.
+                    # Blocking here would be a bug.
+                    logger.info("WT: Attempting to send approximate state")
+                    approximate_state_queue.put(
+                        approximate_state, block=True, timeout=None
+                    )
+                    sent_approximate_state_event.set()
+                    logger.info("WT: Approximate state sent")
+                    # Same here, clear events as we no longer need them.
+                    received_approximate_state_event.wait()
+                    received_approximate_state_event.clear()
+                    logger.info("WT: State received by MT, resuming batch iteration")
                 # Attempt to put on queue or timeout to try again (maybe main thread is busy)
                 batch_queue.put(item, timeout=0.1)
                 # On success, stop trying
@@ -58,10 +89,10 @@ def start_work_from_state(
         if stop_event.is_set():
             # Signal the end of output, this ensures that even if the queue takes a while to
             # buffer, that the main thread receives everything (and tosses this fake batch)
-            logging.debug(
+            logging.info(
                 "Worker thread: Stop event detected, outputting is_final=True batch"
             )
-            logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
+            logging.info("Worker thread: batch_queue full=%s", batch_queue.full())
             batch_queue.put(
                 Batch(
                     x=np.zeros((1, 1)),
@@ -72,23 +103,26 @@ def start_work_from_state(
                     ngram_ids=None,
                 )
             )
-            logging.debug(
+            logging.info(
                 "Worker thread: is_final=True batch put in queue, breaking from loop."
             )
             break
 
     try:
-        logging.debug("Worker thread: outputting state")
+        logging.info("Worker thread: outputting state")
         state_queue.put(stateful_iterator.get_state(), timeout=1)
-        logging.debug("Worker thread: state dump complete")
+        logging.info("Worker thread: state dump complete")
         state_dumped_event.set()
-        logging.debug("Worker thread: set state_dump_event")
+        logging.info("Worker thread: set state_dump_event")
     except Full:
         raise ValueError(
             "Attempted to dump state into the state queue, but it was full"
         )
 
 
+FETCH_STATE_TIMEOUT = 120
+
+
 class MultiprocessIterator(StatefulIterator):
     """
     Design sketch of the multiprocess iterator:
@@ -124,18 +158,24 @@ class MultiprocessIterator(StatefulIterator):
         base_iterator: StatefulIterator,
         *,
         n_batches_to_prefetch: int,
-        prefetch_buffer: list | None = None
+        prefetch_buffer: list | None = None,
+        persist_type: PersistType = PersistType.EXACT,
     ):
         self.base_iterator = base_iterator
         self.n_batches_to_prefetch = n_batches_to_prefetch
+        self.persist_type = persist_type
         if prefetch_buffer is None:
             prefetch_buffer = []
         self.prefetch_buffer = prefetch_buffer
         self.batch_queue = None
         self.state_queue = None
+        self.approximate_state_queue = None
         self.producer = None
         self.stop_iterating_event = None
         self.state_dumped_event = None
+        self.trigger_approximate_send_state_event = None
+        self.sent_approximate_state_event = None
+        self.received_approximate_state_event = None
         self.force_shutdown = False
 
     def shutdown(self):
@@ -144,6 +184,92 @@ class MultiprocessIterator(StatefulIterator):
             self.producer.kill()
             self.force_shutdown = True
 
+    def _get_state_exact(self):
+        logging.info("Main thread: Sending stop iteration event")
+        self.stop_iterating_event.set()
+        logging.info(
+            "Main thread: Emptying the batch_queue until batch.is_final=True is found."
+        )
+        self.prefetch_buffer = []
+        final_batch_received = False
+        while True:
+            try:
+                batch = self.batch_queue.get(timeout=1)
+                if batch.is_final:
+                    logging.info(
+                        "Main thread: is_final=True batch found, stopping fetch from batch_queue"
+                    )
+                    final_batch_received = True
+                    break
+                self.prefetch_buffer.append(batch)
+            except Empty:
+                logging.warning("Main thread: batch_queue is abnormally empty")
+        assert final_batch_received
+
+        logging.info("Main thread: Waiting for state_dumped event")
+        self.state_dumped_event.wait()
+
+        try:
+            logging.info(
+                "Main thread: state_dumped_event received, waiting for state from queue"
+            )
+            base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT)
+            logging.info("Main thread: received state from queue")
+            assert isinstance(base_iterator_state, IteratorState)
+        except Empty:
+            raise ValueError(
+                "Attempted to get the state, but it was unexpectantly missing"
+            )
+
+        self.base_iterator = base_iterator_state.build()
+        self.producer.close()
+        self.producer = None
+        self.batch_queue = None
+        self.state_queue = None
+        self.approximate_state_queue = None
+        self.stop_iterating_event = None
+        self.state_dumped_event = None
+        self.trigger_approximate_send_state_event = None
+        self.sent_approximate_state_event = None
+        self.received_approximate_state_event = None
+
+        return MultiprocessIteratorState(
+            base_iterator_state=self.base_iterator.get_state(),
+            n_batches_to_prefetch=self.n_batches_to_prefetch,
+            serialized_prefetch_buffer=json.dumps(
+                [b.to_python_dict() for b in self.prefetch_buffer]
+            ),
+            persist_type=self.persist_type,
+        )
+
+    def _get_state_approximate(self):
+        logging.info("MT: Sending approximate get_state request")
+        self.trigger_approximate_send_state_event.set()
+        logging.info("MT: Waiting for sent_approximate_state_event")
+        self.sent_approximate_state_event.wait()
+        logging.info("MT: sent_approximate_state_event ack")
+        try:
+            logging.info("MT: waiting for approximate state in queue")
+            base_iterator_state = self.approximate_state_queue.get(
+                timeout=FETCH_STATE_TIMEOUT
+            )
+            logging.info("MT: approximate state received")
+            assert isinstance(base_iterator_state, IteratorState)
+            assert self.approximate_state_queue.empty()
+        except Empty:
+            raise ValueError(
+                "Attempted to get approximate state, but queue was erroniously empty."
+            )
+        self.received_approximate_state_event.set()
+        return MultiprocessIteratorState(
+            base_iterator_state=base_iterator_state,
+            n_batches_to_prefetch=self.n_batches_to_prefetch,
+            serialized_prefetch_buffer=json.dumps(
+                [b.to_python_dict() for b in self.prefetch_buffer]
+            ),
+            persist_type=self.persist_type,
+        )
+
     def get_state(self) -> MultiprocessIteratorState:
         """
         This is slightly unusual in effectively destroying the current iterator, its necessary
@@ -162,55 +288,15 @@ class MultiprocessIterator(StatefulIterator):
                 base_iterator_state=self.base_iterator.get_state(),
                 n_batches_to_prefetch=self.n_batches_to_prefetch,
                 serialized_prefetch_buffer=serialized_prefetch_buffer,
+                persist_type=self.persist_type,
             )
         else:
-            logging.debug("Main thread: Sending stop iteration event")
-            self.stop_iterating_event.set()
-            logging.debug(
-                "Main thread: Emptying the batch_queue until batch.is_final=True is found."
-            )
-            self.prefetch_buffer = []
-            final_batch_received = False
-            while True:
-                try:
-                    batch = self.batch_queue.get(timeout=1)
-                    if batch.is_final:
-                        logging.debug(
-                            "Main thread: is_final=True batch found, stopping fetch from batch_queue"
-                        )
-                        final_batch_received = True
-                        break
-                    self.prefetch_buffer.append(batch)
-                except Empty:
-                    logging.warning("Main thread: batch_queue is abnormally empty")
-            assert final_batch_received
-
-            logging.debug("Main thread: Waiting for state_dumped event")
-            self.state_dumped_event.wait()
-
-            try:
-                base_iterator_state = self.state_queue.get(timeout=1)
-                assert isinstance(base_iterator_state, IteratorState)
-            except Empty:
-                raise ValueError(
-                    "Attempted to get the state, but it was unexpectantly missing"
-                )
-
-            self.base_iterator = base_iterator_state.build()
-            self.producer.close()
-            self.producer = None
-            self.batch_queue = None
-            self.state_queue = None
-            self.stop_iterating_event = None
-            self.state_dumped_event = None
-
-            return MultiprocessIteratorState(
-                base_iterator_state=self.base_iterator.get_state(),
-                n_batches_to_prefetch=self.n_batches_to_prefetch,
-                serialized_prefetch_buffer=json.dumps(
-                    [b.to_python_dict() for b in self.prefetch_buffer]
-                ),
-            )
+            if self.persist_type == PersistType.EXACT:
+                return self._get_state_exact()
+            elif self.persist_type == PersistType.APPROXIMATE:
+                return self._get_state_approximate()
+            else:
+                raise ValueError("invalid persist_type")
 
     def create_iter(self):
         if self.force_shutdown:
@@ -236,8 +322,14 @@ class MultiprocessIterator(StatefulIterator):
         # We should only ever one state, which is output at the detection of a stop event
         self.state_queue = ctx.Manager().Queue(maxsize=1)
 
+        # Similarly, there should only ever be one state in flight due to event signals
+        self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
+
         self.stop_iterating_event = ctx.Event()
         self.state_dumped_event = ctx.Event()
+        self.trigger_approximate_send_state_event = ctx.Event()
+        self.sent_approximate_state_event = ctx.Event()
+        self.received_approximate_state_event = ctx.Event()
 
         self.producer = mp.Process(
             name="blt_data_loader",
@@ -245,8 +337,12 @@ class MultiprocessIterator(StatefulIterator):
             args=(
                 self.batch_queue,
                 self.state_queue,
+                self.approximate_state_queue,
                 self.stop_iterating_event,
                 self.state_dumped_event,
+                self.trigger_approximate_send_state_event,
+                self.sent_approximate_state_event,
+                self.received_approximate_state_event,
                 self.base_iterator.get_state(),
             ),
         )
diff --git a/bytelatent/train.py b/bytelatent/train.py
index eb1c700..ad2fd9b 100644
--- a/bytelatent/train.py
+++ b/bytelatent/train.py
@@ -31,6 +31,7 @@ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
 from bytelatent.data.iterators.multiprocess_iterator import (
     MultiprocessIterator,
     MultiprocessIteratorState,
+    PersistType,
 )
 from bytelatent.data.iterators.packing_iterator import PackingIteratorState
 from bytelatent.distributed import (
@@ -712,9 +713,15 @@ def train(args: TrainArgs):
             if every_n_steps(
                 train_state, args.checkpoint.dump.every, acc_step=0
             ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
-                train_state.data_loader_state, data_loader, batch_iterator = (
-                    get_state_and_refresh(data_loader)
-                )
+                if (
+                    args.data.load_async
+                    and args.data.async_persist_type == PersistType.EXACT
+                ):
+                    train_state.data_loader_state, data_loader, batch_iterator = (
+                        get_state_and_refresh(data_loader)
+                    )
+                else:
+                    train_state.data_loader_state = data_loader.get_state()
                 saved = checkpoint.save(
                     model,
                     optimizer,
@@ -756,9 +763,16 @@ def train(args: TrainArgs):
 
             if preemption_flag["flag"]:
                 if not saved:
-                    train_state.data_loader_state, data_loader, batch_iterator = (
-                        get_state_and_refresh(data_loader)
-                    )
+                    if (
+                        args.data.load_async
+                        and args.data.async_persist_type == PersistType.EXACT
+                    ):
+                        train_state.data_loader_state, data_loader, batch_iterator = (
+                            get_state_and_refresh(data_loader)
+                        )
+                    else:
+                        train_state.data_loader_state = data_loader.get_state()
+
                     checkpoint.save(
                         model,
                         optimizer,
@@ -769,21 +783,27 @@ def train(args: TrainArgs):
                 requeue_slurm_job()
                 sys.exit(0)
 
-    if not saved:
-        train_state.data_loader_state, data_loader, batch_iterator = (
-            get_state_and_refresh(data_loader)
-        )
-        checkpoint.save(
-            model,
-            optimizer,
-            train_state,
-            args,
-            device_mesh=world_mesh,
-        )
-    if isinstance(data_loader, MultiprocessIterator):
-        logger.info("Closing MP iterator before exiting")
-        data_loader.shutdown()
-    gc.collect()
+        if not saved:
+            if (
+                args.data.load_async
+                and args.data.async_persist_type == PersistType.EXACT
+            ):
+                train_state.data_loader_state, data_loader, batch_iterator = (
+                    get_state_and_refresh(data_loader)
+                )
+            else:
+                train_state.data_loader_state = data_loader.get_state()
+            checkpoint.save(
+                model,
+                optimizer,
+                train_state,
+                args,
+                device_mesh=world_mesh,
+            )
+        if isinstance(data_loader, MultiprocessIterator):
+            logger.info("Closing MP iterator before exiting")
+            data_loader.shutdown()
+        gc.collect()
 
 
 def main():