diff --git a/bytelatent/data/iterators/multiprocess_iterator.py b/bytelatent/data/iterators/multiprocess_iterator.py index b4df945..39cd8e8 100644 --- a/bytelatent/data/iterators/multiprocess_iterator.py +++ b/bytelatent/data/iterators/multiprocess_iterator.py @@ -39,8 +39,12 @@ class MultiprocessIteratorState(PydanticIteratorState): def start_work_from_state( batch_queue: mp.Queue, state_queue: mp.Queue, + approximate_state_queue: mp.Queue, stop_event: EventClass, state_dumped_event: EventClass, + trigger_approximate_send_state_event: EventClass, + sent_approximate_state_event: EventClass, + received_approximate_state_event: EventClass, state: IteratorState, ): logging.info("Worker thread: Starting base_iterator work") @@ -49,6 +53,24 @@ def start_work_from_state( for item in iterator: while not stop_event.is_set(): try: + if trigger_approximate_send_state_event.is_set(): + logger.debug("WT: trigger_approximate_send ack") + # Since this can be triggered again (but only after the state is received on mp), + # we should cleanup as soon as possible. + trigger_approximate_send_state_event.clear() + approximate_state = stateful_iterator.get_state() + # At this state, there should always be exactly 1 slot. + # Blocking here would be a bug. + logger.debug("WT: Attempting to send approximate state") + approximate_state_queue.put( + approximate_state, block=True, timeout=None + ) + sent_approximate_state_event.set() + logger.debug("WT: Approximate state sent") + # Same here, clear events as we no longer need them. + received_approximate_state_event.wait() + received_approximate_state_event.clear() + logger.debug("WT: State received by MT, resuming batch iteration") # Attempt to put on queue or timeout to try again (maybe main thread is busy) batch_queue.put(item, timeout=0.1) # On success, stop trying @@ -133,9 +155,13 @@ class MultiprocessIterator(StatefulIterator): self.prefetch_buffer = prefetch_buffer self.batch_queue = None self.state_queue = None + self.approximate_state_queue = None self.producer = None self.stop_iterating_event = None self.state_dumped_event = None + self.trigger_approximate_send_state_event = None + self.sent_approximate_state_event = None + self.received_approximate_state_event = None self.force_shutdown = False def shutdown(self): @@ -144,7 +170,84 @@ class MultiprocessIterator(StatefulIterator): self.producer.kill() self.force_shutdown = True - def get_state(self) -> MultiprocessIteratorState: + def _get_state_exact(self): + logging.debug("Main thread: Sending stop iteration event") + self.stop_iterating_event.set() + logging.debug( + "Main thread: Emptying the batch_queue until batch.is_final=True is found." + ) + self.prefetch_buffer = [] + final_batch_received = False + while True: + try: + batch = self.batch_queue.get(timeout=1) + if batch.is_final: + logging.debug( + "Main thread: is_final=True batch found, stopping fetch from batch_queue" + ) + final_batch_received = True + break + self.prefetch_buffer.append(batch) + except Empty: + logging.warning("Main thread: batch_queue is abnormally empty") + assert final_batch_received + + logging.debug("Main thread: Waiting for state_dumped event") + self.state_dumped_event.wait() + + try: + base_iterator_state = self.state_queue.get(timeout=1) + assert isinstance(base_iterator_state, IteratorState) + except Empty: + raise ValueError( + "Attempted to get the state, but it was unexpectantly missing" + ) + + self.base_iterator = base_iterator_state.build() + self.producer.close() + self.producer = None + self.batch_queue = None + self.state_queue = None + self.approximate_state_queue = None + self.stop_iterating_event = None + self.state_dumped_event = None + self.trigger_approximate_send_state_event = None + self.sent_approximate_state_event = None + self.received_approximate_state_event = None + + return MultiprocessIteratorState( + base_iterator_state=self.base_iterator.get_state(), + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ), + ) + + def _get_state_approximate(self): + logging.debug("MT: Sending approximate get_state request") + self.trigger_approximate_send_state_event.set() + logging.debug("MT: Waiting for sent_approximate_state_event") + self.sent_approximate_state_event.wait() + logging.debug("MT: sent_approximate_state_event ack") + try: + base_iterator_state = self.approximate_state_queue.get(timeout=1) + logging.debug("MT: approximate state received") + assert isinstance(base_iterator_state, IteratorState) + assert self.approximate_state_queue.empty() + except Empty: + raise ValueError( + "Attempted to get approximate state, but queue was erroniously empty." + ) + self.received_approximate_state_event.set() + return MultiprocessIteratorState( + base_iterator_state=base_iterator_state, + n_batches_to_prefetch=self.n_batches_to_prefetch, + serialized_prefetch_buffer=json.dumps( + [b.to_python_dict() for b in self.prefetch_buffer] + ), + ) + + def get_state(self, exact: bool = True) -> MultiprocessIteratorState: """ This is slightly unusual in effectively destroying the current iterator, its necessary to halt the background process and allow it to write the state to the main loop @@ -164,53 +267,10 @@ class MultiprocessIterator(StatefulIterator): serialized_prefetch_buffer=serialized_prefetch_buffer, ) else: - logging.debug("Main thread: Sending stop iteration event") - self.stop_iterating_event.set() - logging.debug( - "Main thread: Emptying the batch_queue until batch.is_final=True is found." - ) - self.prefetch_buffer = [] - final_batch_received = False - while True: - try: - batch = self.batch_queue.get(timeout=1) - if batch.is_final: - logging.debug( - "Main thread: is_final=True batch found, stopping fetch from batch_queue" - ) - final_batch_received = True - break - self.prefetch_buffer.append(batch) - except Empty: - logging.warning("Main thread: batch_queue is abnormally empty") - assert final_batch_received - - logging.debug("Main thread: Waiting for state_dumped event") - self.state_dumped_event.wait() - - try: - base_iterator_state = self.state_queue.get(timeout=1) - assert isinstance(base_iterator_state, IteratorState) - except Empty: - raise ValueError( - "Attempted to get the state, but it was unexpectantly missing" - ) - - self.base_iterator = base_iterator_state.build() - self.producer.close() - self.producer = None - self.batch_queue = None - self.state_queue = None - self.stop_iterating_event = None - self.state_dumped_event = None - - return MultiprocessIteratorState( - base_iterator_state=self.base_iterator.get_state(), - n_batches_to_prefetch=self.n_batches_to_prefetch, - serialized_prefetch_buffer=json.dumps( - [b.to_python_dict() for b in self.prefetch_buffer] - ), - ) + if exact: + return self._get_state_exact() + else: + return self._get_state_approximate() def create_iter(self): if self.force_shutdown: @@ -236,8 +296,14 @@ class MultiprocessIterator(StatefulIterator): # We should only ever one state, which is output at the detection of a stop event self.state_queue = ctx.Manager().Queue(maxsize=1) + # Similarly, there should only ever be one state in flight due to event signals + self.approximate_state_queue = ctx.Manager().Queue(maxsize=1) + self.stop_iterating_event = ctx.Event() self.state_dumped_event = ctx.Event() + self.trigger_approximate_send_state_event = ctx.Event() + self.sent_approximate_state_event = ctx.Event() + self.received_approximate_state_event = ctx.Event() self.producer = mp.Process( name="blt_data_loader", @@ -245,8 +311,12 @@ class MultiprocessIterator(StatefulIterator): args=( self.batch_queue, self.state_queue, + self.approximate_state_queue, self.stop_iterating_event, self.state_dumped_event, + self.trigger_approximate_send_state_event, + self.sent_approximate_state_event, + self.received_approximate_state_event, self.base_iterator.get_state(), ), )