From afedb16598776530b7d21c5d6997ef1126076804 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 09:41:58 -0800 Subject: [PATCH 1/2] Update checkpointing to use fsspec (#39) Summary: - Make the data/checkpoint code fsspec compatible - Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR Test Plan: Run unit tests and the commands below ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 ``` These currently won't work due to the torch distributed save, but theses hould be tested at a later date ``` python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` ``` torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/ ``` --- bytelatent/args.py | 14 +++-- bytelatent/checkpoint.py | 115 +++++++++++++++++++++------------------ bytelatent/logger.py | 9 ++- bytelatent/metrics.py | 15 ++++- bytelatent/train.py | 31 +++++++---- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index d1bac46..fc72b32 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -294,6 +294,14 @@ class TrainArgs(BaseModel): def dump_to_yaml_file( self, path: str, log_config: bool = True, sort_keys: bool = True ): + yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys) + with open(path, "w") as f: + if log_config: + logger.info("Using the following config for this run:") + logger.info(yaml_str) + f.write(yaml_str) + + def dump_to_yaml_str(self, sort_keys: bool = True): model_dict = self.model_dump(mode="json") yaml_str = yaml.dump( model_dict, @@ -301,8 +309,4 @@ class TrainArgs(BaseModel): sort_keys=sort_keys, default_flow_style=False, ) - with open(path, "w") as f: - if log_config: - logger.info("Using the following config for this run:") - logger.info(yaml_str) - f.write(yaml_str) + return yaml_str diff --git a/bytelatent/checkpoint.py b/bytelatent/checkpoint.py index f213c84..1668c88 100644 --- a/bytelatent/checkpoint.py +++ b/bytelatent/checkpoint.py @@ -4,10 +4,9 @@ import json import logging import os import re -from pathlib import Path -from typing import List, Optional, Tuple import fsspec +import s3fs import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp @@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str): Returns the path to the consolidated checkpoint """ - consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER - if not (consolidate_path / CONSOLIDATE_NAME).exists(): - consolidate_path.mkdir(exist_ok=True) - logger.info(f"Consolidating to: {str(consolidate_path)}") - dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME)) - (consolidate_path / CONFIG_NAME).write_text( - (Path(ckpt_dir) / CONFIG_NAME).read_text() + consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER) + consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME) + if not fs.exists(consolidate_name): + fs.mkdirs(consolidate_path, exist_ok=True) + logger.info(f"Consolidating to: {consolidate_path}") + dcp_to_torch_save(ckpt_dir, consolidate_name) + fs.write_text( + os.path.join(consolidate_path, CONFIG_NAME), + fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)), ) logger.info("Consolidated !") return consolidate_path def load_from_checkpoint( + fs: fsspec.AbstractFileSystem, ckpt_dir: str, model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, + optimizer: torch.optim.Optimizer | None = None, model_key: str = "model", optim_key: str = "optim", ): - if not (Path(ckpt_dir) / ".metadata").exists(): + if not fs.exists(os.path.join(ckpt_dir, ".metadata")): raise ValueError( f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it" ) @@ -115,19 +117,24 @@ class CheckpointManager: self.init_ckpt_path = args.init_ckpt_path self.continue_training_from_init = args.continue_training_from_init - assert self.fs.exists( - self.path - ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" + if not isinstance(self.fs, s3fs.S3FileSystem): + # S3 does not have a concept of directories + assert self.fs.exists( + self.path + ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)" self.existing_saves = self.get_existing_saves() - def get_existing_saves(self) -> List[Path]: - folders = [ - p - for p in Path(self.path).iterdir() - if p.is_dir() and re.match(RE_FOLDER, p.name) - ] - folders.sort(key=lambda p: _get_key_step(p.name)) + def get_existing_saves(self) -> list[str]: + if self.fs.exists(self.path) and self.fs.isdir(self.path): + folders = [ + p + for p in self.fs.ls(self.path) + if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p)) + ] + else: + folders = [] + folders.sort(key=lambda p: _get_key_step(os.path.basename(p))) return folders def clean_up(self): @@ -136,8 +143,9 @@ class CheckpointManager: eval_folders = [] other_folders = [] for p in self.existing_saves: - is_dump = _get_key_step(p.name) % self.dump_every.every == 0 - is_eval = _get_key_step(p.name) % self.eval_every.every == 0 + assert isinstance(p, str), f"Base path type: {p}" + is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0 + is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0 if is_dump: dump_folders.append(p) if is_eval: @@ -161,40 +169,39 @@ class CheckpointManager: if dist.get_rank() == 0: for folder in folder_to_remove: - for file in folder.iterdir(): - if file.is_file(): - file.unlink() - elif file.is_dir(): - assert file.name in [CONSOLIDATE_FOLDER] - for f in file.iterdir(): - f.unlink() - file.rmdir() - folder.rmdir() + for file in self.fs.ls(folder): + if self.fs.isfile(file): + self.fs.rm_file(file) + elif self.fs.isdir(file): + assert os.path.name(file) in [CONSOLIDATE_FOLDER] + for f in self.fs.ls(file): + self.fs.rm(f) + self.fs.rmdir(file) + self.fs.rmdir(folder) dist.barrier() self.existing_saves = list(folder_to_keep) - self.existing_saves.sort(key=lambda p: _get_key_step(p.name)) + self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p))) - def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]: + def get_last_step_path(self, dp_rank: int = 0) -> str | None: path = None for p in reversed(self.existing_saves): - if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file(): + + if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))): path = p break return path - def _create_folder(self, base_path: Path, folder_name: str) -> Path: - folder = base_path / folder_name + def _create_folder(self, base_path: str, folder_name: str) -> str: + folder = os.path.join(base_path, folder_name) if get_is_master(): - folder.mkdir(parents=False, exist_ok=True) + self.fs.mkdirs(folder, exist_ok=True) if dist.is_initialized(): dist.barrier() return folder - def _get_dp_tp_mesh( - self, device_mesh: Optional[DeviceMesh] = None - ) -> Tuple[int, int]: + def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]: dp_rank = 0 tp_rank = 0 if device_mesh is not None: @@ -222,14 +229,14 @@ class CheckpointManager: model, optimizer, train_state, - config, - device_mesh: Optional[DeviceMesh] = None, + config: BaseModel, + device_mesh: DeviceMesh | None = None, ) -> bool: # When creating directory check if only rank0 or is there other solution - path = Path(self.path) + path = self.path curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step)) - logger.info(f"Saving to: {str(curr_save_dir)}") + logger.info(f"Saving to: {curr_save_dir}") if dist.is_initialized(): dist.barrier() @@ -242,17 +249,19 @@ class CheckpointManager: if dist.is_initialized(): dist.barrier() + print("config type", type(config)) if get_is_master(): - config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME) + self.fs.write_text( + os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json() + ) # Add json dump here dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) if tp_rank == 0: train_state_name = TRAIN_STATE_NAME.format(dp_rank) - logger.info( - f"Saving train state to: {str(curr_save_dir / train_state_name)}" - ) - with open(curr_save_dir / train_state_name, "w") as f: + train_state_full_path = os.path.join(curr_save_dir, train_state_name) + logger.info(f"Saving train state to: {train_state_full_path}") + with self.fs.open(train_state_full_path, "w") as f: json.dump(train_state.state_dict(), f) logger.info("Train state saved !") @@ -271,7 +280,7 @@ class CheckpointManager: optimizer, train_state, device_mesh: DeviceMesh, - path: Optional[Path] = None, + path: str | None = None, ): dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh) # Loading tries to load the provided path, if not available the last saved step and finally from the init path @@ -284,12 +293,12 @@ class CheckpointManager: # Only load train state if it's provided, the files exist and we're not loading from init path train_state_name = TRAIN_STATE_NAME.format(dp_rank) logger.info("Reloading train state") - with open(path / train_state_name, "r") as f: + with self.fs.open(os.path.join(path, train_state_name), "r") as f: train_state_dict = json.load(f) train_state.load_state_dict(train_state_dict) logger.info("Train state reloaded") - logger.info(f"Loading from: {str(path)}") + logger.info(f"Loading from: {path}") state_dict = self.get_state_dict( model=model, optimizer=optimizer, diff --git a/bytelatent/logger.py b/bytelatent/logger.py index 87f04cc..6f9a397 100644 --- a/bytelatent/logger.py +++ b/bytelatent/logger.py @@ -6,6 +6,8 @@ import sys import time from datetime import timedelta +import fsspec + from bytelatent.distributed import get_global_rank, get_is_slurm_job @@ -92,6 +94,7 @@ def init_logger( *, name: str | None = None, level: str = "INFO", + fs: fsspec.AbstractFileSystem | None = None, ): """ Setup logging. @@ -121,7 +124,11 @@ def init_logger( if log_file is not None and get_global_rank() == 0: # build file handler - file_handler = logging.FileHandler(log_file, "a") + if fs is None: + file_handler = logging.FileHandler(log_file, "a") + else: + file_stream = fs.open(log_file, mode="a") + file_handler = logging.StreamHandler(file_stream) file_handler.setLevel(logging.NOTSET) file_handler.setFormatter(LogFormatter()) # update logger diff --git a/bytelatent/metrics.py b/bytelatent/metrics.py index e746e4f..fb443d7 100644 --- a/bytelatent/metrics.py +++ b/bytelatent/metrics.py @@ -8,6 +8,7 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Union +import fsspec import torch import torch.nn as nn import wandb @@ -53,14 +54,24 @@ class LoggingArgs(BaseModel): class MetricLogger: - def __init__(self, outdir: Path, args: Any | None = None): + def __init__( + self, + outdir: Path, + # args: TrainArgs + args: Any | None = None, + fs: fsspec.AbstractFileSystem | None = None, + ): self.outdir = outdir self.jsonl_writer = None + self.fs = fs self.args = args def open(self): if self.jsonl_writer is None: - self.jsonl_writer = open(self.outdir, "a") + if self.fs is None: + self.jsonl_writer = open(self.outdir, "a") + else: + self.jsonl_writer = self.fs.open(self.outdir, "a") if ( self.args is not None and self.args.logging.wandb is not None diff --git a/bytelatent/train.py b/bytelatent/train.py index bb8307a..9bfe12a 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -8,7 +8,6 @@ import sys from contextlib import ExitStack from copy import deepcopy from dataclasses import asdict, dataclass -from pathlib import Path from timeit import default_timer as timer from typing import Any, TypeVar @@ -18,13 +17,13 @@ import torch.nn.functional import torch.nn.functional as F import wandb import xformers.profiler -from omegaconf import OmegaConf from torch.distributed._tensor import DTensor from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler from bytelatent.args import TrainArgs, parse_args from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.multiprocess_iterator import ( MultiprocessIterator, MultiprocessIteratorState, @@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int): if args.checkpoint.path is None: logger.info(f"Setting checkpoint path to {args.checkpoint.path}") - args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints") + args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints") + data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile) for source in args.data.sources: data_path = os.path.join(args.data.root_dir, source) - assert os.path.exists(data_path), f"{data_path} doesn't exist" + assert data_fs.exists(data_path), f"{data_path} doesn't exist" if ( args.distributed.dp_replicate @@ -255,10 +255,15 @@ def train(args: TrainArgs): args, tokenizer.n_words, ) + dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile) if get_is_master(): - os.makedirs(args.dump_dir, exist_ok=True) - args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml") - init_logger(Path(args.dump_dir) / "train.log") + dump_fs.mkdirs(args.dump_dir, exist_ok=True) + config_yaml_str = args.dump_to_yaml_str() + logging.info("TrainArgs: \n%s", config_yaml_str) + dump_fs.write_text( + os.path.join(args.dump_dir, "config.yaml"), config_yaml_str + ) + init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs) init_signal_handler(set_preemption_flag) # For handling preemption signals. setup_env(args.env) setup_torch_distributed(args.distributed) @@ -313,8 +318,11 @@ def train(args: TrainArgs): if args.checkpoint.init_ckpt_path: logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}") + ckpt_fs = get_fs( + args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile + ) load_from_checkpoint( - args.checkpoint.init_ckpt_path, model, model_key="model" + ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model" ) # Put model_key="" if its directly the model checkpoint model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded else: @@ -352,13 +360,14 @@ def train(args: TrainArgs): checkpoint.load(model, optimizer, train_state, world_mesh) # Either load from latest checkpoint or start from scratch if args.probe_freq is not None: + # TODO: Convert this to fsspec compatible if get_is_master(): - os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True) + os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True) torch.distributed.barrier() probe = AutoProbeD( model, ( - Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl" + os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl") if (dp_rank % 128 == 0) else None ), @@ -370,7 +379,7 @@ def train(args: TrainArgs): # train loop model.train() metric_logger = context_stack.enter_context( - MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args) + MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs) ) data_loader = train_state.data_loader_state.build() batch_iterator = data_loader.create_iter() From 8d26140970e73327eae3e99f73d7d816375c1134 Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Thu, 6 Feb 2025 17:42:32 +0000 Subject: [PATCH 2/2] Allow ArrowIterator to read from json Summary: Test Plan: --- bytelatent/args.py | 38 +++++++- bytelatent/data/iterators/arrow_iterator.py | 97 +++++++++++-------- bytelatent/preprocess/preprocess_entropies.py | 26 +++-- 3 files changed, 104 insertions(+), 57 deletions(-) diff --git a/bytelatent/args.py b/bytelatent/args.py index fc72b32..263e8e3 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -1,8 +1,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import logging import os from typing import Any +import fsspec import numpy as np import yaml from omegaconf import OmegaConf @@ -10,11 +12,9 @@ from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs from bytelatent.data.data_types import Batch +from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import StatefulIterator -from bytelatent.data.iterators.arrow_iterator import ( - ArrowFileIterator, - find_and_sanitize_chunks, -) +from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator from bytelatent.data.iterators.looping_iterator import LoopingIterator from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator @@ -53,6 +53,33 @@ def parse_args(args_cls): return pydantic_args +TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" + + +def find_and_sanitize_chunks( + dataset_path: str, + world_size: int, + file_pattern: str, + s3_profile: str | None = None, +): + fs = get_fs(dataset_path, s3_profile=s3_profile) + path_with_glob = os.path.join(dataset_path, file_pattern) + dataset_chunks = fs.glob(path_with_glob) + n_chunks = len(dataset_chunks) + + if n_chunks > world_size: + n_discard = n_chunks - world_size + dataset_chunks = dataset_chunks[:world_size] + else: + assert ( + world_size % n_chunks == 0 + ), "World size should be a multiple of number of chunks" + + assert n_chunks > 0, f"No valid chunks in {dataset_path}" + + return dataset_chunks + + def distribute_data_to_rank( *, dataset_path: str, @@ -62,9 +89,10 @@ def distribute_data_to_rank( rank: int, world_size: int, s3_profile: str | None = None, + file_pattern: str = TRAIN_DATA_FILE_PATTERN, ) -> ArrowFileIterator: dataset_chunks = find_and_sanitize_chunks( - dataset_path, world_size, s3_profile=s3_profile + dataset_path, world_size, file_pattern, s3_profile=s3_profile ) n_workers_per_chunk = world_size // len(dataset_chunks) rank_to_arrow_iterator_params = [] diff --git a/bytelatent/data/iterators/arrow_iterator.py b/bytelatent/data/iterators/arrow_iterator.py index 4e7b99e..1c68d3a 100644 --- a/bytelatent/data/iterators/arrow_iterator.py +++ b/bytelatent/data/iterators/arrow_iterator.py @@ -16,6 +16,7 @@ from bytelatent import ByteLatentError from bytelatent.data.data_types import BltExample from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator +from bytelatent.preprocess.preprocess_entropies import get_id_key, get_text logger = getLogger(__name__) @@ -32,6 +33,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): arrow_batch_size: int = 100 s3_profile: str | None filesystem_type: str | None = None + file_format: str def build(self) -> "ArrowFileIterator": arrow_file = ArrowFileIterator( @@ -44,6 +46,7 @@ class ArrowFileIteratorState(BaseModel, IteratorState): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) if self.row_num != 0: arrow_file._set_row_num(self.row_num) @@ -70,6 +73,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files: list[str] | None = None, s3_profile: str | None = None, filesystem_type: str | None = None, + file_format: str = "arrow", ): assert 0 <= worker_id < num_workers, (worker_id, num_workers) if file_path is None and dataset_files is None: @@ -87,12 +91,16 @@ class ArrowFileIterator(StatefulIterator): self.arrow_batch_size = arrow_batch_size self.s3_profile = s3_profile self.filesystem_type = filesystem_type + self.file_format = file_format self.fs = None if self.filesystem_type is not None: if self.filesystem_type == "file": self.fs = fsspec.filesystem("file") elif self.filesystem_type == "s3": self.fs = fsspec.filesystem("s3", profile=s3_profile) + else: + raise ValueError("Unknown filesystem") + logger.info("Arrow iterator using fs=%s", self.fs) if dataset_files is None: # Prepare arrow shards @@ -153,6 +161,7 @@ class ArrowFileIterator(StatefulIterator): dataset_files=self.dataset_files, s3_profile=self.s3_profile, filesystem_type=self.filesystem_type, + file_format=self.file_format, ) def create_iter( @@ -164,7 +173,7 @@ class ArrowFileIterator(StatefulIterator): else: filesystem = None self.dataset = pa.dataset.dataset( - self.dataset_files, format="arrow", filesystem=filesystem + self.dataset_files, format=self.file_format, filesystem=filesystem ) self.batch_iterator = self.dataset.to_batches( batch_size=self.arrow_batch_size @@ -173,13 +182,22 @@ class ArrowFileIterator(StatefulIterator): if self.batch_to_consume is not None: batch_columns: dict[str, list] = self.batch_to_consume self.batch_to_consume = None - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -191,13 +209,22 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: batch_columns = batch.to_pydict() - sample_ids = batch_columns["sample_id"] - texts = batch_columns["text"] - entropies = batch_columns["entropies"] + if self.file_format == "arrow": + sample_ids = batch_columns["sample_id"] + texts = batch_columns["text"] + entropies = batch_columns["entropies"] + elif self.file_format == "json": + # This data hasn't been preprocessed to a uniform format, + # so we have to do it now and omit entropies + sample_ids = batch_columns[get_id_key(batch_columns)] + texts = get_text(batch_columns) + entropies = None + else: + raise ValueError(f"Unknown file format: {self.file_format}") for i in range(len(sample_ids)): out = BltExample( sample_id=sample_ids[i], - entropies=entropies[i], + entropies=entropies[i] if entropies is not None else None, text=texts[i], tokens=None, mask=None, @@ -231,13 +258,24 @@ class ArrowFileIterator(StatefulIterator): for batch in self.batch_iterator: if len(batch) > curr_remaining: batch_columns: dict[str, list] = batch.to_pydict() - batch_columns["sample_id"] = batch_columns["sample_id"][ - curr_remaining: - ] - batch_columns["entropies"] = batch_columns["entropies"][ - curr_remaining: - ] - batch_columns["text"] = batch_columns["text"][curr_remaining:] + if self.file_format == "arrow": + leftover_sample_ids = batch_columns["sample_id"][ + curr_remaining: + ] + leftover_entropies = batch_columns["entropies"][curr_remaining:] + leftover_texts = batch_columns["text"][curr_remaining:] + elif self.file_format == "json": + leftover_sample_ids = batch_columns[get_id_key(batch_columns)][ + curr_remaining: + ] + leftover_entropies = None + leftover_texts = get_text(batch_columns)[curr_remaining:] + else: + raise ValueError(f"Unknown file format: {self.file_format}") + + batch_columns["sample_id"] = leftover_sample_ids + batch_columns["entropies"] = leftover_entropies + batch_columns["text"] = leftover_texts self.batch_to_consume = batch_columns break elif len(batch) == curr_remaining: @@ -250,30 +288,3 @@ class ArrowFileIterator(StatefulIterator): logger.info( f"Finished setting arrow position to {target_row_num} for {self.dataset_files}" ) - - -TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" - - -def find_and_sanitize_chunks( - dataset_path: str, - world_size: int, - file_pattern: str = TRAIN_DATA_FILE_PATTERN, - s3_profile: str | None = None, -): - fs = get_fs(dataset_path, s3_profile=s3_profile) - path_with_glob = os.path.join(dataset_path, file_pattern) - dataset_chunks = fs.glob(path_with_glob) - n_chunks = len(dataset_chunks) - - if n_chunks > world_size: - n_discard = n_chunks - world_size - dataset_chunks = dataset_chunks[:world_size] - else: - assert ( - world_size % n_chunks == 0 - ), "World size should be a multiple of number of chunks" - - assert n_chunks > 0, f"No valid chunks in {dataset_path}" - - return dataset_chunks diff --git a/bytelatent/preprocess/preprocess_entropies.py b/bytelatent/preprocess/preprocess_entropies.py index 519da94..31a4802 100644 --- a/bytelatent/preprocess/preprocess_entropies.py +++ b/bytelatent/preprocess/preprocess_entropies.py @@ -15,29 +15,37 @@ from bytelatent.entropy_model import load_entropy_model from bytelatent.tokenizers.build_tokenizer import TokenizerArgs -def get_id_from_doc(doc: dict) -> int: +def get_id_key(doc: dict) -> int: """ We need a reliable way to ensure that samples from jsonl and arrow are the same, but there is no unique id field, so derive the best possible """ if "sample_id" in doc: - sample_id = doc["sample_id"] + return "sample_id" elif "title" in doc: - sample_id = doc["title"] + return "title" elif "qid" in doc: - sample_id = doc["qid"] + return "qid" elif "paper_id" in doc: - sample_id = doc["paper_id"] + return "paper_id" elif "path" in doc: - sample_id = doc["path"] + return "path" elif "url" in doc: - sample_id = doc["url"] + return "url" elif "id" in doc: - sample_id = doc["id"] + return "id" else: raise ValueError(f"Could not find a id key from: {doc.keys()}") - return str(sample_id) + + +def get_id_from_doc(doc: dict) -> int: + """ + We need a reliable way to ensure that samples from jsonl + and arrow are the same, but there is no unique id field, + so derive the best possible + """ + return str(doc[get_id_key(doc)]) def get_text(doc: dict):