Add approximate state persistence
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-03-04 01:02:34 +00:00
parent c727844e9d
commit 3e1df4ea4d
2 changed files with 141 additions and 50 deletions

View file

@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.multiprocess_iterator import (
MultiprocessIterator,
PersistType,
)
from bytelatent.data.iterators.packing_iterator import (
PackingArgs,
PackingIterator,
@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel):
add_bos: bool = True
add_eos: bool = True
load_async: bool = True
async_persist_type: PersistType = PersistType.EXACT
prefetch_size: int = 64
preprocess_dir: str | None = None
dataset_files: list[str] | None = None
@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel):
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
if self.load_async:
mp_iterator = MultiprocessIterator(
packing_iterator, n_batches_to_prefetch=self.prefetch_size
packing_iterator,
n_batches_to_prefetch=self.prefetch_size,
persist_type=self.async_persist_type,
)
return mp_iterator
else:

View file

@ -2,6 +2,7 @@
import json
import logging
import multiprocessing as mp
from enum import Enum
from multiprocessing.synchronize import Event as EventClass
from queue import Empty, Full
@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
logger = logging.getLogger()
class PersistType(str, Enum):
EXACT = "exact"
APPROXIMATE = "approximate"
class MultiprocessIteratorState(PydanticIteratorState):
model_config = ConfigDict(extra="forbid")
base_iterator_state: PackingIteratorState
n_batches_to_prefetch: int
serialized_prefetch_buffer: str
persist_type: PersistType
def build(self):
base_iterator = self.base_iterator_state.build()
@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState):
base_iterator,
n_batches_to_prefetch=self.n_batches_to_prefetch,
prefetch_buffer=prefetch_buffer,
persist_type=self.persist_type,
)
def start_work_from_state(
batch_queue: mp.Queue,
state_queue: mp.Queue,
approximate_state_queue: mp.Queue,
stop_event: EventClass,
state_dumped_event: EventClass,
trigger_approximate_send_state_event: EventClass,
sent_approximate_state_event: EventClass,
received_approximate_state_event: EventClass,
state: IteratorState,
):
logging.info("Worker thread: Starting base_iterator work")
@ -49,6 +61,24 @@ def start_work_from_state(
for item in iterator:
while not stop_event.is_set():
try:
if trigger_approximate_send_state_event.is_set():
logger.debug("WT: trigger_approximate_send ack")
# Since this can be triggered again (but only after the state is received on mp),
# we should cleanup as soon as possible.
trigger_approximate_send_state_event.clear()
approximate_state = stateful_iterator.get_state()
# At this state, there should always be exactly 1 slot.
# Blocking here would be a bug.
logger.debug("WT: Attempting to send approximate state")
approximate_state_queue.put(
approximate_state, block=True, timeout=None
)
sent_approximate_state_event.set()
logger.debug("WT: Approximate state sent")
# Same here, clear events as we no longer need them.
received_approximate_state_event.wait()
received_approximate_state_event.clear()
logger.debug("WT: State received by MT, resuming batch iteration")
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
batch_queue.put(item, timeout=0.1)
# On success, stop trying
@ -124,18 +154,24 @@ class MultiprocessIterator(StatefulIterator):
base_iterator: StatefulIterator,
*,
n_batches_to_prefetch: int,
prefetch_buffer: list | None = None
prefetch_buffer: list | None = None,
persist_type: PersistType = PersistType.EXACT,
):
self.base_iterator = base_iterator
self.n_batches_to_prefetch = n_batches_to_prefetch
self.persist_type = persist_type
if prefetch_buffer is None:
prefetch_buffer = []
self.prefetch_buffer = prefetch_buffer
self.batch_queue = None
self.state_queue = None
self.approximate_state_queue = None
self.producer = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.trigger_approximate_send_state_event = None
self.sent_approximate_state_event = None
self.received_approximate_state_event = None
self.force_shutdown = False
def shutdown(self):
@ -144,26 +180,7 @@ class MultiprocessIterator(StatefulIterator):
self.producer.kill()
self.force_shutdown = True
def get_state(self) -> MultiprocessIteratorState:
"""
This is slightly unusual in effectively destroying the current iterator, its necessary
to halt the background process and allow it to write the state to the main loop
in order to not lose data
"""
if self.force_shutdown:
raise ValueError(
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
)
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=serialized_prefetch_buffer,
)
else:
def _get_state_exact(self):
logging.debug("Main thread: Sending stop iteration event")
self.stop_iterating_event.set()
logging.debug(
@ -201,8 +218,12 @@ class MultiprocessIterator(StatefulIterator):
self.producer = None
self.batch_queue = None
self.state_queue = None
self.approximate_state_queue = None
self.stop_iterating_event = None
self.state_dumped_event = None
self.trigger_approximate_send_state_event = None
self.sent_approximate_state_event = None
self.received_approximate_state_event = None
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
@ -210,8 +231,62 @@ class MultiprocessIterator(StatefulIterator):
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
persist_type=self.persist_type,
)
def _get_state_approximate(self):
logging.debug("MT: Sending approximate get_state request")
self.trigger_approximate_send_state_event.set()
logging.debug("MT: Waiting for sent_approximate_state_event")
self.sent_approximate_state_event.wait()
logging.debug("MT: sent_approximate_state_event ack")
try:
base_iterator_state = self.approximate_state_queue.get(timeout=1)
logging.debug("MT: approximate state received")
assert isinstance(base_iterator_state, IteratorState)
assert self.approximate_state_queue.empty()
except Empty:
raise ValueError(
"Attempted to get approximate state, but queue was erroniously empty."
)
self.received_approximate_state_event.set()
return MultiprocessIteratorState(
base_iterator_state=base_iterator_state,
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
),
persist_type=self.persist_type,
)
def get_state(self) -> MultiprocessIteratorState:
"""
This is slightly unusual in effectively destroying the current iterator, its necessary
to halt the background process and allow it to write the state to the main loop
in order to not lose data
"""
if self.force_shutdown:
raise ValueError(
"State will be invalid if shutdown was forced before state persisted."
)
if self.producer is None:
serialized_prefetch_buffer = json.dumps(
[b.to_python_dict() for b in self.prefetch_buffer]
)
return MultiprocessIteratorState(
base_iterator_state=self.base_iterator.get_state(),
n_batches_to_prefetch=self.n_batches_to_prefetch,
serialized_prefetch_buffer=serialized_prefetch_buffer,
persist_type=self.persist_type,
)
else:
if self.persist_type == PersistType.EXACT:
return self._get_state_exact()
elif self.persist_type == PersistType.APPROXIMATE:
return self._get_state_approximate()
else:
raise ValueError("invalid persist_type")
def create_iter(self):
if self.force_shutdown:
raise ValueError(
@ -236,8 +311,14 @@ class MultiprocessIterator(StatefulIterator):
# We should only ever one state, which is output at the detection of a stop event
self.state_queue = ctx.Manager().Queue(maxsize=1)
# Similarly, there should only ever be one state in flight due to event signals
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
self.stop_iterating_event = ctx.Event()
self.state_dumped_event = ctx.Event()
self.trigger_approximate_send_state_event = ctx.Event()
self.sent_approximate_state_event = ctx.Event()
self.received_approximate_state_event = ctx.Event()
self.producer = mp.Process(
name="blt_data_loader",
@ -245,8 +326,12 @@ class MultiprocessIterator(StatefulIterator):
args=(
self.batch_queue,
self.state_queue,
self.approximate_state_queue,
self.stop_iterating_event,
self.state_dumped_event,
self.trigger_approximate_send_state_event,
self.sent_approximate_state_event,
self.received_approximate_state_event,
self.base_iterator.get_state(),
),
)