Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-13 11:58:23 -08:00 committed by GitHub
parent 85c2f28f26
commit 8c61ab5e67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 77 additions and 33 deletions

View file

@ -60,6 +60,13 @@ def shard_sort_key(file: str):
return shard_number
def maybe_truncate_string(text: str, max_length: int):
if len(text) <= max_length:
return text
else:
return text[:max_length] + "..."
class ArrowFileIterator(StatefulIterator):
def __init__(
self,
@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
yield out
def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
)
data_str = maybe_truncate_string(str(self.dataset_files), 200)
logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
if target_row_num is None or target_row_num == 0:
self.row_num = 0
self.dataset = None
@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
else:
curr_remaining -= len(batch)
self.row_num = target_row_num
data_str = maybe_truncate_string(str(self.dataset_files), 200)
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
f"Finished setting arrow position to {target_row_num} for {data_str}"
)