Fix multiprocessing dataloader checkpointing and use it in the train script (#50)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

This commit is contained in:
Pedro Rodriguez 2025-02-13 11:58:23 -08:00 committed by GitHub
parent 85c2f28f26
commit 8c61ab5e67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 77 additions and 33 deletions

View file

@ -54,9 +54,10 @@ def start_work_from_state(
if stop_event.is_set():
# Signal the end of output, this ensures that even if the queue takes a while to
# buffer, that the main thread receives everything (and tosses this fake batch)
logging.info(
logging.debug(
"Worker thread: Stop event detected, outputting is_final=True batch"
)
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
batch_queue.put(
Batch(
x=np.zeros((1, 1)),
@ -67,14 +68,17 @@ def start_work_from_state(
ngram_ids=None,
)
)
logging.debug(
"Worker thread: is_final=True batch put in queue, breaking from loop."
)
break
try:
logging.info("Worker thread: outputting state")
state_queue.put(iterator.get_state(), timeout=1)
logging.info("Worker thread: state dump complete")
logging.debug("Worker thread: outputting state")
state_queue.put(stateful_iterator.get_state(), timeout=1)
logging.debug("Worker thread: state dump complete")
state_dumped_event.set()
logging.info("Worker thread: set state_dump_event")
logging.debug("Worker thread: set state_dump_event")
except Full:
raise ValueError(
"Attempted to dump state into the state queue, but it was full"
@ -156,16 +160,20 @@ class MultiprocessIterator(StatefulIterator):
serialized_prefetch_buffer=serialized_prefetch_buffer,
)
else:
logging.info("Main thread: Sending stop iteration event")
logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.info("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
logging.debug(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
@ -173,6 +181,9 @@ class MultiprocessIterator(StatefulIterator):
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)