mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-23 01:59:08 +00:00
Add approximate state persistence
Summary: Test Plan:
This commit is contained in:
parent
08b8c7cd05
commit
967b23fd05
1 changed files with 118 additions and 48 deletions
|
@ -39,8 +39,12 @@ class MultiprocessIteratorState(PydanticIteratorState):
|
|||
def start_work_from_state(
|
||||
batch_queue: mp.Queue,
|
||||
state_queue: mp.Queue,
|
||||
approximate_state_queue: mp.Queue,
|
||||
stop_event: EventClass,
|
||||
state_dumped_event: EventClass,
|
||||
trigger_approximate_send_state_event: EventClass,
|
||||
sent_approximate_state_event: EventClass,
|
||||
received_approximate_state_event: EventClass,
|
||||
state: IteratorState,
|
||||
):
|
||||
logging.info("Worker thread: Starting base_iterator work")
|
||||
|
@ -49,6 +53,24 @@ def start_work_from_state(
|
|||
for item in iterator:
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
if trigger_approximate_send_state_event.is_set():
|
||||
logger.debug("WT: trigger_approximate_send ack")
|
||||
# Since this can be triggered again (but only after the state is received on mp),
|
||||
# we should cleanup as soon as possible.
|
||||
trigger_approximate_send_state_event.clear()
|
||||
approximate_state = stateful_iterator.get_state()
|
||||
# At this state, there should always be exactly 1 slot.
|
||||
# Blocking here would be a bug.
|
||||
logger.debug("WT: Attempting to send approximate state")
|
||||
approximate_state_queue.put(
|
||||
approximate_state, block=True, timeout=None
|
||||
)
|
||||
sent_approximate_state_event.set()
|
||||
logger.debug("WT: Approximate state sent")
|
||||
# Same here, clear events as we no longer need them.
|
||||
received_approximate_state_event.wait()
|
||||
received_approximate_state_event.clear()
|
||||
logger.debug("WT: State received by MT, resuming batch iteration")
|
||||
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
||||
batch_queue.put(item, timeout=0.1)
|
||||
# On success, stop trying
|
||||
|
@ -133,9 +155,13 @@ class MultiprocessIterator(StatefulIterator):
|
|||
self.prefetch_buffer = prefetch_buffer
|
||||
self.batch_queue = None
|
||||
self.state_queue = None
|
||||
self.approximate_state_queue = None
|
||||
self.producer = None
|
||||
self.stop_iterating_event = None
|
||||
self.state_dumped_event = None
|
||||
self.trigger_approximate_send_state_event = None
|
||||
self.sent_approximate_state_event = None
|
||||
self.received_approximate_state_event = None
|
||||
self.force_shutdown = False
|
||||
|
||||
def shutdown(self):
|
||||
|
@ -144,7 +170,84 @@ class MultiprocessIterator(StatefulIterator):
|
|||
self.producer.kill()
|
||||
self.force_shutdown = True
|
||||
|
||||
def get_state(self) -> MultiprocessIteratorState:
|
||||
def _get_state_exact(self):
|
||||
logging.debug("Main thread: Sending stop iteration event")
|
||||
self.stop_iterating_event.set()
|
||||
logging.debug(
|
||||
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
||||
)
|
||||
self.prefetch_buffer = []
|
||||
final_batch_received = False
|
||||
while True:
|
||||
try:
|
||||
batch = self.batch_queue.get(timeout=1)
|
||||
if batch.is_final:
|
||||
logging.debug(
|
||||
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
||||
)
|
||||
final_batch_received = True
|
||||
break
|
||||
self.prefetch_buffer.append(batch)
|
||||
except Empty:
|
||||
logging.warning("Main thread: batch_queue is abnormally empty")
|
||||
assert final_batch_received
|
||||
|
||||
logging.debug("Main thread: Waiting for state_dumped event")
|
||||
self.state_dumped_event.wait()
|
||||
|
||||
try:
|
||||
base_iterator_state = self.state_queue.get(timeout=1)
|
||||
assert isinstance(base_iterator_state, IteratorState)
|
||||
except Empty:
|
||||
raise ValueError(
|
||||
"Attempted to get the state, but it was unexpectantly missing"
|
||||
)
|
||||
|
||||
self.base_iterator = base_iterator_state.build()
|
||||
self.producer.close()
|
||||
self.producer = None
|
||||
self.batch_queue = None
|
||||
self.state_queue = None
|
||||
self.approximate_state_queue = None
|
||||
self.stop_iterating_event = None
|
||||
self.state_dumped_event = None
|
||||
self.trigger_approximate_send_state_event = None
|
||||
self.sent_approximate_state_event = None
|
||||
self.received_approximate_state_event = None
|
||||
|
||||
return MultiprocessIteratorState(
|
||||
base_iterator_state=self.base_iterator.get_state(),
|
||||
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||
serialized_prefetch_buffer=json.dumps(
|
||||
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||
),
|
||||
)
|
||||
|
||||
def _get_state_approximate(self):
|
||||
logging.debug("MT: Sending approximate get_state request")
|
||||
self.trigger_approximate_send_state_event.set()
|
||||
logging.debug("MT: Waiting for sent_approximate_state_event")
|
||||
self.sent_approximate_state_event.wait()
|
||||
logging.debug("MT: sent_approximate_state_event ack")
|
||||
try:
|
||||
base_iterator_state = self.approximate_state_queue.get(timeout=1)
|
||||
logging.debug("MT: approximate state received")
|
||||
assert isinstance(base_iterator_state, IteratorState)
|
||||
assert self.approximate_state_queue.empty()
|
||||
except Empty:
|
||||
raise ValueError(
|
||||
"Attempted to get approximate state, but queue was erroniously empty."
|
||||
)
|
||||
self.received_approximate_state_event.set()
|
||||
return MultiprocessIteratorState(
|
||||
base_iterator_state=base_iterator_state,
|
||||
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||
serialized_prefetch_buffer=json.dumps(
|
||||
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||
),
|
||||
)
|
||||
|
||||
def get_state(self, exact: bool = True) -> MultiprocessIteratorState:
|
||||
"""
|
||||
This is slightly unusual in effectively destroying the current iterator, its necessary
|
||||
to halt the background process and allow it to write the state to the main loop
|
||||
|
@ -164,53 +267,10 @@ class MultiprocessIterator(StatefulIterator):
|
|||
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
||||
)
|
||||
else:
|
||||
logging.debug("Main thread: Sending stop iteration event")
|
||||
self.stop_iterating_event.set()
|
||||
logging.debug(
|
||||
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
||||
)
|
||||
self.prefetch_buffer = []
|
||||
final_batch_received = False
|
||||
while True:
|
||||
try:
|
||||
batch = self.batch_queue.get(timeout=1)
|
||||
if batch.is_final:
|
||||
logging.debug(
|
||||
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
||||
)
|
||||
final_batch_received = True
|
||||
break
|
||||
self.prefetch_buffer.append(batch)
|
||||
except Empty:
|
||||
logging.warning("Main thread: batch_queue is abnormally empty")
|
||||
assert final_batch_received
|
||||
|
||||
logging.debug("Main thread: Waiting for state_dumped event")
|
||||
self.state_dumped_event.wait()
|
||||
|
||||
try:
|
||||
base_iterator_state = self.state_queue.get(timeout=1)
|
||||
assert isinstance(base_iterator_state, IteratorState)
|
||||
except Empty:
|
||||
raise ValueError(
|
||||
"Attempted to get the state, but it was unexpectantly missing"
|
||||
)
|
||||
|
||||
self.base_iterator = base_iterator_state.build()
|
||||
self.producer.close()
|
||||
self.producer = None
|
||||
self.batch_queue = None
|
||||
self.state_queue = None
|
||||
self.stop_iterating_event = None
|
||||
self.state_dumped_event = None
|
||||
|
||||
return MultiprocessIteratorState(
|
||||
base_iterator_state=self.base_iterator.get_state(),
|
||||
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||
serialized_prefetch_buffer=json.dumps(
|
||||
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||
),
|
||||
)
|
||||
if exact:
|
||||
return self._get_state_exact()
|
||||
else:
|
||||
return self._get_state_approximate()
|
||||
|
||||
def create_iter(self):
|
||||
if self.force_shutdown:
|
||||
|
@ -236,8 +296,14 @@ class MultiprocessIterator(StatefulIterator):
|
|||
# We should only ever one state, which is output at the detection of a stop event
|
||||
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
||||
|
||||
# Similarly, there should only ever be one state in flight due to event signals
|
||||
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
|
||||
|
||||
self.stop_iterating_event = ctx.Event()
|
||||
self.state_dumped_event = ctx.Event()
|
||||
self.trigger_approximate_send_state_event = ctx.Event()
|
||||
self.sent_approximate_state_event = ctx.Event()
|
||||
self.received_approximate_state_event = ctx.Event()
|
||||
|
||||
self.producer = mp.Process(
|
||||
name="blt_data_loader",
|
||||
|
@ -245,8 +311,12 @@ class MultiprocessIterator(StatefulIterator):
|
|||
args=(
|
||||
self.batch_queue,
|
||||
self.state_queue,
|
||||
self.approximate_state_queue,
|
||||
self.stop_iterating_event,
|
||||
self.state_dumped_event,
|
||||
self.trigger_approximate_send_state_event,
|
||||
self.sent_approximate_state_event,
|
||||
self.received_approximate_state_event,
|
||||
self.base_iterator.get_state(),
|
||||
),
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue