Fix multiprocessing dataloader checkpointing and use it in the train script

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-12 18:09:26 +00:00
parent 4cee32ea8c
commit bd3cf61bb9
5 changed files with 41 additions and 13 deletions

View file

@ -1,10 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging
import os
from typing import Any
import fsspec
import numpy as np
import yaml
from omegaconf import OmegaConf

View file

@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
@abc.abstractmethod
def build(self) -> StatefulIterator[T, C]:
pass
def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs
# to be restarted.
state = iterator.get_state()
data_loader = state.build()
py_iterator = data_loader.create_iter()
return state, data_loader, py_iterator

View file

@ -236,7 +236,7 @@ class ArrowFileIterator(StatefulIterator):
def _set_row_num(self, target_row_num: int):
logger.info(
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
f"Setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
)
if target_row_num is None or target_row_num == 0:
self.row_num = 0
@ -286,5 +286,5 @@ class ArrowFileIterator(StatefulIterator):
curr_remaining -= len(batch)
self.row_num = target_row_num
logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
f"Finished setting arrow position to {target_row_num} for {str(self.dataset_files)[:200]}"
)

View file

@ -54,9 +54,10 @@ def start_work_from_state(
if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch)
logging.info(
logging.debug(
"Worker thread: Stop event detected, outputting is_final=True batch"
)
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
batch_queue.put(
Batch(
x=np.zeros((1, 1)),
@ -67,14 +68,17 @@ def start_work_from_state(
ngram_ids=None,
)
)
logging.debug(
"Worker thread: is_final=True batch put in queue, breaking from loop."
)
break
try:
logging.info("Worker thread: outputting state")
state_queue.put(iterator.get_state(), timeout=1)
logging.info("Worker thread: state dump complete")
logging.debug("Worker thread: outputting state")
state_queue.put(stateful_iterator.get_state(), timeout=1)
logging.debug("Worker thread: state dump complete")
state_dumped_event.set()
logging.info("Worker thread: set state_dump_event")
logging.debug("Worker thread: set state_dump_event")
except Full:
raise ValueError(
"Attempted to dump state into the state queue, but it was full"
@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
serialized_prefetch_buffer=serialized_prefetch_buffer,
)
else:
logging.info("Main thread: Sending stop iteration event")
logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.info("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
logging.debug(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)

View file

@ -25,6 +25,7 @@ from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
MultiprocessIteratorState,
@ -34,7 +35,6 @@ from bytelatent.distributed import (
check_model_value_range,
clean_env,
dist_mean,
dist_mean_dict,
dist_sum,
get_device_mesh,
get_is_master,
@ -700,6 +700,9 @@ def train(args: TrainArgs):
if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
saved = checkpoint.save(
model,
optimizer,
@ -741,6 +744,9 @@ def train(args: TrainArgs):
if preemption_flag["flag"]:
if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save(
model,
optimizer,
@ -752,6 +758,9 @@ def train(args: TrainArgs):
sys.exit(0)
if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save(
model,
optimizer,