diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 34f58d3..ee57472 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -197,9 +197,6 @@ class ArrowFileIterator(StatefulIterator): self.dataset = pa.dataset.dataset( self.dataset_files, format=self.file_format, filesystem=filesystem ) - self.batch_iterator = self.dataset.to_batches( - batch_size=self.arrow_batch_size - ) self.iter_id += 1 if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume @@ -229,6 +226,7 @@ class ArrowFileIterator(StatefulIterator): if (self.row_num - 1) % self.num_workers == self.worker_id: yield out + self.batch_iterator = self.dataset.to_batches(batch_size=self.arrow_batch_size) for batch in self.batch_iterator: batch_columns = batch.to_pydict() if self.file_format == "arrow":