This includes fixes that make checkpointing and reloading work correctly. (#35)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

It also batches in a first set of changes for fixing eval code

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-27 16:56:42 -08:00 committed by GitHub
parent 7622d28b74
commit 7044771a12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 221 additions and 237 deletions

View file

@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator):
self.producer = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.force_shutdown = False
def shutdown(self):
if self.producer is not None:
# This properly shuts things down
self.producer.kill()
self.force_shutdown = True
def get_state(self) -> MultiprocessIteratorState:
"""
@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator):
to halt the background process and allow it to write the state to the main loop
in order to not lose data
"""
if self.force_shutdown:
raise ValueError(
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator):
)
def create_iter(self):
if self.force_shutdown:
raise ValueError(
"Iterator may be invalid if shutdown was forced before state persisted."
)
logging.info("Main thread: Creating MP iterator")
# First yield from the stored prefetch buffer.
if self.prefetch_buffer is not None: