From c727844e9d15700cf1b7c170a477d8a60e71c27c Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <par@meta.com> Date: Mon, 3 Mar 2025 16:59:02 -0800 Subject: [PATCH] Correctly reset batch iterator at each arrow create_iter call. (#74) Summary: Test Plan: --- bytelatent/data/iterators/arrow_iterator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 34f58d3..ee57472 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator): self.dataset = pa.dataset.dataset( self.dataset_files, format=self.file_format, filesystem=filesystem ) - self.batch_iterator = self.dataset.to_batches( - batch_size=self.arrow_batch_size - ) self.iter_id += 1 if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume @@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator): if (self.row_num - 1) % self.num_workers == self.worker_id: yield out + self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size) for batch in self.batch_iterator: batch_columns = batch.to_pydict() if self.file_format == "arrow":