mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-10 19:59:09 +00:00
Add approximate state persistence (#73)
Summary: Test Plan: *** More verbose multiprocess logging, fix get_state_and_recycle Summary: Test Plan:
This commit is contained in:
parent
9bd51df961
commit
ea1fc75862
3 changed files with 199 additions and 77 deletions
bytelatent
|
@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
||||||
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
||||||
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
||||||
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||||
|
MultiprocessIterator,
|
||||||
|
PersistType,
|
||||||
|
)
|
||||||
from bytelatent.data.iterators.packing_iterator import (
|
from bytelatent.data.iterators.packing_iterator import (
|
||||||
PackingArgs,
|
PackingArgs,
|
||||||
PackingIterator,
|
PackingIterator,
|
||||||
|
@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel):
|
||||||
add_bos: bool = True
|
add_bos: bool = True
|
||||||
add_eos: bool = True
|
add_eos: bool = True
|
||||||
load_async: bool = True
|
load_async: bool = True
|
||||||
|
async_persist_type: PersistType = PersistType.EXACT
|
||||||
prefetch_size: int = 64
|
prefetch_size: int = 64
|
||||||
preprocess_dir: str | None = None
|
preprocess_dir: str | None = None
|
||||||
dataset_files: list[str] | None = None
|
dataset_files: list[str] | None = None
|
||||||
|
@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel):
|
||||||
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
||||||
if self.load_async:
|
if self.load_async:
|
||||||
mp_iterator = MultiprocessIterator(
|
mp_iterator = MultiprocessIterator(
|
||||||
packing_iterator, n_batches_to_prefetch=self.prefetch_size
|
packing_iterator,
|
||||||
|
n_batches_to_prefetch=self.prefetch_size,
|
||||||
|
persist_type=self.async_persist_type,
|
||||||
)
|
)
|
||||||
return mp_iterator
|
return mp_iterator
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
from enum import Enum
|
||||||
from multiprocessing.synchronize import Event as EventClass
|
from multiprocessing.synchronize import Event as EventClass
|
||||||
from queue import Empty, Full
|
from queue import Empty, Full
|
||||||
|
|
||||||
|
@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
class PersistType(str, Enum):
|
||||||
|
EXACT = "exact"
|
||||||
|
APPROXIMATE = "approximate"
|
||||||
|
|
||||||
|
|
||||||
class MultiprocessIteratorState(PydanticIteratorState):
|
class MultiprocessIteratorState(PydanticIteratorState):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
base_iterator_state: PackingIteratorState
|
base_iterator_state: PackingIteratorState
|
||||||
n_batches_to_prefetch: int
|
n_batches_to_prefetch: int
|
||||||
serialized_prefetch_buffer: str
|
serialized_prefetch_buffer: str
|
||||||
|
persist_type: PersistType
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
base_iterator = self.base_iterator_state.build()
|
base_iterator = self.base_iterator_state.build()
|
||||||
|
@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState):
|
||||||
base_iterator,
|
base_iterator,
|
||||||
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
prefetch_buffer=prefetch_buffer,
|
prefetch_buffer=prefetch_buffer,
|
||||||
|
persist_type=self.persist_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def start_work_from_state(
|
def start_work_from_state(
|
||||||
batch_queue: mp.Queue,
|
batch_queue: mp.Queue,
|
||||||
state_queue: mp.Queue,
|
state_queue: mp.Queue,
|
||||||
|
approximate_state_queue: mp.Queue,
|
||||||
stop_event: EventClass,
|
stop_event: EventClass,
|
||||||
state_dumped_event: EventClass,
|
state_dumped_event: EventClass,
|
||||||
|
trigger_approximate_send_state_event: EventClass,
|
||||||
|
sent_approximate_state_event: EventClass,
|
||||||
|
received_approximate_state_event: EventClass,
|
||||||
state: IteratorState,
|
state: IteratorState,
|
||||||
):
|
):
|
||||||
logging.info("Worker thread: Starting base_iterator work")
|
logging.info("Worker thread: Starting base_iterator work")
|
||||||
|
@ -49,6 +61,25 @@ def start_work_from_state(
|
||||||
for item in iterator:
|
for item in iterator:
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
|
if trigger_approximate_send_state_event.is_set():
|
||||||
|
logger.info("WT: trigger_approximate_send ack")
|
||||||
|
# Since this can be triggered again (but only after the state is received on mp),
|
||||||
|
# we should cleanup as soon as possible.
|
||||||
|
trigger_approximate_send_state_event.clear()
|
||||||
|
logging.info("WT: Computing approximate state")
|
||||||
|
approximate_state = stateful_iterator.get_state()
|
||||||
|
# At this state, there should always be exactly 1 slot.
|
||||||
|
# Blocking here would be a bug.
|
||||||
|
logger.info("WT: Attempting to send approximate state")
|
||||||
|
approximate_state_queue.put(
|
||||||
|
approximate_state, block=True, timeout=None
|
||||||
|
)
|
||||||
|
sent_approximate_state_event.set()
|
||||||
|
logger.info("WT: Approximate state sent")
|
||||||
|
# Same here, clear events as we no longer need them.
|
||||||
|
received_approximate_state_event.wait()
|
||||||
|
received_approximate_state_event.clear()
|
||||||
|
logger.info("WT: State received by MT, resuming batch iteration")
|
||||||
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
||||||
batch_queue.put(item, timeout=0.1)
|
batch_queue.put(item, timeout=0.1)
|
||||||
# On success, stop trying
|
# On success, stop trying
|
||||||
|
@ -58,10 +89,10 @@ def start_work_from_state(
|
||||||
if stop_event.is_set():
|
if stop_event.is_set():
|
||||||
# Signal the end of output, this ensures that even if the queue takes a while to
|
# Signal the end of output, this ensures that even if the queue takes a while to
|
||||||
# buffer, that the main thread receives everything (and tosses this fake batch)
|
# buffer, that the main thread receives everything (and tosses this fake batch)
|
||||||
logging.debug(
|
logging.info(
|
||||||
"Worker thread: Stop event detected, outputting is_final=True batch"
|
"Worker thread: Stop event detected, outputting is_final=True batch"
|
||||||
)
|
)
|
||||||
logging.debug("Worker thread: batch_queue full=%s", batch_queue.full())
|
logging.info("Worker thread: batch_queue full=%s", batch_queue.full())
|
||||||
batch_queue.put(
|
batch_queue.put(
|
||||||
Batch(
|
Batch(
|
||||||
x=np.zeros((1, 1)),
|
x=np.zeros((1, 1)),
|
||||||
|
@ -72,23 +103,26 @@ def start_work_from_state(
|
||||||
ngram_ids=None,
|
ngram_ids=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logging.debug(
|
logging.info(
|
||||||
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logging.debug("Worker thread: outputting state")
|
logging.info("Worker thread: outputting state")
|
||||||
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
||||||
logging.debug("Worker thread: state dump complete")
|
logging.info("Worker thread: state dump complete")
|
||||||
state_dumped_event.set()
|
state_dumped_event.set()
|
||||||
logging.debug("Worker thread: set state_dump_event")
|
logging.info("Worker thread: set state_dump_event")
|
||||||
except Full:
|
except Full:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Attempted to dump state into the state queue, but it was full"
|
"Attempted to dump state into the state queue, but it was full"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FETCH_STATE_TIMEOUT = 120
|
||||||
|
|
||||||
|
|
||||||
class MultiprocessIterator(StatefulIterator):
|
class MultiprocessIterator(StatefulIterator):
|
||||||
"""
|
"""
|
||||||
Design sketch of the multiprocess iterator:
|
Design sketch of the multiprocess iterator:
|
||||||
|
@ -124,18 +158,24 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
base_iterator: StatefulIterator,
|
base_iterator: StatefulIterator,
|
||||||
*,
|
*,
|
||||||
n_batches_to_prefetch: int,
|
n_batches_to_prefetch: int,
|
||||||
prefetch_buffer: list | None = None
|
prefetch_buffer: list | None = None,
|
||||||
|
persist_type: PersistType = PersistType.EXACT,
|
||||||
):
|
):
|
||||||
self.base_iterator = base_iterator
|
self.base_iterator = base_iterator
|
||||||
self.n_batches_to_prefetch = n_batches_to_prefetch
|
self.n_batches_to_prefetch = n_batches_to_prefetch
|
||||||
|
self.persist_type = persist_type
|
||||||
if prefetch_buffer is None:
|
if prefetch_buffer is None:
|
||||||
prefetch_buffer = []
|
prefetch_buffer = []
|
||||||
self.prefetch_buffer = prefetch_buffer
|
self.prefetch_buffer = prefetch_buffer
|
||||||
self.batch_queue = None
|
self.batch_queue = None
|
||||||
self.state_queue = None
|
self.state_queue = None
|
||||||
|
self.approximate_state_queue = None
|
||||||
self.producer = None
|
self.producer = None
|
||||||
self.stop_iterating_event = None
|
self.stop_iterating_event = None
|
||||||
self.state_dumped_event = None
|
self.state_dumped_event = None
|
||||||
|
self.trigger_approximate_send_state_event = None
|
||||||
|
self.sent_approximate_state_event = None
|
||||||
|
self.received_approximate_state_event = None
|
||||||
self.force_shutdown = False
|
self.force_shutdown = False
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
@ -144,6 +184,92 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
self.producer.kill()
|
self.producer.kill()
|
||||||
self.force_shutdown = True
|
self.force_shutdown = True
|
||||||
|
|
||||||
|
def _get_state_exact(self):
|
||||||
|
logging.info("Main thread: Sending stop iteration event")
|
||||||
|
self.stop_iterating_event.set()
|
||||||
|
logging.info(
|
||||||
|
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
||||||
|
)
|
||||||
|
self.prefetch_buffer = []
|
||||||
|
final_batch_received = False
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
batch = self.batch_queue.get(timeout=1)
|
||||||
|
if batch.is_final:
|
||||||
|
logging.info(
|
||||||
|
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
||||||
|
)
|
||||||
|
final_batch_received = True
|
||||||
|
break
|
||||||
|
self.prefetch_buffer.append(batch)
|
||||||
|
except Empty:
|
||||||
|
logging.warning("Main thread: batch_queue is abnormally empty")
|
||||||
|
assert final_batch_received
|
||||||
|
|
||||||
|
logging.info("Main thread: Waiting for state_dumped event")
|
||||||
|
self.state_dumped_event.wait()
|
||||||
|
|
||||||
|
try:
|
||||||
|
logging.info(
|
||||||
|
"Main thread: state_dumped_event received, waiting for state from queue"
|
||||||
|
)
|
||||||
|
base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT)
|
||||||
|
logging.info("Main thread: received state from queue")
|
||||||
|
assert isinstance(base_iterator_state, IteratorState)
|
||||||
|
except Empty:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempted to get the state, but it was unexpectantly missing"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.base_iterator = base_iterator_state.build()
|
||||||
|
self.producer.close()
|
||||||
|
self.producer = None
|
||||||
|
self.batch_queue = None
|
||||||
|
self.state_queue = None
|
||||||
|
self.approximate_state_queue = None
|
||||||
|
self.stop_iterating_event = None
|
||||||
|
self.state_dumped_event = None
|
||||||
|
self.trigger_approximate_send_state_event = None
|
||||||
|
self.sent_approximate_state_event = None
|
||||||
|
self.received_approximate_state_event = None
|
||||||
|
|
||||||
|
return MultiprocessIteratorState(
|
||||||
|
base_iterator_state=self.base_iterator.get_state(),
|
||||||
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
|
serialized_prefetch_buffer=json.dumps(
|
||||||
|
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||||
|
),
|
||||||
|
persist_type=self.persist_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_state_approximate(self):
|
||||||
|
logging.info("MT: Sending approximate get_state request")
|
||||||
|
self.trigger_approximate_send_state_event.set()
|
||||||
|
logging.info("MT: Waiting for sent_approximate_state_event")
|
||||||
|
self.sent_approximate_state_event.wait()
|
||||||
|
logging.info("MT: sent_approximate_state_event ack")
|
||||||
|
try:
|
||||||
|
logging.info("MT: waiting for approximate state in queue")
|
||||||
|
base_iterator_state = self.approximate_state_queue.get(
|
||||||
|
timeout=FETCH_STATE_TIMEOUT
|
||||||
|
)
|
||||||
|
logging.info("MT: approximate state received")
|
||||||
|
assert isinstance(base_iterator_state, IteratorState)
|
||||||
|
assert self.approximate_state_queue.empty()
|
||||||
|
except Empty:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempted to get approximate state, but queue was erroniously empty."
|
||||||
|
)
|
||||||
|
self.received_approximate_state_event.set()
|
||||||
|
return MultiprocessIteratorState(
|
||||||
|
base_iterator_state=base_iterator_state,
|
||||||
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
|
serialized_prefetch_buffer=json.dumps(
|
||||||
|
[b.to_python_dict() for b in self.prefetch_buffer]
|
||||||
|
),
|
||||||
|
persist_type=self.persist_type,
|
||||||
|
)
|
||||||
|
|
||||||
def get_state(self) -> MultiprocessIteratorState:
|
def get_state(self) -> MultiprocessIteratorState:
|
||||||
"""
|
"""
|
||||||
This is slightly unusual in effectively destroying the current iterator, its necessary
|
This is slightly unusual in effectively destroying the current iterator, its necessary
|
||||||
|
@ -162,55 +288,15 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
base_iterator_state=self.base_iterator.get_state(),
|
base_iterator_state=self.base_iterator.get_state(),
|
||||||
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
||||||
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
||||||
|
persist_type=self.persist_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.debug("Main thread: Sending stop iteration event")
|
if self.persist_type == PersistType.EXACT:
|
||||||
self.stop_iterating_event.set()
|
return self._get_state_exact()
|
||||||
logging.debug(
|
elif self.persist_type == PersistType.APPROXIMATE:
|
||||||
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
return self._get_state_approximate()
|
||||||
)
|
else:
|
||||||
self.prefetch_buffer = []
|
raise ValueError("invalid persist_type")
|
||||||
final_batch_received = False
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
batch = self.batch_queue.get(timeout=1)
|
|
||||||
if batch.is_final:
|
|
||||||
logging.debug(
|
|
||||||
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
|
||||||
)
|
|
||||||
final_batch_received = True
|
|
||||||
break
|
|
||||||
self.prefetch_buffer.append(batch)
|
|
||||||
except Empty:
|
|
||||||
logging.warning("Main thread: batch_queue is abnormally empty")
|
|
||||||
assert final_batch_received
|
|
||||||
|
|
||||||
logging.debug("Main thread: Waiting for state_dumped event")
|
|
||||||
self.state_dumped_event.wait()
|
|
||||||
|
|
||||||
try:
|
|
||||||
base_iterator_state = self.state_queue.get(timeout=1)
|
|
||||||
assert isinstance(base_iterator_state, IteratorState)
|
|
||||||
except Empty:
|
|
||||||
raise ValueError(
|
|
||||||
"Attempted to get the state, but it was unexpectantly missing"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.base_iterator = base_iterator_state.build()
|
|
||||||
self.producer.close()
|
|
||||||
self.producer = None
|
|
||||||
self.batch_queue = None
|
|
||||||
self.state_queue = None
|
|
||||||
self.stop_iterating_event = None
|
|
||||||
self.state_dumped_event = None
|
|
||||||
|
|
||||||
return MultiprocessIteratorState(
|
|
||||||
base_iterator_state=self.base_iterator.get_state(),
|
|
||||||
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
|
||||||
serialized_prefetch_buffer=json.dumps(
|
|
||||||
[b.to_python_dict() for b in self.prefetch_buffer]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_iter(self):
|
def create_iter(self):
|
||||||
if self.force_shutdown:
|
if self.force_shutdown:
|
||||||
|
@ -236,8 +322,14 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
# We should only ever one state, which is output at the detection of a stop event
|
# We should only ever one state, which is output at the detection of a stop event
|
||||||
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
||||||
|
|
||||||
|
# Similarly, there should only ever be one state in flight due to event signals
|
||||||
|
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
|
||||||
|
|
||||||
self.stop_iterating_event = ctx.Event()
|
self.stop_iterating_event = ctx.Event()
|
||||||
self.state_dumped_event = ctx.Event()
|
self.state_dumped_event = ctx.Event()
|
||||||
|
self.trigger_approximate_send_state_event = ctx.Event()
|
||||||
|
self.sent_approximate_state_event = ctx.Event()
|
||||||
|
self.received_approximate_state_event = ctx.Event()
|
||||||
|
|
||||||
self.producer = mp.Process(
|
self.producer = mp.Process(
|
||||||
name="blt_data_loader",
|
name="blt_data_loader",
|
||||||
|
@ -245,8 +337,12 @@ class MultiprocessIterator(StatefulIterator):
|
||||||
args=(
|
args=(
|
||||||
self.batch_queue,
|
self.batch_queue,
|
||||||
self.state_queue,
|
self.state_queue,
|
||||||
|
self.approximate_state_queue,
|
||||||
self.stop_iterating_event,
|
self.stop_iterating_event,
|
||||||
self.state_dumped_event,
|
self.state_dumped_event,
|
||||||
|
self.trigger_approximate_send_state_event,
|
||||||
|
self.sent_approximate_state_event,
|
||||||
|
self.received_approximate_state_event,
|
||||||
self.base_iterator.get_state(),
|
self.base_iterator.get_state(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,6 +31,7 @@ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
|
||||||
from bytelatent.data.iterators.multiprocess_iterator import (
|
from bytelatent.data.iterators.multiprocess_iterator import (
|
||||||
MultiprocessIterator,
|
MultiprocessIterator,
|
||||||
MultiprocessIteratorState,
|
MultiprocessIteratorState,
|
||||||
|
PersistType,
|
||||||
)
|
)
|
||||||
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
||||||
from bytelatent.distributed import (
|
from bytelatent.distributed import (
|
||||||
|
@ -712,9 +713,15 @@ def train(args: TrainArgs):
|
||||||
if every_n_steps(
|
if every_n_steps(
|
||||||
train_state, args.checkpoint.dump.every, acc_step=0
|
train_state, args.checkpoint.dump.every, acc_step=0
|
||||||
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
||||||
train_state.data_loader_state, data_loader, batch_iterator = (
|
if (
|
||||||
get_state_and_refresh(data_loader)
|
args.data.load_async
|
||||||
)
|
and args.data.async_persist_type == PersistType.EXACT
|
||||||
|
):
|
||||||
|
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||||
|
get_state_and_refresh(data_loader)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_state.data_loader_state = data_loader.get_state()
|
||||||
saved = checkpoint.save(
|
saved = checkpoint.save(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -756,9 +763,16 @@ def train(args: TrainArgs):
|
||||||
|
|
||||||
if preemption_flag["flag"]:
|
if preemption_flag["flag"]:
|
||||||
if not saved:
|
if not saved:
|
||||||
train_state.data_loader_state, data_loader, batch_iterator = (
|
if (
|
||||||
get_state_and_refresh(data_loader)
|
args.data.load_async
|
||||||
)
|
and args.data.async_persist_type == PersistType.EXACT
|
||||||
|
):
|
||||||
|
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||||
|
get_state_and_refresh(data_loader)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_state.data_loader_state = data_loader.get_state()
|
||||||
|
|
||||||
checkpoint.save(
|
checkpoint.save(
|
||||||
model,
|
model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -769,21 +783,27 @@ def train(args: TrainArgs):
|
||||||
requeue_slurm_job()
|
requeue_slurm_job()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if not saved:
|
if not saved:
|
||||||
train_state.data_loader_state, data_loader, batch_iterator = (
|
if (
|
||||||
get_state_and_refresh(data_loader)
|
args.data.load_async
|
||||||
)
|
and args.data.async_persist_type == PersistType.EXACT
|
||||||
checkpoint.save(
|
):
|
||||||
model,
|
train_state.data_loader_state, data_loader, batch_iterator = (
|
||||||
optimizer,
|
get_state_and_refresh(data_loader)
|
||||||
train_state,
|
)
|
||||||
args,
|
else:
|
||||||
device_mesh=world_mesh,
|
train_state.data_loader_state = data_loader.get_state()
|
||||||
)
|
checkpoint.save(
|
||||||
if isinstance(data_loader, MultiprocessIterator):
|
model,
|
||||||
logger.info("Closing MP iterator before exiting")
|
optimizer,
|
||||||
data_loader.shutdown()
|
train_state,
|
||||||
gc.collect()
|
args,
|
||||||
|
device_mesh=world_mesh,
|
||||||
|
)
|
||||||
|
if isinstance(data_loader, MultiprocessIterator):
|
||||||
|
logger.info("Closing MP iterator before exiting")
|
||||||
|
data_loader.shutdown()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Add table
Reference in a new issue