blt/bytelatent/data/iterators/test_limit_iterator.py
Pedro Rodriguez fc3399ef40
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled
Update iterator inheritance, pass file format args, limit iterator (#63)
- Create a common class to use in all inheritance for states
- Add a limit iterator that we can use in evals
- Modify ArrowFileIterator behavior to not do arrow path inference if file_format='json'
- Make EvalArgs valid
- Move testing iterators to a common directory to allow usage in multiple test files
- Make it so that SequenceIterator can take a None rng_state, to disable all rng ops (for eval mainly)

Test Plan:

- `pytest bytelatent`
- `python -m bytelatent.train config=../internal-blt/configs/entropy_model.yaml logging.wandb=null eval=null`
2025-02-21 16:21:07 -08:00

46 lines
1.3 KiB
Python

from bytelatent.data.iterators.dev_iterators import BltTestIterator
from bytelatent.data.iterators.limit_iterator import LimitIterator
def test_limit_iterator():
total = 10
limit = 5
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == limit
limit = 10
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == limit == total
limit = 20
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == total
limit = -1
base_iterator = BltTestIterator(total=total)
limit_iterator = LimitIterator(base_iterator, limit=limit)
iterator = limit_iterator.create_iter()
n = 0
for example in iterator:
assert example.sample_id == f"test_{n}"
n += 1
assert n == total