From c727844e9d15700cf1b7c170a477d8a60e71c27c Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <par@meta.com>
Date: Mon, 3 Mar 2025 16:59:02 -0800
Subject: [PATCH] Correctly reset batch iterator at each arrow create_iter
 call. (#74)

Summary:

Test Plan:
---
 bytelatent/data/iterators/arrow_iterator.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py
index 34f58d3..ee57472 100644
--- a/bytelatent/data/iterators/arrow_iterator.py
+++ b/bytelatent/data/iterators/arrow_iterator.py
@@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator):
             self.dataset = pa.dataset.dataset(
                 self.dataset_files, format=self.file_format, filesystem=filesystem
             )
-            self.batch_iterator = self.dataset.to_batches(
-                batch_size=self.arrow_batch_size
-            )
         self.iter_id += 1
         if self.batch_to_consume is not None:
             batch_columns: dict[str, list] = self.batch_to_consume
@@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator):
                 if (self.row_num - 1) % self.num_workers == self.worker_id:
                     yield out
 
+        self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size)
         for batch in self.batch_iterator:
             batch_columns = batch.to_pydict()
             if self.file_format == "arrow":