From ae3b2cd8eb8df4c98a91a01ae199b1f466a0a95c Mon Sep 17 00:00:00 2001
From: Pedro Rodriguez <me@pedro.ai>
Date: Thu, 13 Mar 2025 00:23:54 +0000
Subject: [PATCH] Update iterate_data

Summary:

Test Plan:
---
 bytelatent/iterate_data.py | 21 +++++++++++++++------
 1 file changed, 15 insertions(+), 6 deletions(-)

diff --git a/bytelatent/iterate_data.py b/bytelatent/iterate_data.py
index bdb8f22..79738ff 100644
--- a/bytelatent/iterate_data.py
+++ b/bytelatent/iterate_data.py
@@ -8,10 +8,16 @@ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
 from bytelatent.logger import init_logger
 
 
-def main(state_file: str):
+def main(
+    state_file: str,
+    steps: int = 3_000,
+    io_thread_count: int = 2,
+    cpu_count: int = 2,
+    log_freq: int = 100,
+):
     init_logger()
-    pyarrow.set_io_thread_count(4)
-    pyarrow.set_cpu_count(4)
+    pyarrow.set_io_thread_count(io_thread_count)
+    pyarrow.set_cpu_count(cpu_count)
     with open(state_file) as f:
         train_state = json.load(f)
         dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
@@ -20,10 +26,13 @@ def main(state_file: str):
         packing_iterator = packing_iterator_state.build()
         print("iter")
         batch_iter = packing_iterator.create_iter()
-        batch = None
         print("looping")
-        for i in track(range(1_000)):
-            batch = next(batch_iter)
+        for i in track(range(steps)):
+            _ = next(batch_iter)
+            if i % log_freq == 0:
+                print(pyarrow.default_memory_pool())
+        print(i)
+        print(pyarrow.default_memory_pool())
 
 
 if __name__ == "__main__":