mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-07 10:19:08 +00:00
Update iterate_data
Summary: Test Plan:
This commit is contained in:
parent
c110f6be2a
commit
ae3b2cd8eb
1 changed files with 15 additions and 6 deletions
|
@ -8,10 +8,16 @@ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
|||
from bytelatent.logger import init_logger
|
||||
|
||||
|
||||
def main(state_file: str):
|
||||
def main(
|
||||
state_file: str,
|
||||
steps: int = 3_000,
|
||||
io_thread_count: int = 2,
|
||||
cpu_count: int = 2,
|
||||
log_freq: int = 100,
|
||||
):
|
||||
init_logger()
|
||||
pyarrow.set_io_thread_count(4)
|
||||
pyarrow.set_cpu_count(4)
|
||||
pyarrow.set_io_thread_count(io_thread_count)
|
||||
pyarrow.set_cpu_count(cpu_count)
|
||||
with open(state_file) as f:
|
||||
train_state = json.load(f)
|
||||
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
||||
|
@ -20,10 +26,13 @@ def main(state_file: str):
|
|||
packing_iterator = packing_iterator_state.build()
|
||||
print("iter")
|
||||
batch_iter = packing_iterator.create_iter()
|
||||
batch = None
|
||||
print("looping")
|
||||
for i in track(range(1_000)):
|
||||
batch = next(batch_iter)
|
||||
for i in track(range(steps)):
|
||||
_ = next(batch_iter)
|
||||
if i % log_freq == 0:
|
||||
print(pyarrow.default_memory_pool())
|
||||
print(i)
|
||||
print(pyarrow.default_memory_pool())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Add table
Reference in a new issue