Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-13 11:58:23 -08:00 committed by GitHub
parent 85c2f28f26
commit 8c61ab5e67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 77 additions and 33 deletions

View file

@ -1,10 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import logging import logging
import os import os
from typing import Any from typing import Any
import fsspec
import numpy as np import numpy as np
import yaml import yaml
from omegaconf import OmegaConf from omegaconf import OmegaConf

View file

@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
@abc.abstractmethod @abc.abstractmethod
def build(self) -> StatefulIterator[T, C]: def build(self) -> StatefulIterator[T, C]:
pass pass
def get_state_and_refresh(iterator: StatefulIterator):
# Re-init dataloader and iterator is necessary since get_state()
# on mp iterator shuts down MP to correctly persist state and it needs
# to be restarted.
state = iterator.get_state()
data_loader = state.build()
py_iterator = data_loader.create_iter()
return state, data_loader, py_iterator

View file

@ -60,6 +60,13 @@ def shard_sort_key(file: str):
return shard_number return shard_number
def maybe_truncate_string(text: str, max_length: int):
if len(text) <= max_length:
return text
else:
return text[:max_length] + "..."
class ArrowFileIterator(StatefulIterator): class ArrowFileIterator(StatefulIterator):
def __init__( def __init__(
self, self,
@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
yield out yield out
def _set_row_num(self, target_row_num: int): def _set_row_num(self, target_row_num: int):
logger.info( data_str = maybe_truncate_string(str(self.dataset_files), 200)
f"Setting arrow position to {target_row_num} for {self.dataset_files}" logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
)
if target_row_num is None or target_row_num == 0: if target_row_num is None or target_row_num == 0:
self.row_num = 0 self.row_num = 0
self.dataset = None self.dataset = None
@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
else: else:
curr_remaining -= len(batch) curr_remaining -= len(batch)
self.row_num = target_row_num self.row_num = target_row_num
data_str = maybe_truncate_string(str(self.dataset_files), 200)
logger.info( logger.info(
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" f"Finished setting arrow position to {target_row_num} for {data_str}"
) )

View file

@ -54,9 +54,10 @@ def start_work_from_state(
if stop_event.is_set(): if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to # Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch) # buffer, that the main thread receives everything (and tosses this fake batch)
logging.info( logging.debug(
"Worker thread: Stop event detected, outputting is_final=True batch" "Worker thread: Stop event detected, outputting is_final=True batch"
) )
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
batch_queue.put( batch_queue.put(
Batch( Batch(
x=np.zeros((1, 1)), x=np.zeros((1, 1)),
@ -67,14 +68,17 @@ def start_work_from_state(
ngram_ids=None, ngram_ids=None,
) )
) )
logging.debug(
"Worker thread: is_final=True batch put in queue, breaking from loop."
)
break break
try: try:
logging.info("Worker thread: outputting state") logging.debug("Worker thread: outputting state")
state_queue.put(iterator.get_state(), timeout=1) state_queue.put(stateful_iterator.get_state(), timeout=1)
logging.info("Worker thread: state dump complete") logging.debug("Worker thread: state dump complete")
state_dumped_event.set() state_dumped_event.set()
logging.info("Worker thread: set state_dump_event") logging.debug("Worker thread: set state_dump_event")
except Full: except Full:
raise ValueError( raise ValueError(
"Attempted to dump state into the state queue, but it was full" "Attempted to dump state into the state queue, but it was full"
@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
serialized_prefetch_buffer=serialized_prefetch_buffer, serialized_prefetch_buffer=serialized_prefetch_buffer,
) )
else: else:
logging.info("Main thread: Sending stop iteration event") logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set() self.stop_iterating_event.set()
logging.info("Main thread: Waiting for state_dumped event") logging.debug(
self.state_dumped_event.wait() "Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = [] self.prefetch_buffer = []
final_batch_received = False final_batch_received = False
while True: while True:
try: try:
batch = self.batch_queue.get(timeout=1) batch = self.batch_queue.get(timeout=1)
if batch.is_final: if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True final_batch_received = True
break break
self.prefetch_buffer.append(batch) self.prefetch_buffer.append(batch)
@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
logging.warning("Main thread: batch_queue is abnormally empty") logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received assert final_batch_received
logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try: try:
base_iterator_state = self.state_queue.get(timeout=1) base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState) assert isinstance(base_iterator_state, IteratorState)

View file

@ -26,6 +26,7 @@ from torch.optim import lr_scheduler
from bytelatent.args import TrainArgs, parse_args from bytelatent.args import TrainArgs, parse_args
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.data.file_util import get_fs from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import ( from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator, MultiprocessIterator,
MultiprocessIteratorState, MultiprocessIteratorState,
@ -35,7 +36,6 @@ from bytelatent.distributed import (
check_model_value_range, check_model_value_range,
clean_env, clean_env,
dist_mean, dist_mean,
dist_mean_dict,
dist_sum, dist_sum,
get_device_mesh, get_device_mesh,
get_is_master, get_is_master,
@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state):
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}") raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
if isinstance(num, (torch.Tensor, np.ndarray)):
return num.item()
else:
return num
# TODO: Make this pydantic based instead of data class based # TODO: Make this pydantic based instead of data class based
# TODO: Generalize this to any iterator state # TODO: Generalize this to any iterator state
@dataclass @dataclass
@ -603,20 +610,20 @@ def train(args: TrainArgs):
# step: Metric at a step # step: Metric at a step
# interval: Metric averaged/summed across all steps since the last log interval. # interval: Metric averaged/summed across all steps since the last log interval.
# Typically, this is 10 # Typically, this is 10
step_loss_per_gpu = loss.item() step_loss_per_gpu = loss
step_loss_across_gpus = dist_mean(step_loss_per_gpu).item() step_loss_across_gpus = dist_mean(step_loss_per_gpu)
interval_loss_per_gpu = np.mean(step_losses).item() interval_loss_per_gpu = np.mean(step_losses)
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item() interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
stacked_tok_loss = torch.cat(step_tok_losses, dim=0) stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item() interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
interval_total_tok_loss_across_gpus = dist_sum( interval_total_tok_loss_across_gpus = dist_sum(
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16 interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
).item() )
interval_total_n_bytes_per_gpu = n_bytes.item() interval_total_n_bytes_per_gpu = n_bytes
interval_total_n_bytes_across_gpus = dist_sum( interval_total_n_bytes_across_gpus = dist_sum(
n_bytes, reduce_dtype=torch.bfloat16 n_bytes, reduce_dtype=torch.bfloat16
).item() )
interval_bpb_per_gpu = ( interval_bpb_per_gpu = (
interval_total_tok_loss_per_gpu interval_total_tok_loss_per_gpu
@ -645,18 +652,20 @@ def train(args: TrainArgs):
}, },
"memory": gpu_mem_stats._asdict(), "memory": gpu_mem_stats._asdict(),
"loss": { "loss": {
"step_per_gpu": step_loss_per_gpu, "step_per_gpu": to_py_num(step_loss_per_gpu),
"step_across_gpu": step_loss_across_gpus, "step_across_gpu": to_py_num(step_loss_across_gpus),
"interval_per_gpu": interval_loss_per_gpu, "interval_per_gpu": to_py_num(interval_loss_per_gpu),
"interval_across_gpu": interval_loss_across_gpus, "interval_across_gpu": to_py_num(interval_loss_across_gpus),
}, },
"bpb": { "bpb": {
"interval_per_gpu": interval_bpb_per_gpu, "interval_per_gpu": to_py_num(interval_bpb_per_gpu),
"interval_across_gpus": interval_bpb_across_gpus, "interval_across_gpus": to_py_num(interval_bpb_across_gpus),
}, },
"n_bytes": { "n_bytes": {
"interval_per_gpu": interval_total_n_bytes_per_gpu, "interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu),
"interval_across_gpus": interval_total_n_bytes_across_gpus, "interval_across_gpus": to_py_num(
interval_total_n_bytes_across_gpus
),
}, },
} }
@ -676,8 +685,8 @@ def train(args: TrainArgs):
logger.info( logger.info(
f"step: {train_state.step}" f"step: {train_state.step}"
f" acc: {train_state.acc_step}" f" acc: {train_state.acc_step}"
f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}" f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}"
f" loss_avg: {round(interval_loss_across_gpus, 4):>7}" f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}"
f" bpb_gpu: {interval_bpb_per_gpu:3f}" f" bpb_gpu: {interval_bpb_per_gpu:3f}"
f" bpb_avg: {interval_bpb_across_gpus:3f}" f" bpb_avg: {interval_bpb_across_gpus:3f}"
f" grad: {grad_norm:.2e}" f" grad: {grad_norm:.2e}"
@ -702,6 +711,9 @@ def train(args: TrainArgs):
if every_n_steps( if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0 train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
saved = checkpoint.save( saved = checkpoint.save(
model, model,
optimizer, optimizer,
@ -743,6 +755,9 @@ def train(args: TrainArgs):
if preemption_flag["flag"]: if preemption_flag["flag"]:
if not saved: if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save( checkpoint.save(
model, model,
optimizer, optimizer,
@ -754,6 +769,9 @@ def train(args: TrainArgs):
sys.exit(0) sys.exit(0)
if not saved: if not saved:
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
checkpoint.save( checkpoint.save(
model, model,
optimizer, optimizer,