diff --git a/bytelatent/args.py b/bytelatent/args.py index b9144c6..a332c89 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -150,11 +150,13 @@ class DataloaderArgs(BaseModel): enable_byte_ngrams=self.enable_byte_ngrams, ) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) - mp_iterator = MultiprocessIterator( - packing_iterator, n_batches_to_prefetch=self.prefetch_size - ) - - return mp_iterator + if self.load_async: + mp_iterator = MultiprocessIterator( + packing_iterator, n_batches_to_prefetch=self.prefetch_size + ) + return mp_iterator + else: + return packing_iterator class TrainArgs(BaseModel):