Test first batch matches

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-02-12 18:07:21 +00:00
parent 22c7fe1d1c
commit 4cee32ea8c
7 changed files with 62 additions and 5 deletions

View file

@ -14,7 +14,6 @@ from typing import Any, Dict, Optional
import torch import torch
import torch.distributed import torch.distributed
import wandb
import xformers.profiler import xformers.profiler
from lingua.args import dataclass_from_dict, dump_config, flatten_dict from lingua.args import dataclass_from_dict, dump_config, flatten_dict
from lingua.data import ( from lingua.data import (
@ -70,6 +69,8 @@ from bytelatent.transformer import (
tp_parallelize, tp_parallelize,
) )
import wandb
logger = logging.getLogger() logger = logging.getLogger()

View file

@ -28,6 +28,7 @@ def test_basic_arrow_file():
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None, s3_profile=None,
file_format="arrow",
) )
arrow_file = initial_state.build() arrow_file = initial_state.build()
start_state = arrow_file.get_state() start_state = arrow_file.get_state()
@ -57,6 +58,7 @@ def test_basic_arrow_file():
row_num=251, row_num=251,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None, s3_profile=None,
file_format="arrow",
) )
arrow_file = resumed_state.build() arrow_file = resumed_state.build()
for example in arrow_file.create_iter(): for example in arrow_file.create_iter():
@ -77,6 +79,7 @@ def test_basic_arrow_file():
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None, s3_profile=None,
file_format="arrow",
) )
arrow_file = rank_state.build() arrow_file = rank_state.build()
expected_ids = [] expected_ids = []

View file

@ -0,0 +1,48 @@
import os
import pickle
import pytest
from omegaconf import OmegaConf
from bytelatent.args import TrainArgs
from bytelatent.constants import BLT_DATA
def get_test_config():
if "BLT_INTERNAL" in os.environ:
internal_dir = os.environ["BLT_INTERNAL"]
else:
internal_dir = "../internal-blt/configs"
test_config = os.path.join(internal_dir, "tests.yaml")
return test_config
@pytest.mark.skipif(
not os.path.exists(get_test_config()),
reason="Skipping since internal config is missing",
)
def test_first_batch_matches():
test_config_path = get_test_config()
default_cfg = OmegaConf.create(TrainArgs().model_dump())
file_cfg = OmegaConf.load(test_config_path)
merged_cfg = OmegaConf.merge(default_cfg, file_cfg)
merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True)
train_args = TrainArgs.model_validate(merged_cfg)
# MP doesn't work with async very well, but it doesn't change logic
train_args.data.load_async = False
# Test data created by pickling first batch in train loop then exiting
with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f:
first_batch = pickle.load(f)
# Emulate 1 node, 8 gpu training
data_loader = train_args.data.build_from_rank(0, 8)
batch_iterator = data_loader.create_iter()
print("Getting first batch")
batch = next(batch_iterator)
assert (batch.x == first_batch.x).all()
assert (batch.y == first_batch.y).all()
assert (batch.mask == first_batch.mask).all()
assert (batch.patch_lengths == first_batch.patch_lengths).all()
assert batch.ngram_ids is None and first_batch.ngram_ids is None
assert batch.is_final == False and batch.is_final == False

View file

@ -11,11 +11,12 @@ from typing import Any, Union
import fsspec import fsspec
import torch import torch
import torch.nn as nn import torch.nn as nn
import wandb
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from bytelatent.distributed import get_is_master from bytelatent.distributed import get_is_master
import wandb
logger = logging.getLogger() logger = logging.getLogger()
@ -198,9 +199,10 @@ def upload_train_to_wandb(
import json import json
from pathlib import Path from pathlib import Path
import wandb
from omegaconf import OmegaConf from omegaconf import OmegaConf
import wandb
cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml") cfg = OmegaConf.load(Path(ckpt_dir) / "config.yaml")
cfg = OmegaConf.to_container(cfg) cfg = OmegaConf.to_container(cfg)

View file

@ -7,7 +7,6 @@ import os
from pathlib import Path from pathlib import Path
import torch.distributed import torch.distributed
import wandb
import xformers.profiler import xformers.profiler
from pydantic import BaseModel from pydantic import BaseModel
from torch.profiler.profiler import profile from torch.profiler.profiler import profile
@ -15,6 +14,8 @@ from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler
from bytelatent.distributed import get_is_master from bytelatent.distributed import get_is_master
import wandb
class ProfilerArgs(BaseModel): class ProfilerArgs(BaseModel):
run: bool = False run: bool = False

View file

@ -25,6 +25,7 @@ def test_entropy_model():
row_num=0, row_num=0,
arrow_batch_size=100, arrow_batch_size=100,
s3_profile=None, s3_profile=None,
file_format="arrow",
) )
arrow_file = initial_state.build() arrow_file = initial_state.build()
tokenizer_args = TokenizerArgs( tokenizer_args = TokenizerArgs(

View file

@ -17,7 +17,6 @@ import torch
import torch.distributed import torch.distributed
import torch.nn.functional import torch.nn.functional
import torch.nn.functional as F import torch.nn.functional as F
import wandb
import xformers.profiler import xformers.profiler
from torch.distributed._tensor import DTensor from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.stateful import Stateful
@ -63,6 +62,8 @@ from bytelatent.transformer import (
tp_parallelize, tp_parallelize,
) )
import wandb
logger = logging.getLogger() logger = logging.getLogger()
T = TypeVar("T") T = TypeVar("T")