Add approximate state persistence ()

Summary:

Test Plan:

***
More verbose multiprocess logging, fix get_state_and_recycle

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-03-05 15:01:45 -08:00 committed by GitHub
parent 9bd51df961
commit ea1fc75862
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 199 additions and 77 deletions

View file

@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
PersistType,
)
from bytelatent.data.iterators.packing_iterator import ( from bytelatent.data.iterators.packing_iterator import (
PackingArgs, PackingArgs,
PackingIterator, PackingIterator,
@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel):
add_bos: bool = True add_bos: bool = True
add_eos: bool = True add_eos: bool = True
load_async: bool = True load_async: bool = True
async_persist_type: PersistType = PersistType.EXACT
prefetch_size: int = 64 prefetch_size: int = 64
preprocess_dir: str | None = None preprocess_dir: str | None = None
dataset_files: list[str] | None = None dataset_files: list[str] | None = None
@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel):
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args) packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
if self.load_async: if self.load_async:
mp_iterator = MultiprocessIterator( mp_iterator = MultiprocessIterator(
packing_iterator, n_batches_to_prefetch=self.prefetch_size packing_iterator,
n_batches_to_prefetch=self.prefetch_size,
persist_type=self.async_persist_type,
) )
return mp_iterator return mp_iterator
else: else:

View file

@ -2,6 +2,7 @@
import json import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
from enum import Enum
from multiprocessing.synchronize import Event as EventClass from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Full from queue import Empty, Full
@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
logger = logging.getLogger() logger = logging.getLogger()
class PersistType(str, Enum):
EXACT = "exact"
APPROXIMATE = "approximate"
class MultiprocessIteratorState(PydanticIteratorState): class MultiprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
base_iterator_state: PackingIteratorState base_iterator_state: PackingIteratorState
n_batches_to_prefetch: int n_batches_to_prefetch: int
serialized_prefetch_buffer: str serialized_prefetch_buffer: str
persist_type: PersistType
def build(self): def build(self):
base_iterator = self.base_iterator_state.build() base_iterator = self.base_iterator_state.build()
@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState):
base_iterator, base_iterator,
n_batches_to_prefetch=self.n_batches_to_prefetch, n_batches_to_prefetch=self.n_batches_to_prefetch,
prefetch_buffer=prefetch_buffer, prefetch_buffer=prefetch_buffer,
persist_type=self.persist_type,
) )
def start_work_from_state( def start_work_from_state(
batch_queue: mp.Queue, batch_queue: mp.Queue,
state_queue: mp.Queue, state_queue: mp.Queue,
approximate_state_queue: mp.Queue,
stop_event: EventClass, stop_event: EventClass,
state_dumped_event: EventClass, state_dumped_event: EventClass,
trigger_approximate_send_state_event: EventClass,
sent_approximate_state_event: EventClass,
received_approximate_state_event: EventClass,
state: IteratorState, state: IteratorState,
): ):
logging.info("Worker thread: Starting base_iterator work") logging.info("Worker thread: Starting base_iterator work")
@ -49,6 +61,25 @@ def start_work_from_state(
for item in iterator: for item in iterator:
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
if trigger_approximate_send_state_event.is_set():
logger.info("WT: trigger_approximate_send ack")
# Since this can be triggered again (but only after the state is received on mp),
# we should cleanup as soon as possible.
trigger_approximate_send_state_event.clear()
logging.info("WT: Computing approximate state")
approximate_state = stateful_iterator.get_state()
# At this state, there should always be exactly 1 slot.
# Blocking here would be a bug.
logger.info("WT: Attempting to send approximate state")
approximate_state_queue.put(
approximate_state, block=True, timeout=None
)
sent_approximate_state_event.set()
logger.info("WT: Approximate state sent")
# Same here, clear events as we no longer need them.
received_approximate_state_event.wait()
received_approximate_state_event.clear()
logger.info("WT: State received by MT, resuming batch iteration")
# Attempt to put on queue or timeout to try again (maybe main thread is busy) # Attempt to put on queue or timeout to try again (maybe main thread is busy)
batch_queue.put(item, timeout=0.1) batch_queue.put(item, timeout=0.1)
# On success, stop trying # On success, stop trying
@ -58,10 +89,10 @@ def start_work_from_state(
if stop_event.is_set(): if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to # Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch) # buffer, that the main thread receives everything (and tosses this fake batch)
logging.debug( logging.info(
"Worker thread: Stop event detected, outputting is_final=True batch" "Worker thread: Stop event detected, outputting is_final=True batch"
) )
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full()) logging.info("Worker thread: batch_queue full=%s", batch_queue.full())
batch_queue.put( batch_queue.put(
Batch( Batch(
x=np.zeros((1, 1)), x=np.zeros((1, 1)),
@ -72,23 +103,26 @@ def start_work_from_state(
ngram_ids=None, ngram_ids=None,
) )
) )
logging.debug( logging.info(
"Worker thread: is_final=True batch put in queue, breaking from loop." "Worker thread: is_final=True batch put in queue, breaking from loop."
) )
break break
try: try:
logging.debug("Worker thread: outputting state") logging.info("Worker thread: outputting state")
state_queue.put(stateful_iterator.get_state(), timeout=1) state_queue.put(stateful_iterator.get_state(), timeout=1)
logging.debug("Worker thread: state dump complete") logging.info("Worker thread: state dump complete")
state_dumped_event.set() state_dumped_event.set()
logging.debug("Worker thread: set state_dump_event") logging.info("Worker thread: set state_dump_event")
except Full: except Full:
raise ValueError( raise ValueError(
"Attempted to dump state into the state queue, but it was full" "Attempted to dump state into the state queue, but it was full"
) )
FETCH_STATE_TIMEOUT = 120
class MultiprocessIterator(StatefulIterator): class MultiprocessIterator(StatefulIterator):
""" """
Design sketch of the multiprocess iterator: Design sketch of the multiprocess iterator:
@ -124,18 +158,24 @@ class MultiprocessIterator(StatefulIterator):
base_iterator: StatefulIterator, base_iterator: StatefulIterator,
*, *,
n_batches_to_prefetch: int, n_batches_to_prefetch: int,
prefetch_buffer: list | None = None prefetch_buffer: list | None = None,
persist_type: PersistType = PersistType.EXACT,
): ):
self.base_iterator = base_iterator self.base_iterator = base_iterator
self.n_batches_to_prefetch = n_batches_to_prefetch self.n_batches_to_prefetch = n_batches_to_prefetch
self.persist_type = persist_type
if prefetch_buffer is None: if prefetch_buffer is None:
prefetch_buffer = [] prefetch_buffer = []
self.prefetch_buffer = prefetch_buffer self.prefetch_buffer = prefetch_buffer
self.batch_queue = None self.batch_queue = None
self.state_queue = None self.state_queue = None
self.approximate_state_queue = None
self.producer = None self.producer = None
self.stop_iterating_event = None self.stop_iterating_event = None
self.state_dumped_event = None self.state_dumped_event = None
self.trigger_approximate_send_state_event = None
self.sent_approximate_state_event = None
self.received_approximate_state_event = None
self.force_shutdown = False self.force_shutdown = False
def shutdown(self): def shutdown(self):
@ -144,6 +184,92 @@ class MultiprocessIterator(StatefulIterator):
self.producer.kill() self.producer.kill()
self.force_shutdown = True self.force_shutdown = True
def _get_state_exact(self):
logging.info("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.info(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.info(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
logging.info("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try:
logging.info(
"Main thread: state_dumped_event received, waiting for state from queue"
)
base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT)
logging.info("Main thread: received state from queue")
assert isinstance(base_iterator_state, IteratorState)
except Empty:
raise ValueError(
"Attempted to get the state, but it was unexpectantly missing"
)
self.base_iterator = base_iterator_state.build()
self.producer.close()
self.producer = None
self.batch_queue = None
self.state_queue = None
self.approximate_state_queue = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.trigger_approximate_send_state_event = None
self.sent_approximate_state_event = None
self.received_approximate_state_event = None
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
persist_type=self.persist_type,
)
def _get_state_approximate(self):
logging.info("MT: Sending approximate get_state request")
self.trigger_approximate_send_state_event.set()
logging.info("MT: Waiting for sent_approximate_state_event")
self.sent_approximate_state_event.wait()
logging.info("MT: sent_approximate_state_event ack")
try:
logging.info("MT: waiting for approximate state in queue")
base_iterator_state = self.approximate_state_queue.get(
timeout=FETCH_STATE_TIMEOUT
)
logging.info("MT: approximate state received")
assert isinstance(base_iterator_state, IteratorState)
assert self.approximate_state_queue.empty()
except Empty:
raise ValueError(
"Attempted to get approximate state, but queue was erroniously empty."
)
self.received_approximate_state_event.set()
return MultiprocessIteratorState(
base_iterator_state=base_iterator_state,
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
persist_type=self.persist_type,
)
def get_state(self) -> MultiprocessIteratorState: def get_state(self) -> MultiprocessIteratorState:
""" """
This is slightly unusual in effectively destroying the current iterator, its necessary This is slightly unusual in effectively destroying the current iterator, its necessary
@ -162,55 +288,15 @@ class MultiprocessIterator(StatefulIterator):
base_iterator_state=self.base_iterator.get_state(), base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch, n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=serialized_prefetch_buffer, serialized_prefetch_buffer=serialized_prefetch_buffer,
persist_type=self.persist_type,
) )
else: else:
logging.debug("Main thread: Sending stop iteration event") if self.persist_type == PersistType.EXACT:
self.stop_iterating_event.set() return self._get_state_exact()
logging.debug( elif self.persist_type == PersistType.APPROXIMATE:
"Main thread: Emptying the batch_queue until batch.is_final=True is found." return self._get_state_approximate()
) else:
self.prefetch_buffer = [] raise ValueError("invalid persist_type")
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)
except Empty:
raise ValueError(
"Attempted to get the state, but it was unexpectantly missing"
)
self.base_iterator = base_iterator_state.build()
self.producer.close()
self.producer = None
self.batch_queue = None
self.state_queue = None
self.stop_iterating_event = None
self.state_dumped_event = None
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
)
def create_iter(self): def create_iter(self):
if self.force_shutdown: if self.force_shutdown:
@ -236,8 +322,14 @@ class MultiprocessIterator(StatefulIterator):
# We should only ever one state, which is output at the detection of a stop event # We should only ever one state, which is output at the detection of a stop event
self.state_queue = ctx.Manager().Queue(maxsize=1) self.state_queue = ctx.Manager().Queue(maxsize=1)
# Similarly, there should only ever be one state in flight due to event signals
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
self.stop_iterating_event = ctx.Event() self.stop_iterating_event = ctx.Event()
self.state_dumped_event = ctx.Event() self.state_dumped_event = ctx.Event()
self.trigger_approximate_send_state_event = ctx.Event()
self.sent_approximate_state_event = ctx.Event()
self.received_approximate_state_event = ctx.Event()
self.producer = mp.Process( self.producer = mp.Process(
name="blt_data_loader", name="blt_data_loader",
@ -245,8 +337,12 @@ class MultiprocessIterator(StatefulIterator):
args=( args=(
self.batch_queue, self.batch_queue,
self.state_queue, self.state_queue,
self.approximate_state_queue,
self.stop_iterating_event, self.stop_iterating_event,
self.state_dumped_event, self.state_dumped_event,
self.trigger_approximate_send_state_event,
self.sent_approximate_state_event,
self.received_approximate_state_event,
self.base_iterator.get_state(), self.base_iterator.get_state(),
), ),
) )

View file

@ -31,6 +31,7 @@ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import ( from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator, MultiprocessIterator,
MultiprocessIteratorState, MultiprocessIteratorState,
PersistType,
) )
from bytelatent.data.iterators.packing_iterator import PackingIteratorState from bytelatent.data.iterators.packing_iterator import PackingIteratorState
from bytelatent.distributed import ( from bytelatent.distributed import (
@ -712,9 +713,15 @@ def train(args: TrainArgs):
if every_n_steps( if every_n_steps(
train_state, args.checkpoint.dump.every, acc_step=0 train_state, args.checkpoint.dump.every, acc_step=0
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0): ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
train_state.data_loader_state, data_loader, batch_iterator = ( if (
get_state_and_refresh(data_loader) args.data.load_async
) and args.data.async_persist_type == PersistType.EXACT
):
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
else:
train_state.data_loader_state = data_loader.get_state()
saved = checkpoint.save( saved = checkpoint.save(
model, model,
optimizer, optimizer,
@ -756,9 +763,16 @@ def train(args: TrainArgs):
if preemption_flag["flag"]: if preemption_flag["flag"]:
if not saved: if not saved:
train_state.data_loader_state, data_loader, batch_iterator = ( if (
get_state_and_refresh(data_loader) args.data.load_async
) and args.data.async_persist_type == PersistType.EXACT
):
train_state.data_loader_state, data_loader, batch_iterator = (
get_state_and_refresh(data_loader)
)
else:
train_state.data_loader_state = data_loader.get_state()
checkpoint.save( checkpoint.save(
model, model,
optimizer, optimizer,
@ -769,21 +783,27 @@ def train(args: TrainArgs):
requeue_slurm_job() requeue_slurm_job()
sys.exit(0) sys.exit(0)
if not saved: if not saved:
train_state.data_loader_state, data_loader, batch_iterator = ( if (
get_state_and_refresh(data_loader) args.data.load_async
) and args.data.async_persist_type == PersistType.EXACT
checkpoint.save( ):
model, train_state.data_loader_state, data_loader, batch_iterator = (
optimizer, get_state_and_refresh(data_loader)
train_state, )
args, else:
device_mesh=world_mesh, train_state.data_loader_state = data_loader.get_state()
) checkpoint.save(
if isinstance(data_loader, MultiprocessIterator): model,
logger.info("Closing MP iterator before exiting") optimizer,
data_loader.shutdown() train_state,
gc.collect() args,
device_mesh=world_mesh,
)
if isinstance(data_loader, MultiprocessIterator):
logger.info("Closing MP iterator before exiting")
data_loader.shutdown()
gc.collect()
def main(): def main():