mirror of
https://github.com/facebookresearch/blt.git
synced 2025-04-10 19:59:09 +00:00
Update iterate_data
Summary: Test Plan:
This commit is contained in:
parent
c110f6be2a
commit
ae3b2cd8eb
1 changed files with 15 additions and 6 deletions
|
@ -8,10 +8,16 @@ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
|
||||||
from bytelatent.logger import init_logger
|
from bytelatent.logger import init_logger
|
||||||
|
|
||||||
|
|
||||||
def main(state_file: str):
|
def main(
|
||||||
|
state_file: str,
|
||||||
|
steps: int = 3_000,
|
||||||
|
io_thread_count: int = 2,
|
||||||
|
cpu_count: int = 2,
|
||||||
|
log_freq: int = 100,
|
||||||
|
):
|
||||||
init_logger()
|
init_logger()
|
||||||
pyarrow.set_io_thread_count(4)
|
pyarrow.set_io_thread_count(io_thread_count)
|
||||||
pyarrow.set_cpu_count(4)
|
pyarrow.set_cpu_count(cpu_count)
|
||||||
with open(state_file) as f:
|
with open(state_file) as f:
|
||||||
train_state = json.load(f)
|
train_state = json.load(f)
|
||||||
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
|
||||||
|
@ -20,10 +26,13 @@ def main(state_file: str):
|
||||||
packing_iterator = packing_iterator_state.build()
|
packing_iterator = packing_iterator_state.build()
|
||||||
print("iter")
|
print("iter")
|
||||||
batch_iter = packing_iterator.create_iter()
|
batch_iter = packing_iterator.create_iter()
|
||||||
batch = None
|
|
||||||
print("looping")
|
print("looping")
|
||||||
for i in track(range(1_000)):
|
for i in track(range(steps)):
|
||||||
batch = next(batch_iter)
|
_ = next(batch_iter)
|
||||||
|
if i % log_freq == 0:
|
||||||
|
print(pyarrow.default_memory_pool())
|
||||||
|
print(i)
|
||||||
|
print(pyarrow.default_memory_pool())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Add table
Reference in a new issue