From f74aa7bd1a18d180cd06a42a63524a8ba868c592 Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <me@pedro.ai>
Date: Mon, 3 Mar 2025 23:32:29 +0000
Subject: [PATCH] Correctly reset batch iterator at each arrow create_iter
 call.

Summary:

Test Plan:
---
 bytelatent/data/iterators/arrow_iterator.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py
index 34f58d3..ee57472 100644
--- a/bytelatent/data/iterators/arrow_iterator.py
+++ b/bytelatent/data/iterators/arrow_iterator.py
@@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator):
             self.dataset = pa.dataset.dataset(
                 self.dataset_files, format=self.file_format, filesystem=filesystem
             )
-            self.batch_iterator = self.dataset.to_batches(
-                batch_size=self.arrow_batch_size
-            )
         self.iter_id += 1
         if self.batch_to_consume is not None:
             batch_columns: dict[str, list] = self.batch_to_consume
@@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator):
                 if (self.row_num - 1) % self.num_workers == self.worker_id:
                     yield out
 
+        self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
         for batch in self.batch_iterator:
             batch_columns = batch.to_pydict()
             if self.file_format == "arrow":