Initial codes and scripts for training entropy model (#34)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-27 09:46:44 -08:00 committed by GitHub
parent a809259e71
commit 7622d28b74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 209 additions and 34 deletions

1
.gitignore vendored
View file

@ -166,3 +166,4 @@ figures/
.vscode/
.DS_Store
internal/
jobs_parallel-copy/

View file

@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel):
max_encoder_seq_length: int = 12288
enable_byte_ngrams: bool = False
add_patches: bool = True
tokenizer_args: TokenizerArgs = TokenizerArgs()
patcher_args: PatcherArgs = PatcherArgs()
@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel):
looping_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
add_patches=self.add_patches,
)
sequence_iterator = SequenceIterator(
preprocess_iterator,
@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel):
source_to_iterator=source_to_sequence_iterators,
)
tokenizer = self.tokenizer_args.build()
if self.tokenizer_args.name == "bytes":
# TODO: Check this with Artidoro
pad_id = 0
else:
pad_id = tokenizer.boe_id
packing_args = PackingArgs(
batch_size=self.batch_size,
seq_len=self.seq_len,
pad_id=tokenizer.boe_id,
pad_id=pad_id,
max_length=self.max_encoder_seq_length,
pad_to_max_length=self.pad_to_max_length,
enable_byte_ngrams=self.enable_byte_ngrams,
tokenizer_name=self.tokenizer_args.name,
)
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
if self.load_async:
@ -180,7 +189,7 @@ class TrainArgs(BaseModel):
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model

View file

@ -26,10 +26,9 @@ model:
vocab_size: 260
dim_token: 256
patch_size: 6
tokenization_mode: "bytes"
patching_mode: "space"
tie_local_encoder_decoder_logits: false
data_loader_patching: true
patch_in_forward: false
max_encoder_seq_length: 12288
pad_to_max_length: true
patching_threshold: 3.1439168453216553

View file

@ -0,0 +1,82 @@
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
# Evals can be activated by uncommenting its config
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
dump_dir: /tmp/
name: "debug"
steps: 100_000
probe_freq: null
seed: 777
optim:
lr: 4e-04
warmup: 500
lr_min_ratio: 0.1
clip: 10.0
distributed:
fsdp_type: full_shard
model_dtype: bf16
matmul_allow_tf32: false
selective_activation_checkpointing: false
tp_size: 1
train_entropy_model: true
model: null
entropy_model:
dim: 768
n_layers: 14
n_heads: 12
max_seqlen: 8192
# vocab_size: -1
vocab_size: 260
ffn_dim_multiplier: 1.0
sliding_window: 512
attn_bias_type: "local_block_causal"
attn_impl: "xformers"
data:
s3_profile: blt
root_dir: ???
sources:
dclm_baseline_1.0: 1.0
batch_size: 2
prefetch_size: 64
# seqlen is in terms of patches and
# max_encoder_seq_length is in terms of bytes.
# For entropy model, these are the same since 1 patch=1 byte
seq_len: 8192
max_encoder_seq_length: 8192
load_async: true
preprocess_dir: ???
# We don't need patches for this model
add_patches: false
patcher_args:
# This doesn't matter since byte entropy model doesn't use patching,
# so pick the most efficient, so static
patching_mode: byte
tokenizer_args:
name: bytes
profiling:
run: false
checkpoint:
dump:
every: 500
keep: 3
eval:
every: 1000
keep: -1
logging:
freq: 10
eval_on_gpus: 8
eval:
dataset_dir: ???
tasks: ???
generator:
max_tokens: 65536
dtype: bf16
mp_size: 1

View file

@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
class BltSequence(BaseModel):
tokens: list[int]
mask: list[bool]
patch_lengths: list[int]
patch_lengths: list[int] | None
@dataclass

View file

@ -17,6 +17,7 @@ class PackingArgs(BaseModel):
max_length: int | None
pad_to_max_length: bool
enable_byte_ngrams: bool
tokenizer_name: str
class PackingIteratorState(BaseModel, IteratorState):
@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
)
def create_iter(self):
if self.packing_args.tokenizer_name == "bytes":
return self._create_iter_from_bytes()
else:
return self._create_iter_from_patch_lengths()
def _create_iter_from_bytes(self):
sequence_iter = self.sequence_iterator.create_iter()
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
seq_len = self.packing_args.seq_len
while True:
tokens: list[list[int]] = []
masks: list[list[bool]] = []
for _ in range(self.packing_args.batch_size):
sequence = next(sequence_iter)
_tokens = sequence.tokens
_mask = sequence.mask
assert (
sequence.patch_lengths is None
), "patch_lengths should not be used in byte packing"
tokens.append(_tokens)
masks.append(_mask)
x = np.full((batch_size, seq_len), fill_value=pad_id)
y = np.full((batch_size, seq_len), fill_value=pad_id)
for i, tok_seq in enumerate(tokens):
x[i, : len(tok_seq)] = tok_seq
y[i, : len(tok_seq) - 1] = tok_seq[1:]
batch = Batch(x=x, y=y)
assert (
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
yield batch
def _create_iter_from_patch_lengths(self):
sequence_iter = self.sequence_iterator.create_iter()
batch_size = self.packing_args.batch_size
pad_id = self.packing_args.pad_id
@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
_tokens = sequence.tokens
_mask = sequence.mask
_patch_lengths = sequence.patch_lengths
assert (
_patch_lengths is not None
), "patch lengths are required for packing based on patches."
# Reminder: seq_len is in terms of patches
assert len(sequence.patch_lengths) == self.packing_args.seq_len
last_patch_length = 0
if _patch_lengths[0] > 1:

View file

@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator):
for example in example_iter:
assert example.tokens is not None
assert example.mask is not None
if self.preprocess_iterator.add_patches:
assert example.patch_lengths is not None
assert len(example.tokens) == sum(example.patch_lengths)
else:
assert example.patch_lengths is None
assert len(example.tokens) != 0
assert len(example.mask) != 0
assert len(example.tokens) == len(example.mask)
assert len(example.tokens) == sum(example.patch_lengths)
tokens.extend(example.tokens)
mask.extend(example.mask)
if self.preprocess_iterator.add_patches:
patch_lengths.extend(example.patch_lengths)
else:
# This lets the rest of the code work as expected and just yield byte seqs
patch_lengths.extend([1] * len(example.tokens))
while len(patch_lengths) >= n_buffer_patches:
if first:
@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator):
== len(seq_mask[idx])
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
if self.preprocess_iterator.add_patches:
yield BltSequence(
tokens=seq_tokens[idx],
mask=seq_mask[idx],
patch_lengths=seq_patch_lengths[idx],
)
else:
yield BltSequence(
tokens=seq_tokens[idx],
mask=seq_mask[idx],
patch_lengths=None,
)

View file

@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum):
bpe = "bpe"
bpe_patcher = "bpe_patcher"
space = "space"
static = "static"
byte = "byte"
class PatcherArgs(BaseModel):
@ -34,7 +36,6 @@ class PatcherArgs(BaseModel):
max_patch_length: int | None = None
patch_size: float = 4.5
patching_batch_size: int = 1
data_loader_patching: bool = False
device: str = "cuda"
monotonicity: bool = False
log_time: bool = False
@ -486,7 +487,6 @@ class Patcher:
self.max_patch_length = patcher_args.max_patch_length
self.patch_size = patcher_args.patch_size
self.patching_batch_size = patcher_args.patching_batch_size
self.data_loader_patching = patcher_args.data_loader_patching
self.device = patcher_args.device
self.monotonicity = patcher_args.monotonicity
self.log_time = patcher_args.log_time
@ -528,7 +528,7 @@ class Patcher:
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
scores = None
# STATIC
if self.patching_mode is None:
if self.patching_mode == PatchingModeEnum.static:
patch_lengths = torch.zeros(
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
dtype=tokens.dtype,
@ -536,6 +536,10 @@ class Patcher:
).fill_(self.patch_size)
if seq_len_next_tok % self.patch_size != 0:
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
elif self.patching_mode == PatchingModeEnum.byte:
patch_lengths = torch.ones(
(bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
)
# ENTROPY
elif self.patching_mode == PatchingModeEnum.entropy:
if self.log_time:

View file

@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
n_heads: int = 8
# TODO: What is the purpose of this parameter?
weight_tying: bool = False
patch_in_forward: bool = False
# Architecture and dimensions
dim_token: int = 256
@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
n_layers_local_encoder: int = 8
# Tokenization and patching
tokenization_mode: str = "bpe"
patch_size: float | None = None
patching_mode: str | None = None
patching_threshold: float | None = None
@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
monotonicity: bool = False
patching_batch_size: int = 1
patching_device: str = "cuda"
data_loader_patching: bool = False
max_patch_length: int | None = None
# Encoder/Decoder configuration
@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module):
self.output.weight = self.tok_embeddings.weight
# Patcher module
if not args.data_loader_patching:
if args.patch_in_forward:
self.patcher = Patcher(
PatcherArgs(
patch_size=args.patch_size,

View file

@ -68,10 +68,9 @@ def create_args(cross_attention=False):
# Additional args from command line
dim_token=256,
patch_size=6,
tokenization_mode="bytes",
patching_mode="space",
tie_local_encoder_decoder_logits=False,
data_loader_patching=True,
patch_in_forward=False,
max_encoder_seq_length=12288,
pad_to_max_length=True,
encoder_lm_loss=False,

View file

@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD
from bytelatent.profiling import maybe_run_profiler
from bytelatent.stool import StoolArgs, launch_job
from bytelatent.transformer import (
LMTransformer,
build_fsdp_grouping_plan,
get_no_recompute_ops,
get_num_flop_per_token,
@ -103,10 +104,15 @@ class TrainState(Stateful):
def validate_train_args(args: TrainArgs, output_size: int):
if args.model.vocab_size < 0:
assert args.model is not None or args.entropy_model is not None
if args.model is not None:
logger.info(f"Setting model output size to {args.model.vocab_size}")
args.model.vocab_size = output_size
if args.entropy_model is not None:
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
args.entropy_model.vocab_size = output_size
assert args.dump_dir, "Dump dir not set"
if args.checkpoint.path is None:
@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int):
and args.distributed.dp_replicate == get_world_size()
)
if args.model is not None:
args.model.max_seqlen = args.data.seq_len
if args.entropy_model is not None:
args.entropy_model.max_seqlen = args.data.seq_len
if args.distributed.tp_size == 1:
logger.warning(
@ -237,7 +246,14 @@ def train(args: TrainArgs):
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
with torch.device("meta"):
if args.train_entropy_model:
assert args.entropy_model is not None
model = LMTransformer(args.entropy_model)
model_args = args.entropy_model
else:
assert args.model is not None
model = ByteLatentTransformer(args.model)
model_args = args.model
logger.info("Model is built !")
model_param_count = get_num_params(model)
@ -247,7 +263,7 @@ def train(args: TrainArgs):
world_mesh,
args.model,
args.distributed,
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
tp_parallelize=tp_parallelize,
no_recompute_ops=get_no_recompute_ops(),
)
@ -267,7 +283,7 @@ def train(args: TrainArgs):
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
else:
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
torch.manual_seed(args.model.seed)
torch.manual_seed(model_args.seed)
model.init_weights()
check_model_value_range(model, range=10.0, std=1.0)
@ -342,10 +358,17 @@ def train(args: TrainArgs):
batch.x,
).cuda()
batch_y = torch.from_numpy(batch.y).cuda()
if batch.patch_lengths is None:
batch_patch_lengths = None
else:
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
if (
not args.train_entropy_model
and args.model.encoder_enable_byte_ngrams
and batch.ngram_ids is None
):
raise ValueError(
"Cannot enable byte ngrams and have batch.ngram_ids be None"
)
@ -408,6 +431,9 @@ def train(args: TrainArgs):
next(probe_mod.parameters()).grad is None
), "Probe model shouldn't have grads at this point"
if args.train_entropy_model:
pred = model(batch_x)
else:
pred = model(
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
)
@ -474,9 +500,9 @@ def train(args: TrainArgs):
# Use xformer's analyze profile trace to get actual measurement
FLOPS = (
get_num_flop_per_token(
model_param_count - args.model.vocab_size * args.model.dim,
args.model.n_layers,
args.model.dim,
model_param_count - model_args.vocab_size * model_args.dim,
model_args.n_layers,
model_args.dim,
args.data.seq_len,
)
* wps