diff --git a/bytelatent/args.py b/bytelatent/args.py
index 4a67231..bad4d17 100644
--- a/bytelatent/args.py
+++ b/bytelatent/args.py
@@ -138,7 +138,8 @@ class DataloaderArgs(BaseModel):
     preprocess_dir: str | None = None
     dataset_files: list[str] | None = None
     entropy_model_name: str | None = "transformer_100m"
-    arrow_batch_size: int = 100
+    # Be very careful with increasing, increases memory usage by that factor per rank, per data source
+    arrow_batch_size: int = 20
     buffer_size: int = 64
     file_format: str = "arrow"
 
diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py
index ee57472..bfb1c17 100644
--- a/bytelatent/data/iterators/arrow_iterator.py
+++ b/bytelatent/data/iterators/arrow_iterator.py
@@ -226,7 +226,13 @@ class ArrowFileIterator(StatefulIterator):
                 if (self.row_num - 1) % self.num_workers == self.worker_id:
                     yield out
 
-        self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
+        self.batch_iterator = self.dataset.to_batches(
+            batch_size=self.arrow_batch_size,
+            # We have large files in GBs, no need to readahead
+            fragment_readahead=1,
+            # Don't readahead in case batches are huge (e.g., books)
+            batch_readahead=1,
+        )
         for batch in self.batch_iterator:
             batch_columns = batch.to_pydict()
             if self.file_format == "arrow":
diff --git a/bytelatent/data/iterators/sequence_iterator.py b/bytelatent/data/iterators/sequence_iterator.py
index 0a492be..3b5214d 100644
--- a/bytelatent/data/iterators/sequence_iterator.py
+++ b/bytelatent/data/iterators/sequence_iterator.py
@@ -10,6 +10,9 @@ from bytelatent.data.iterators.abstract_iterator import (
     PydanticIteratorState,
     StatefulIterator,
 )
+from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
+from bytelatent.data.iterators.limit_iterator import LimitIterator
+from bytelatent.data.iterators.looping_iterator import LoopingIterator
 from bytelatent.data.iterators.preprocess_iterator import (
     PreprocessIterator,
     PreprocessIteratorState,
@@ -40,6 +43,21 @@ class SequenceIteratorState(PydanticIteratorState):
         )
 
 
+def get_datafile(
+    iterator: PreprocessIterator | ArrowFileIterator | LoopingIterator | LimitIterator,
+):
+    if isinstance(iterator, ArrowFileIterator):
+        return f"file={iterator.file_path} n_shards={len(iterator.dataset_files) if iterator.dataset_files is not None else None}"
+    elif isinstance(iterator, PreprocessIterator):
+        return get_datafile(iterator.arrow_iterator)
+    elif isinstance(iterator, LoopingIterator):
+        return get_datafile(iterator.file_iterator)
+    elif isinstance(iterator, LimitIterator):
+        return get_datafile(iterator.base_iterator)
+    else:
+        raise NotImplementedError()
+
+
 class SequenceIterator(StatefulIterator):
     def __init__(
         self,
@@ -74,6 +92,10 @@ class SequenceIterator(StatefulIterator):
         tokens: list[int] = []
         mask: list[bool] = []
         first = True
+        logger.info(
+            "Starting first buffer for: %s",
+            get_datafile(self.preprocess_iterator),
+        )
         for example in example_iter:
             assert example.tokens is not None
             assert example.mask is not None
@@ -97,7 +119,10 @@ class SequenceIterator(StatefulIterator):
             while len(patch_lengths) >= n_buffer_patches:
                 if first:
                     first = False
-                    logger.info("First buffer complete")
+                    logger.info(
+                        "First buffer complete for: %s",
+                        get_datafile(self.preprocess_iterator),
+                    )
 
                 x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
                     self.buffer_size, self.output_seq_len
diff --git a/bytelatent/iterate_data.py b/bytelatent/iterate_data.py
new file mode 100644
index 0000000..bdb8f22
--- /dev/null
+++ b/bytelatent/iterate_data.py
@@ -0,0 +1,30 @@
+import json
+
+import pyarrow
+import typer
+from rich.progress import track
+
+from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIteratorState
+from bytelatent.logger import init_logger
+
+
+def main(state_file: str):
+    init_logger()
+    pyarrow.set_io_thread_count(4)
+    pyarrow.set_cpu_count(4)
+    with open(state_file) as f:
+        train_state = json.load(f)
+        dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
+        packing_iterator_state = dl_state.base_iterator_state
+        print("building")
+        packing_iterator = packing_iterator_state.build()
+        print("iter")
+        batch_iter = packing_iterator.create_iter()
+        batch = None
+        print("looping")
+        for i in track(range(1_000)):
+            batch = next(batch_iter)
+
+
+if __name__ == "__main__":
+    typer.run(main)
diff --git a/bytelatent/train.py b/bytelatent/train.py
index ad2fd9b..5a8f937 100644
--- a/bytelatent/train.py
+++ b/bytelatent/train.py
@@ -13,6 +13,7 @@ from timeit import default_timer as timer
 from typing import Any, TypeVar
 
 import numpy as np
+import pyarrow
 import torch
 import torch.distributed
 import torch.nn.functional
@@ -266,6 +267,8 @@ def compute_loss(p, y, mask, scale):
 
 def train(args: TrainArgs):
     with ExitStack() as context_stack:
+        pyarrow.set_io_thread_count(4)
+        pyarrow.set_cpu_count(4)
         tokenizer = args.data.tokenizer_args.build()
         validate_train_args(
             args,