Add approximate state persistence

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-03-01 00:15:32 +00:00
parent 08b8c7cd05
commit 967b23fd05

View file

@ -39,8 +39,12 @@ class MultiprocessIteratorState(PydanticIteratorState):
def start_work_from_state( def start_work_from_state(
batch_queue: mp.Queue, batch_queue: mp.Queue,
state_queue: mp.Queue, state_queue: mp.Queue,
approximate_state_queue: mp.Queue,
stop_event: EventClass, stop_event: EventClass,
state_dumped_event: EventClass, state_dumped_event: EventClass,
trigger_approximate_send_state_event: EventClass,
sent_approximate_state_event: EventClass,
received_approximate_state_event: EventClass,
state: IteratorState, state: IteratorState,
): ):
logging.info("Worker thread: Starting base_iterator work") logging.info("Worker thread: Starting base_iterator work")
@ -49,6 +53,24 @@ def start_work_from_state(
for item in iterator: for item in iterator:
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
if trigger_approximate_send_state_event.is_set():
logger.debug("WT: trigger_approximate_send ack")
# Since this can be triggered again (but only after the state is received on mp),
# we should cleanup as soon as possible.
trigger_approximate_send_state_event.clear()
approximate_state = stateful_iterator.get_state()
# At this state, there should always be exactly 1 slot.
# Blocking here would be a bug.
logger.debug("WT: Attempting to send approximate state")
approximate_state_queue.put(
approximate_state, block=True, timeout=None
)
sent_approximate_state_event.set()
logger.debug("WT: Approximate state sent")
# Same here, clear events as we no longer need them.
received_approximate_state_event.wait()
received_approximate_state_event.clear()
logger.debug("WT: State received by MT, resuming batch iteration")
# Attempt to put on queue or timeout to try again (maybe main thread is busy) # Attempt to put on queue or timeout to try again (maybe main thread is busy)
batch_queue.put(item, timeout=0.1) batch_queue.put(item, timeout=0.1)
# On success, stop trying # On success, stop trying
@ -133,9 +155,13 @@ class MultiprocessIterator(StatefulIterator):
self.prefetch_buffer = prefetch_buffer self.prefetch_buffer = prefetch_buffer
self.batch_queue = None self.batch_queue = None
self.state_queue = None self.state_queue = None
self.approximate_state_queue = None
self.producer = None self.producer = None
self.stop_iterating_event = None self.stop_iterating_event = None
self.state_dumped_event = None self.state_dumped_event = None
self.trigger_approximate_send_state_event = None
self.sent_approximate_state_event = None
self.received_approximate_state_event = None
self.force_shutdown = False self.force_shutdown = False
def shutdown(self): def shutdown(self):
@ -144,7 +170,84 @@ class MultiprocessIterator(StatefulIterator):
self.producer.kill() self.producer.kill()
self.force_shutdown = True self.force_shutdown = True
def get_state(self) -> MultiprocessIteratorState: def _get_state_exact(self):
logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.debug(
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)
except Empty:
raise ValueError(
"Attempted to get the state, but it was unexpectantly missing"
)
self.base_iterator = base_iterator_state.build()
self.producer.close()
self.producer = None
self.batch_queue = None
self.state_queue = None
self.approximate_state_queue = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.trigger_approximate_send_state_event = None
self.sent_approximate_state_event = None
self.received_approximate_state_event = None
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
)
def _get_state_approximate(self):
logging.debug("MT: Sending approximate get_state request")
self.trigger_approximate_send_state_event.set()
logging.debug("MT: Waiting for sent_approximate_state_event")
self.sent_approximate_state_event.wait()
logging.debug("MT: sent_approximate_state_event ack")
try:
base_iterator_state = self.approximate_state_queue.get(timeout=1)
logging.debug("MT: approximate state received")
assert isinstance(base_iterator_state, IteratorState)
assert self.approximate_state_queue.empty()
except Empty:
raise ValueError(
"Attempted to get approximate state, but queue was erroniously empty."
)
self.received_approximate_state_event.set()
return MultiprocessIteratorState(
base_iterator_state=base_iterator_state,
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
)
def get_state(self, exact: bool = True) -> MultiprocessIteratorState:
""" """
This is slightly unusual in effectively destroying the current iterator, its necessary This is slightly unusual in effectively destroying the current iterator, its necessary
to halt the background process and allow it to write the state to the main loop to halt the background process and allow it to write the state to the main loop
@ -164,53 +267,10 @@ class MultiprocessIterator(StatefulIterator):
serialized_prefetch_buffer=serialized_prefetch_buffer, serialized_prefetch_buffer=serialized_prefetch_buffer,
) )
else: else:
logging.debug("Main thread: Sending stop iteration event") if exact:
self.stop_iterating_event.set() return self._get_state_exact()
logging.debug( else:
"Main thread: Emptying the batch_queue until batch.is_final=True is found." return self._get_state_approximate()
)
self.prefetch_buffer = []
final_batch_received = False
while True:
try:
batch = self.batch_queue.get(timeout=1)
if batch.is_final:
logging.debug(
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
)
final_batch_received = True
break
self.prefetch_buffer.append(batch)
except Empty:
logging.warning("Main thread: batch_queue is abnormally empty")
assert final_batch_received
logging.debug("Main thread: Waiting for state_dumped event")
self.state_dumped_event.wait()
try:
base_iterator_state = self.state_queue.get(timeout=1)
assert isinstance(base_iterator_state, IteratorState)
except Empty:
raise ValueError(
"Attempted to get the state, but it was unexpectantly missing"
)
self.base_iterator = base_iterator_state.build()
self.producer.close()
self.producer = None
self.batch_queue = None
self.state_queue = None
self.stop_iterating_event = None
self.state_dumped_event = None
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
)
def create_iter(self): def create_iter(self):
if self.force_shutdown: if self.force_shutdown:
@ -236,8 +296,14 @@ class MultiprocessIterator(StatefulIterator):
# We should only ever one state, which is output at the detection of a stop event # We should only ever one state, which is output at the detection of a stop event
self.state_queue = ctx.Manager().Queue(maxsize=1) self.state_queue = ctx.Manager().Queue(maxsize=1)
# Similarly, there should only ever be one state in flight due to event signals
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
self.stop_iterating_event = ctx.Event() self.stop_iterating_event = ctx.Event()
self.state_dumped_event = ctx.Event() self.state_dumped_event = ctx.Event()
self.trigger_approximate_send_state_event = ctx.Event()
self.sent_approximate_state_event = ctx.Event()
self.received_approximate_state_event = ctx.Event()
self.producer = mp.Process( self.producer = mp.Process(
name="blt_data_loader", name="blt_data_loader",
@ -245,8 +311,12 @@ class MultiprocessIterator(StatefulIterator):
args=( args=(
self.batch_queue, self.batch_queue,
self.state_queue, self.state_queue,
self.approximate_state_queue,
self.stop_iterating_event, self.stop_iterating_event,
self.state_dumped_event, self.state_dumped_event,
self.trigger_approximate_send_state_event,
self.sent_approximate_state_event,
self.received_approximate_state_event,
self.base_iterator.get_state(), self.base_iterator.get_state(),
), ),
) )