From f74aa7bd1a18d180cd06a42a63524a8ba868c592 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <me@pedro.ai> Date: Mon, 3 Mar 2025 23:32:29 +0000 Subject: [PATCH] Correctly reset batch iterator at each arrow create_iter call. Summary: Test Plan: --- bytelatent/data/iterators/arrow_iterator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 34f58d3..ee57472 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator): self.dataset = pa.dataset.dataset( self.dataset_files, format=self.file_format, filesystem=filesystem ) - self.batch_iterator = self.dataset.to_batches( - batch_size=self.arrow_batch_size - ) self.iter_id += 1 if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume @@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator): if (self.row_num - 1) % self.num_workers == self.worker_id: yield out + self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size) for batch in self.batch_iterator: batch_columns = batch.to_pydict() if self.file_format == "arrow":