From ae3b2cd8eb8df4c98a91a01ae199b1f466a0a95c Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez <me@pedro.ai> Date: Thu, 13 Mar 2025 00:23:54 +0000 Subject: [PATCH] Update iterate_data Summary: Test Plan: --- bytelatent/iterate_data.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/bytelatent/iterate_data.py b/bytelatent/iterate_data.py index bdb8f22..79738ff 100644 --- a/bytelatent/iterate_data.py +++ b/bytelatent/iterate_data.py @@ -8,10 +8,16 @@ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.logger import init_logger -def main(state_file: str): +def main( + state_file: str, + steps: int = 3_000, + io_thread_count: int = 2, + cpu_count: int = 2, + log_freq: int = 100, +): init_logger() - pyarrow.set_io_thread_count(4) - pyarrow.set_cpu_count(4) + pyarrow.set_io_thread_count(io_thread_count) + pyarrow.set_cpu_count(cpu_count) with open(state_file) as f: train_state = json.load(f) dl_state = MultiprocessIteratorState(**train_state["data_loader_state"]) @@ -20,10 +26,13 @@ def main(state_file: str): packing_iterator = packing_iterator_state.build() print("iter") batch_iter = packing_iterator.create_iter() - batch = None print("looping") - for i in track(range(1_000)): - batch = next(batch_iter) + for i in track(range(steps)): + _ = next(batch_iter) + if i % log_freq == 0: + print(pyarrow.default_memory_pool()) + print(i) + print(pyarrow.default_memory_pool()) if __name__ == "__main__":