mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Fix multiprocessing dataloader checkpointing and use it in the train script
Summary: Test Plan:
This commit is contained in:
parent
4cee32ea8c
commit
bd3cf61bb9
|
@ -1,10 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
|
|
@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
|
|||
@abc.abstractmethod
|
||||
def build(self) -> StatefulIterator[T, C]:
|
||||
pass
|
||||
|
||||
|
||||
def get_state_and_refresh(iterator: StatefulIterator):
|
||||
# Re-init dataloader and iterator is necessary since get_state()
|
||||
# on mp iterator shuts down MP to correctly persist state and it needs
|
||||
# to be restarted.
|
||||
state = iterator.get_state()
|
||||
data_loader = state.build()
|
||||
py_iterator = data_loader.create_iter()
|
||||
return state, data_loader, py_iterator
|
||||
|
|
|
@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator):
|
|||
|
||||
def _set_row_num(self, target_row_num: int):
|
||||
logger.info(
|
||||
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
|
||||
)
|
||||
if target_row_num is None or target_row_num == 0:
|
||||
self.row_num = 0
|
||||
|
@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator):
|
|||
curr_remaining -= len(batch)
|
||||
self.row_num = target_row_num
|
||||
logger.info(
|
||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
||||
f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
|
||||
)
|
||||
|
|
|
@ -54,9 +54,10 @@ def start_work_from_state(
|
|||
if stop_event.is_set():
|
||||
# Signal the end of output, this ensures that even if the queue takes a while to
|
||||
# buffer, that the main thread receives everything (and tosses this fake batch)
|
||||
logging.info(
|
||||
logging.debug(
|
||||
"Worker thread: Stop event detected, outputting is_final=True batch"
|
||||
)
|
||||
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
|
||||
batch_queue.put(
|
||||
Batch(
|
||||
x=np.zeros((1, 1)),
|
||||
|
@ -67,14 +68,17 @@ def start_work_from_state(
|
|||
ngram_ids=None,
|
||||
)
|
||||
)
|
||||
logging.debug(
|
||||
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
logging.info("Worker thread: outputting state")
|
||||
state_queue.put(iterator.get_state(), timeout=1)
|
||||
logging.info("Worker thread: state dump complete")
|
||||
logging.debug("Worker thread: outputting state")
|
||||
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
||||
logging.debug("Worker thread: state dump complete")
|
||||
state_dumped_event.set()
|
||||
logging.info("Worker thread: set state_dump_event")
|
||||
logging.debug("Worker thread: set state_dump_event")
|
||||
except Full:
|
||||
raise ValueError(
|
||||
"Attempted to dump state into the state queue, but it was full"
|
||||
|
@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
|
|||
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
||||
)
|
||||
else:
|
||||
logging.info("Main thread: Sending stop iteration event")
|
||||
logging.debug("Main thread: Sending stop iteration event")
|
||||
self.stop_iterating_event.set()
|
||||
logging.info("Main thread: Waiting for state_dumped event")
|
||||
self.state_dumped_event.wait()
|
||||
logging.debug(
|
||||
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
||||
)
|
||||
self.prefetch_buffer = []
|
||||
final_batch_received = False
|
||||
while True:
|
||||
try:
|
||||
batch = self.batch_queue.get(timeout=1)
|
||||
if batch.is_final:
|
||||
logging.debug(
|
||||
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
||||
)
|
||||
final_batch_received = True
|
||||
break
|
||||
self.prefetch_buffer.append(batch)
|
||||
|
@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
|
|||
logging.warning("Main thread: batch_queue is abnormally empty")
|
||||
assert final_batch_received
|
||||
|
||||
logging.debug("Main thread: Waiting for state_dumped event")
|
||||
self.state_dumped_event.wait()
|
||||
|
||||
try:
|
||||
base_iterator_state = self.state_queue.get(timeout=1)
|
||||
assert isinstance(base_iterator_state, IteratorState)
|
||||
|
|
|
@ -25,6 +25,7 @@ from torch.optim import lr_scheduler
|
|||
from bytelatent.args import TrainArgs, parse_args
|
||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||
from bytelatent.data.file_util import get_fs
|
||||
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
|
||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||
MultiprocessIterator,
|
||||
MultiprocessIteratorState,
|
||||
|
@ -34,7 +35,6 @@ from bytelatent.distributed import (
|
|||
check_model_value_range,
|
||||
clean_env,
|
||||
dist_mean,
|
||||
dist_mean_dict,
|
||||
dist_sum,
|
||||
get_device_mesh,
|
||||
get_is_master,
|
||||
|
@ -700,6 +700,9 @@ def train(args: TrainArgs):
|
|||
if every_n_steps(
|
||||
train_state, args.checkpoint.dump.every, acc_step=0
|
||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||
get_state_and_refresh(data_loader)
|
||||
)
|
||||
saved = checkpoint.save(
|
||||
model,
|
||||
optimizer,
|
||||
|
@ -741,6 +744,9 @@ def train(args: TrainArgs):
|
|||
|
||||
if preemption_flag["flag"]:
|
||||
if not saved:
|
||||
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||
get_state_and_refresh(data_loader)
|
||||
)
|
||||
checkpoint.save(
|
||||
model,
|
||||
optimizer,
|
||||
|
@ -752,6 +758,9 @@ def train(args: TrainArgs):
|
|||
sys.exit(0)
|
||||
|
||||
if not saved:
|
||||
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||
get_state_and_refresh(data_loader)
|
||||
)
|
||||
checkpoint.save(
|
||||
model,
|
||||
optimizer,
|
||||
|
|
Loading…
Reference in a new issue