mirror of
https://github.com/facebookresearch/blt.git
synced 2025-02-22 21:12:15 +00:00
Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
This commit is contained in:
parent
85c2f28f26
commit
8c61ab5e67
|
@ -1,10 +1,8 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import fsspec
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
|
@ -21,3 +21,13 @@ class IteratorState(Generic[C]):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def build(self) -> StatefulIterator[T, C]:
|
def build(self) -> StatefulIterator[T, C]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_and_refresh(iterator: StatefulIterator):
|
||||||
|
# Re-init dataloader and iterator is necessary since get_state()
|
||||||
|
# on mp iterator shuts down MP to correctly persist state and it needs
|
||||||
|
# to be restarted.
|
||||||
|
state = iterator.get_state()
|
||||||
|
data_loader = state.build()
|
||||||
|
py_iterator = data_loader.create_iter()
|
||||||
|
return state, data_loader, py_iterator
|
||||||
|
|
|
@ -60,6 +60,13 @@ def shard_sort_key(file: str):
|
||||||
return shard_number
|
return shard_number
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_truncate_string(text: str, max_length: int):
|
||||||
|
if len(text) <= max_length:
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
return text[:max_length] + "..."
|
||||||
|
|
||||||
|
|
||||||
class ArrowFileIterator(StatefulIterator):
|
class ArrowFileIterator(StatefulIterator):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -235,9 +242,8 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
def _set_row_num(self, target_row_num: int):
|
def _set_row_num(self, target_row_num: int):
|
||||||
logger.info(
|
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
||||||
f"Setting arrow position to {target_row_num} for {self.dataset_files}"
|
logger.info(f"Setting arrow position to {target_row_num} for {data_str}")
|
||||||
)
|
|
||||||
if target_row_num is None or target_row_num == 0:
|
if target_row_num is None or target_row_num == 0:
|
||||||
self.row_num = 0
|
self.row_num = 0
|
||||||
self.dataset = None
|
self.dataset = None
|
||||||
|
@ -285,6 +291,7 @@ class ArrowFileIterator(StatefulIterator):
|
||||||
else:
|
else:
|
||||||
curr_remaining -= len(batch)
|
curr_remaining -= len(batch)
|
||||||
self.row_num = target_row_num
|
self.row_num = target_row_num
|
||||||
|
data_str = maybe_truncate_string(str(self.dataset_files), 200)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Finished setting arrow position to {target_row_num} for {self.dataset_files}"
|
f"Finished setting arrow position to {target_row_num} for {data_str}"
|
||||||
)
|
)
|
||||||
|
|
|
@ -54,9 +54,10 @@ def start_work_from_state(
|
||||||
if stop_event.is_set():
|
if stop_event.is_set():
|
||||||
# Signal the end of output, this ensures that even if the queue takes a while to
|
# Signal the end of output, this ensures that even if the queue takes a while to
|
||||||
# buffer, that the main thread receives everything (and tosses this fake batch)
|
# buffer, that the main thread receives everything (and tosses this fake batch)
|
||||||
logging.info(
|
logging.debug(
|
||||||
"Worker thread: Stop event detected, outputting is_final=True batch"
|
"Worker thread: Stop event detected, outputting is_final=True batch"
|
||||||
)
|
)
|
||||||
|
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
|
||||||
batch_queue.put(
|
batch_queue.put(
|
||||||
Batch(
|
Batch(
|
||||||
x=np.zeros((1, 1)),
|
x=np.zeros((1, 1)),
|
||||||
|
@ -67,14 +68,17 @@ def start_work_from_state(
|
||||||
ngram_ids=None,
|
ngram_ids=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
logging.debug(
|
||||||
|
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.info("Worker thread: outputting state")
|
logging.debug("Worker thread: outputting state")
|
||||||
state_queue.put(iterator.get_state(), timeout=1)
|
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
||||||
logging.info("Worker thread: state dump complete")
|
logging.debug("Worker thread: state dump complete")
|
||||||
state_dumped_event.set()
|
state_dumped_event.set()
|
||||||
logging.info("Worker thread: set state_dump_event")
|
logging.debug("Worker thread: set state_dump_event")
|
||||||
except Full:
|
except Full:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Attempted to dump state into the state queue, but it was full"
|
"Attempted to dump state into the state queue, but it was full"
|
||||||
|
@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Main thread: Sending stop iteration event")
|
logging.debug("Main thread: Sending stop iteration event")
|
||||||
self.stop_iterating_event.set()
|
self.stop_iterating_event.set()
|
||||||
logging.info("Main thread: Waiting for state_dumped event")
|
logging.debug(
|
||||||
self.state_dumped_event.wait()
|
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
||||||
|
)
|
||||||
self.prefetch_buffer = []
|
self.prefetch_buffer = []
|
||||||
final_batch_received = False
|
final_batch_received = False
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
batch = self.batch_queue.get(timeout=1)
|
batch = self.batch_queue.get(timeout=1)
|
||||||
if batch.is_final:
|
if batch.is_final:
|
||||||
|
logging.debug(
|
||||||
|
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
||||||
|
)
|
||||||
final_batch_received = True
|
final_batch_received = True
|
||||||
break
|
break
|
||||||
self.prefetch_buffer.append(batch)
|
self.prefetch_buffer.append(batch)
|
||||||
|
@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
logging.warning("Main thread: batch_queue is abnormally empty")
|
logging.warning("Main thread: batch_queue is abnormally empty")
|
||||||
assert final_batch_received
|
assert final_batch_received
|
||||||
|
|
||||||
|
logging.debug("Main thread: Waiting for state_dumped event")
|
||||||
|
self.state_dumped_event.wait()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
base_iterator_state = self.state_queue.get(timeout=1)
|
base_iterator_state = self.state_queue.get(timeout=1)
|
||||||
assert isinstance(base_iterator_state, IteratorState)
|
assert isinstance(base_iterator_state, IteratorState)
|
||||||
|
|
|
@ -26,6 +26,7 @@ from torch.optim import lr_scheduler
|
||||||
from bytelatent.args import TrainArgs, parse_args
|
from bytelatent.args import TrainArgs, parse_args
|
||||||
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
|
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
|
||||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||||
MultiprocessIterator,
|
MultiprocessIterator,
|
||||||
MultiprocessIteratorState,
|
MultiprocessIteratorState,
|
||||||
|
@ -35,7 +36,6 @@ from bytelatent.distributed import (
|
||||||
check_model_value_range,
|
check_model_value_range,
|
||||||
clean_env,
|
clean_env,
|
||||||
dist_mean,
|
dist_mean,
|
||||||
dist_mean_dict,
|
|
||||||
dist_sum,
|
dist_sum,
|
||||||
get_device_mesh,
|
get_device_mesh,
|
||||||
get_is_master,
|
get_is_master,
|
||||||
|
@ -88,6 +88,13 @@ def get_iterator_state_name(iterator_state):
|
||||||
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
|
||||||
|
|
||||||
|
|
||||||
|
def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
|
||||||
|
if isinstance(num, (torch.Tensor, np.ndarray)):
|
||||||
|
return num.item()
|
||||||
|
else:
|
||||||
|
return num
|
||||||
|
|
||||||
|
|
||||||
# TODO: Make this pydantic based instead of data class based
|
# TODO: Make this pydantic based instead of data class based
|
||||||
# TODO: Generalize this to any iterator state
|
# TODO: Generalize this to any iterator state
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -603,20 +610,20 @@ def train(args: TrainArgs):
|
||||||
# step: Metric at a step
|
# step: Metric at a step
|
||||||
# interval: Metric averaged/summed across all steps since the last log interval.
|
# interval: Metric averaged/summed across all steps since the last log interval.
|
||||||
# Typically, this is 10
|
# Typically, this is 10
|
||||||
step_loss_per_gpu = loss.item()
|
step_loss_per_gpu = loss
|
||||||
step_loss_across_gpus = dist_mean(step_loss_per_gpu).item()
|
step_loss_across_gpus = dist_mean(step_loss_per_gpu)
|
||||||
interval_loss_per_gpu = np.mean(step_losses).item()
|
interval_loss_per_gpu = np.mean(step_losses)
|
||||||
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item()
|
interval_loss_across_gpus = dist_mean(interval_loss_per_gpu)
|
||||||
|
|
||||||
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
|
||||||
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item()
|
interval_total_tok_loss_per_gpu = stacked_tok_loss.sum()
|
||||||
interval_total_tok_loss_across_gpus = dist_sum(
|
interval_total_tok_loss_across_gpus = dist_sum(
|
||||||
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
|
||||||
).item()
|
)
|
||||||
interval_total_n_bytes_per_gpu = n_bytes.item()
|
interval_total_n_bytes_per_gpu = n_bytes
|
||||||
interval_total_n_bytes_across_gpus = dist_sum(
|
interval_total_n_bytes_across_gpus = dist_sum(
|
||||||
n_bytes, reduce_dtype=torch.bfloat16
|
n_bytes, reduce_dtype=torch.bfloat16
|
||||||
).item()
|
)
|
||||||
|
|
||||||
interval_bpb_per_gpu = (
|
interval_bpb_per_gpu = (
|
||||||
interval_total_tok_loss_per_gpu
|
interval_total_tok_loss_per_gpu
|
||||||
|
@ -645,18 +652,20 @@ def train(args: TrainArgs):
|
||||||
},
|
},
|
||||||
"memory": gpu_mem_stats._asdict(),
|
"memory": gpu_mem_stats._asdict(),
|
||||||
"loss": {
|
"loss": {
|
||||||
"step_per_gpu": step_loss_per_gpu,
|
"step_per_gpu": to_py_num(step_loss_per_gpu),
|
||||||
"step_across_gpu": step_loss_across_gpus,
|
"step_across_gpu": to_py_num(step_loss_across_gpus),
|
||||||
"interval_per_gpu": interval_loss_per_gpu,
|
"interval_per_gpu": to_py_num(interval_loss_per_gpu),
|
||||||
"interval_across_gpu": interval_loss_across_gpus,
|
"interval_across_gpu": to_py_num(interval_loss_across_gpus),
|
||||||
},
|
},
|
||||||
"bpb": {
|
"bpb": {
|
||||||
"interval_per_gpu": interval_bpb_per_gpu,
|
"interval_per_gpu": to_py_num(interval_bpb_per_gpu),
|
||||||
"interval_across_gpus": interval_bpb_across_gpus,
|
"interval_across_gpus": to_py_num(interval_bpb_across_gpus),
|
||||||
},
|
},
|
||||||
"n_bytes": {
|
"n_bytes": {
|
||||||
"interval_per_gpu": interval_total_n_bytes_per_gpu,
|
"interval_per_gpu": to_py_num(interval_total_n_bytes_per_gpu),
|
||||||
"interval_across_gpus": interval_total_n_bytes_across_gpus,
|
"interval_across_gpus": to_py_num(
|
||||||
|
interval_total_n_bytes_across_gpus
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -676,8 +685,8 @@ def train(args: TrainArgs):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"step: {train_state.step}"
|
f"step: {train_state.step}"
|
||||||
f" acc: {train_state.acc_step}"
|
f" acc: {train_state.acc_step}"
|
||||||
f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
|
f" loss_gpu: {round(to_py_num(interval_loss_per_gpu), 4):>7}"
|
||||||
f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
|
f" loss_avg: {round(to_py_num(interval_loss_across_gpus), 4):>7}"
|
||||||
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
f" bpb_gpu: {interval_bpb_per_gpu:3f}"
|
||||||
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
f" bpb_avg: {interval_bpb_across_gpus:3f}"
|
||||||
f" grad: {grad_norm:.2e}"
|
f" grad: {grad_norm:.2e}"
|
||||||
|
@ -702,6 +711,9 @@ def train(args: TrainArgs):
|
||||||
if every_n_steps(
|
if every_n_steps(
|
||||||
train_state, args.checkpoint.dump.every, acc_step=0
|
train_state, args.checkpoint.dump.every, acc_step=0
|
||||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||||
|
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||||
|
get_state_and_refresh(data_loader)
|
||||||
|
)
|
||||||
saved = checkpoint.save(
|
saved = checkpoint.save(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -743,6 +755,9 @@ def train(args: TrainArgs):
|
||||||
|
|
||||||
if preemption_flag["flag"]:
|
if preemption_flag["flag"]:
|
||||||
if not saved:
|
if not saved:
|
||||||
|
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||||
|
get_state_and_refresh(data_loader)
|
||||||
|
)
|
||||||
checkpoint.save(
|
checkpoint.save(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -754,6 +769,9 @@ def train(args: TrainArgs):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if not saved:
|
if not saved:
|
||||||
|
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||||
|
get_state_and_refresh(data_loader)
|
||||||
|
)
|
||||||
checkpoint.save(
|
checkpoint.save(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|
Loading…
Reference in a new issue