2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
2025-01-10 01:02:25 +00:00
import os
2024-12-12 23:32:30 +00:00
import re
from logging import getLogger
from typing import Any , Generator
2025-01-10 01:02:25 +00:00
import fsspec
2024-12-12 23:32:30 +00:00
import pyarrow as pa
# pyarrow needs the initialization from this import
import pyarrow . dataset # pyright: ignore
2025-01-10 01:02:25 +00:00
import s3fs
2024-12-12 23:32:30 +00:00
from pydantic import BaseModel , ConfigDict
from bytelatent import ByteLatentError
from bytelatent . data . data_types import BltExample
2025-01-10 01:02:25 +00:00
from bytelatent . data . file_util import get_fs
2024-12-12 23:32:30 +00:00
from bytelatent . data . iterators . abstract_iterator import IteratorState , StatefulIterator
logger = getLogger ( __name__ )
class ArrowFileIteratorState ( BaseModel , IteratorState ) :
model_config = ConfigDict ( extra = " forbid " )
file_path : str | None
row_num : int
num_workers : int
worker_id : int
preprocess_dir : str | None
dataset_files : list [ str ] | None
entropy_model_name : str | None
arrow_batch_size : int = 100
2025-01-10 01:02:25 +00:00
s3_profile : str | None
filesystem_type : str | None = None
2024-12-12 23:32:30 +00:00
def build ( self ) - > " ArrowFileIterator " :
arrow_file = ArrowFileIterator (
file_path = self . file_path ,
worker_id = self . worker_id ,
num_workers = self . num_workers ,
preprocess_dir = self . preprocess_dir ,
entropy_model_name = self . entropy_model_name ,
arrow_batch_size = self . arrow_batch_size ,
dataset_files = self . dataset_files ,
2025-01-10 01:02:25 +00:00
s3_profile = self . s3_profile ,
filesystem_type = self . filesystem_type ,
2024-12-12 23:32:30 +00:00
)
if self . row_num != 0 :
arrow_file . _set_row_num ( self . row_num )
return arrow_file
2025-01-10 01:02:25 +00:00
def shard_sort_key ( file : str ) :
assert isinstance ( file , str )
match = re . search ( r " .+ \ .shard_([0-9]+) \ .arrow " , file )
2024-12-12 23:32:30 +00:00
shard_number = int ( match . group ( 1 ) )
return shard_number
class ArrowFileIterator ( StatefulIterator ) :
def __init__ (
self ,
* ,
file_path : str | None ,
worker_id : int ,
num_workers : int ,
preprocess_dir : str | None ,
entropy_model_name : str | None ,
arrow_batch_size : int ,
dataset_files : list [ str ] | None = None ,
2025-01-10 01:02:25 +00:00
s3_profile : str | None = None ,
filesystem_type : str | None = None ,
2024-12-12 23:32:30 +00:00
) :
assert 0 < = worker_id < num_workers , ( worker_id , num_workers )
if file_path is None and dataset_files is None :
raise ByteLatentError ( " file_path and dataset_files cannot both be None " )
self . row_num = 0
self . iter_id = 0
self . batch_iterator = None
self . batch_to_consume = None
self . dataset = None
self . file_path = file_path
self . worker_id = worker_id
self . num_workers = num_workers
self . preprocess_dir = preprocess_dir
self . entropy_model_name = entropy_model_name
self . arrow_batch_size = arrow_batch_size
2025-01-10 01:02:25 +00:00
self . s3_profile = s3_profile
self . filesystem_type = filesystem_type
self . fs = None
if self . filesystem_type is not None :
if self . filesystem_type == " file " :
self . fs = fsspec . filesystem ( " file " )
elif self . filesystem_type == " s3 " :
self . fs = fsspec . filesystem ( " s3 " , profile = s3_profile )
2024-12-12 23:32:30 +00:00
if dataset_files is None :
# Prepare arrow shards
2025-01-10 01:02:25 +00:00
jsonl_file = file_path
parts = re . match (
r " (.+) \ .chunk \ .[0-9]+ \ .jsonl " , os . path . basename ( jsonl_file )
)
2024-12-12 23:32:30 +00:00
assert parts is not None
dataset = parts . group ( 1 )
2025-01-10 01:02:25 +00:00
data_dir = os . path . join ( preprocess_dir , dataset , entropy_model_name )
data_dir_with_glob = os . path . join (
data_dir , f " { os . path . basename ( jsonl_file ) } .shard_*.arrow "
)
if self . fs is None :
self . fs = get_fs ( data_dir_with_glob , s3_profile = s3_profile )
if isinstance ( self . fs , s3fs . S3FileSystem ) :
self . filesystem_type = " s3 "
else :
self . filesystem_type = " file "
shard_files = self . fs . glob ( data_dir_with_glob )
2024-12-12 23:32:30 +00:00
for s in shard_files :
2025-01-10 01:02:25 +00:00
complete_file = os . path . join (
data_dir , f " { os . path . basename ( s ) } .complete "
)
if not self . fs . exists ( complete_file ) :
2024-12-12 23:32:30 +00:00
raise ValueError ( f " Missing .complete for input file: { s } " )
shard_files = sorted ( shard_files , key = shard_sort_key )
if len ( shard_files ) == 0 :
raise ByteLatentError (
f " Zero shard_files found corresponding to: { file_path } using preprocess_dir= { preprocess_dir } and entropy_model_name= { entropy_model_name } , so the search path is data_dir= { data_dir } for matches to { jsonl_file . name } .shard_*.arrow "
)
2025-01-10 01:02:25 +00:00
self . dataset_files = [ f for f in shard_files ]
2024-12-12 23:32:30 +00:00
else :
self . preprocess_dir = None
self . dataset_files = dataset_files
2025-01-10 01:02:25 +00:00
if dataset_files [ 0 ] . startswith ( " s3:// " ) :
for f in dataset_files :
assert f . startswith ( " s3:// " )
if self . fs is None :
self . fs = get_fs ( dataset_files [ 0 ] , s3_profile = s3_profile )
if isinstance ( self . fs , s3fs . S3FileSystem ) :
self . filesystem_type = " s3 "
else :
self . filesystem_type = " file "
2024-12-12 23:32:30 +00:00
def get_state ( self ) - > ArrowFileIteratorState :
return ArrowFileIteratorState (
file_path = self . file_path ,
row_num = self . row_num ,
worker_id = self . worker_id ,
num_workers = self . num_workers ,
preprocess_dir = self . preprocess_dir ,
entropy_model_name = self . entropy_model_name ,
arrow_batch_size = self . arrow_batch_size ,
dataset_files = self . dataset_files ,
2025-01-10 01:02:25 +00:00
s3_profile = self . s3_profile ,
filesystem_type = self . filesystem_type ,
2024-12-12 23:32:30 +00:00
)
def create_iter (
self ,
) - > Generator [ BltExample , Any , None ] :
if self . dataset is None :
2025-01-10 01:02:25 +00:00
if isinstance ( self . fs , s3fs . core . S3FileSystem ) :
filesystem = self . fs
else :
filesystem = None
self . dataset = pa . dataset . dataset (
self . dataset_files , format = " arrow " , filesystem = filesystem
)
2024-12-12 23:32:30 +00:00
self . batch_iterator = self . dataset . to_batches (
batch_size = self . arrow_batch_size
)
self . iter_id + = 1
if self . batch_to_consume is not None :
batch_columns : dict [ str , list ] = self . batch_to_consume
self . batch_to_consume = None
sample_ids = batch_columns [ " sample_id " ]
texts = batch_columns [ " text " ]
entropies = batch_columns [ " entropies " ]
for i in range ( len ( sample_ids ) ) :
out = BltExample (
sample_id = sample_ids [ i ] ,
entropies = entropies [ i ] ,
text = texts [ i ] ,
tokens = None ,
mask = None ,
patch_lengths = None ,
)
self . row_num + = 1
if ( self . row_num - 1 ) % self . num_workers == self . worker_id :
yield out
for batch in self . batch_iterator :
batch_columns = batch . to_pydict ( )
sample_ids = batch_columns [ " sample_id " ]
texts = batch_columns [ " text " ]
entropies = batch_columns [ " entropies " ]
for i in range ( len ( sample_ids ) ) :
out = BltExample (
sample_id = sample_ids [ i ] ,
entropies = entropies [ i ] ,
text = texts [ i ] ,
tokens = None ,
mask = None ,
patch_lengths = None ,
)
self . row_num + = 1
if ( self . row_num - 1 ) % self . num_workers == self . worker_id :
yield out
def _set_row_num ( self , target_row_num : int ) :
logger . info (
f " Setting arrow position to { target_row_num } for { self . dataset_files } "
)
if target_row_num is None or target_row_num == 0 :
self . row_num = 0
self . dataset = None
self . batch_iterator = None
self . batch_to_consume = None
else :
2025-01-10 01:02:25 +00:00
if isinstance ( self . fs , s3fs . core . S3FileSystem ) :
filesystem = self . fs
else :
filesystem = None
self . dataset = pa . dataset . dataset (
self . dataset_files , format = " arrow " , filesystem = filesystem
)
2024-12-12 23:32:30 +00:00
self . batch_iterator = self . dataset . to_batches (
batch_size = self . arrow_batch_size
)
curr_remaining = target_row_num
for batch in self . batch_iterator :
if len ( batch ) > curr_remaining :
batch_columns : dict [ str , list ] = batch . to_pydict ( )
batch_columns [ " sample_id " ] = batch_columns [ " sample_id " ] [
curr_remaining :
]
batch_columns [ " entropies " ] = batch_columns [ " entropies " ] [
curr_remaining :
]
batch_columns [ " text " ] = batch_columns [ " text " ] [ curr_remaining : ]
self . batch_to_consume = batch_columns
break
elif len ( batch ) == curr_remaining :
# We are exactly at the end of the batch,
# so the next batch is the right spot
break
else :
curr_remaining - = len ( batch )
self . row_num = target_row_num
logger . info (
f " Finished setting arrow position to { target_row_num } for { self . dataset_files } "
)
TRAIN_DATA_FILE_PATTERN = " *.chunk.*.jsonl "
def find_and_sanitize_chunks (
2025-01-10 01:02:25 +00:00
dataset_path : str ,
world_size : int ,
file_pattern : str = TRAIN_DATA_FILE_PATTERN ,
s3_profile : str | None = None ,
2024-12-12 23:32:30 +00:00
) :
2025-01-10 01:02:25 +00:00
fs = get_fs ( dataset_path , s3_profile = s3_profile )
path_with_glob = os . path . join ( dataset_path , file_pattern )
dataset_chunks = fs . glob ( path_with_glob )
2024-12-12 23:32:30 +00:00
n_chunks = len ( dataset_chunks )
if n_chunks > world_size :
n_discard = n_chunks - world_size
dataset_chunks = dataset_chunks [ : world_size ]
else :
assert (
world_size % n_chunks == 0
) , " World size should be a multiple of number of chunks "
assert n_chunks > 0 , f " No valid chunks in { dataset_path } "
return dataset_chunks