mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-13 15:49:42 +00:00
Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
This commit is contained in:
parent
85c2f28f26
commit
8c61ab5e67
5 changed files with 77 additions and 33 deletions
|
@ -60,6 +60,13 @@ def shard_sort_key(file: str):
|
|||
return shard_number
|
||||
|
||||
|
||||
def maybe_truncate_string(text: str, max_length: int):
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
else:
|
||||
return text[:max_length] + "..."
|
||||
|
||||
|
||||
class ArrowFileIterator(StatefulIterator):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
|
|||
yield out
|
||||
|
||||
def _set_row_num(self, target_row_num: int):
|
||||
logger.info(
|
||||
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
)
|
||||
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
||||
logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
|
||||
if target_row_num is None or target_row_num == 0:
|
||||
self.row_num = 0
|
||||
self.dataset = None
|
||||
|
@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
else:
|
||||
curr_remaining -= len(batch)
|
||||
self.row_num = target_row_num
|
||||
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
||||
logger.info(
|
||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
f"Finished setting arrow position to {target_row_num} for {data_str}"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue