mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-31 01:52:15 +00:00
Initial codes and scripts for training entropy model (#34)
Summary: Test Plan:
This commit is contained in:
parent
a809259e71
commit
7622d28b74
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -166,3 +166,4 @@ figures/
|
|||
.vscode/
|
||||
.DS_Store
|
||||
internal/
|
||||
jobs_parallel-copy/
|
||||
|
|
|
@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel):
|
|||
max_encoder_seq_length: int = 12288
|
||||
enable_byte_ngrams: bool = False
|
||||
|
||||
add_patches: bool = True
|
||||
|
||||
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
||||
patcher_args: PatcherArgs = PatcherArgs()
|
||||
|
||||
|
@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel):
|
|||
looping_iterator,
|
||||
patcher_args=self.patcher_args,
|
||||
tokenizer_args=self.tokenizer_args,
|
||||
add_patches=self.add_patches,
|
||||
)
|
||||
sequence_iterator = SequenceIterator(
|
||||
preprocess_iterator,
|
||||
|
@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel):
|
|||
source_to_iterator=source_to_sequence_iterators,
|
||||
)
|
||||
tokenizer = self.tokenizer_args.build()
|
||||
if self.tokenizer_args.name == "bytes":
|
||||
# TODO: Check this with Artidoro
|
||||
pad_id = 0
|
||||
else:
|
||||
pad_id = tokenizer.boe_id
|
||||
packing_args = PackingArgs(
|
||||
batch_size=self.batch_size,
|
||||
seq_len=self.seq_len,
|
||||
pad_id=tokenizer.boe_id,
|
||||
pad_id=pad_id,
|
||||
max_length=self.max_encoder_seq_length,
|
||||
pad_to_max_length=self.pad_to_max_length,
|
||||
enable_byte_ngrams=self.enable_byte_ngrams,
|
||||
tokenizer_name=self.tokenizer_args.name,
|
||||
)
|
||||
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
||||
if self.load_async:
|
||||
|
@ -180,7 +189,7 @@ class TrainArgs(BaseModel):
|
|||
|
||||
data: DataloaderArgs = DataloaderArgs()
|
||||
optim: OptimArgs = OptimArgs()
|
||||
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
||||
model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
|
||||
# This is only needed for training the entropy model
|
||||
entropy_model: LMTransformerArgs | None = None
|
||||
# Instead of training main model, train entropy model
|
||||
|
|
|
@ -26,10 +26,9 @@ model:
|
|||
vocab_size: 260
|
||||
dim_token: 256
|
||||
patch_size: 6
|
||||
tokenization_mode: "bytes"
|
||||
patching_mode: "space"
|
||||
tie_local_encoder_decoder_logits: false
|
||||
data_loader_patching: true
|
||||
patch_in_forward: false
|
||||
max_encoder_seq_length: 12288
|
||||
pad_to_max_length: true
|
||||
patching_threshold: 3.1439168453216553
|
||||
|
|
82
bytelatent/configs/entropy_model.yaml
Normal file
82
bytelatent/configs/entropy_model.yaml
Normal file
|
@ -0,0 +1,82 @@
|
|||
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
|
||||
# Evals can be activated by uncommenting its config
|
||||
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
|
||||
|
||||
dump_dir: /tmp/
|
||||
name: "debug"
|
||||
steps: 100_000
|
||||
probe_freq: null
|
||||
seed: 777
|
||||
optim:
|
||||
lr: 4e-04
|
||||
warmup: 500
|
||||
lr_min_ratio: 0.1
|
||||
clip: 10.0
|
||||
|
||||
distributed:
|
||||
fsdp_type: full_shard
|
||||
model_dtype: bf16
|
||||
matmul_allow_tf32: false
|
||||
selective_activation_checkpointing: false
|
||||
tp_size: 1
|
||||
|
||||
train_entropy_model: true
|
||||
model: null
|
||||
entropy_model:
|
||||
dim: 768
|
||||
n_layers: 14
|
||||
n_heads: 12
|
||||
max_seqlen: 8192
|
||||
# vocab_size: -1
|
||||
vocab_size: 260
|
||||
ffn_dim_multiplier: 1.0
|
||||
sliding_window: 512
|
||||
attn_bias_type: "local_block_causal"
|
||||
attn_impl: "xformers"
|
||||
|
||||
data:
|
||||
s3_profile: blt
|
||||
root_dir: ???
|
||||
sources:
|
||||
dclm_baseline_1.0: 1.0
|
||||
batch_size: 2
|
||||
prefetch_size: 64
|
||||
# seqlen is in terms of patches and
|
||||
# max_encoder_seq_length is in terms of bytes.
|
||||
# For entropy model, these are the same since 1 patch=1 byte
|
||||
seq_len: 8192
|
||||
max_encoder_seq_length: 8192
|
||||
load_async: true
|
||||
preprocess_dir: ???
|
||||
# We don't need patches for this model
|
||||
add_patches: false
|
||||
patcher_args:
|
||||
# This doesn't matter since byte entropy model doesn't use patching,
|
||||
# so pick the most efficient, so static
|
||||
patching_mode: byte
|
||||
tokenizer_args:
|
||||
name: bytes
|
||||
|
||||
profiling:
|
||||
run: false
|
||||
|
||||
checkpoint:
|
||||
dump:
|
||||
every: 500
|
||||
keep: 3
|
||||
eval:
|
||||
every: 1000
|
||||
keep: -1
|
||||
|
||||
logging:
|
||||
freq: 10
|
||||
|
||||
eval_on_gpus: 8
|
||||
eval:
|
||||
dataset_dir: ???
|
||||
tasks: ???
|
||||
generator:
|
||||
max_tokens: 65536
|
||||
dtype: bf16
|
||||
|
||||
mp_size: 1
|
|
@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
|||
class BltSequence(BaseModel):
|
||||
tokens: list[int]
|
||||
mask: list[bool]
|
||||
patch_lengths: list[int]
|
||||
patch_lengths: list[int] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -17,6 +17,7 @@ class PackingArgs(BaseModel):
|
|||
max_length: int | None
|
||||
pad_to_max_length: bool
|
||||
enable_byte_ngrams: bool
|
||||
tokenizer_name: str
|
||||
|
||||
|
||||
class PackingIteratorState(BaseModel, IteratorState):
|
||||
|
@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|||
)
|
||||
|
||||
def create_iter(self):
|
||||
if self.packing_args.tokenizer_name == "bytes":
|
||||
return self._create_iter_from_bytes()
|
||||
else:
|
||||
return self._create_iter_from_patch_lengths()
|
||||
|
||||
def _create_iter_from_bytes(self):
|
||||
sequence_iter = self.sequence_iterator.create_iter()
|
||||
batch_size = self.packing_args.batch_size
|
||||
pad_id = self.packing_args.pad_id
|
||||
seq_len = self.packing_args.seq_len
|
||||
while True:
|
||||
tokens: list[list[int]] = []
|
||||
masks: list[list[bool]] = []
|
||||
|
||||
for _ in range(self.packing_args.batch_size):
|
||||
sequence = next(sequence_iter)
|
||||
_tokens = sequence.tokens
|
||||
_mask = sequence.mask
|
||||
assert (
|
||||
sequence.patch_lengths is None
|
||||
), "patch_lengths should not be used in byte packing"
|
||||
tokens.append(_tokens)
|
||||
masks.append(_mask)
|
||||
|
||||
x = np.full((batch_size, seq_len), fill_value=pad_id)
|
||||
y = np.full((batch_size, seq_len), fill_value=pad_id)
|
||||
|
||||
for i, tok_seq in enumerate(tokens):
|
||||
x[i, : len(tok_seq)] = tok_seq
|
||||
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
||||
batch = Batch(x=x, y=y)
|
||||
assert (
|
||||
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
||||
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
||||
yield batch
|
||||
|
||||
def _create_iter_from_patch_lengths(self):
|
||||
sequence_iter = self.sequence_iterator.create_iter()
|
||||
batch_size = self.packing_args.batch_size
|
||||
pad_id = self.packing_args.pad_id
|
||||
|
@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|||
_tokens = sequence.tokens
|
||||
_mask = sequence.mask
|
||||
_patch_lengths = sequence.patch_lengths
|
||||
assert (
|
||||
_patch_lengths is not None
|
||||
), "patch lengths are required for packing based on patches."
|
||||
# Reminder: seq_len is in terms of patches
|
||||
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
||||
last_patch_length = 0
|
||||
if _patch_lengths[0] > 1:
|
||||
|
|
|
@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator):
|
|||
for example in example_iter:
|
||||
assert example.tokens is not None
|
||||
assert example.mask is not None
|
||||
assert example.patch_lengths is not None
|
||||
if self.preprocess_iterator.add_patches:
|
||||
assert example.patch_lengths is not None
|
||||
assert len(example.tokens) == sum(example.patch_lengths)
|
||||
else:
|
||||
assert example.patch_lengths is None
|
||||
assert len(example.tokens) != 0
|
||||
assert len(example.mask) != 0
|
||||
assert len(example.tokens) == len(example.mask)
|
||||
assert len(example.tokens) == sum(example.patch_lengths)
|
||||
|
||||
tokens.extend(example.tokens)
|
||||
mask.extend(example.mask)
|
||||
patch_lengths.extend(example.patch_lengths)
|
||||
if self.preprocess_iterator.add_patches:
|
||||
patch_lengths.extend(example.patch_lengths)
|
||||
else:
|
||||
# This lets the rest of the code work as expected and just yield byte seqs
|
||||
patch_lengths.extend([1] * len(example.tokens))
|
||||
|
||||
while len(patch_lengths) >= n_buffer_patches:
|
||||
if first:
|
||||
|
@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator):
|
|||
== len(seq_mask[idx])
|
||||
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
||||
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
||||
yield BltSequence(
|
||||
tokens=seq_tokens[idx],
|
||||
mask=seq_mask[idx],
|
||||
patch_lengths=seq_patch_lengths[idx],
|
||||
)
|
||||
if self.preprocess_iterator.add_patches:
|
||||
yield BltSequence(
|
||||
tokens=seq_tokens[idx],
|
||||
mask=seq_mask[idx],
|
||||
patch_lengths=seq_patch_lengths[idx],
|
||||
)
|
||||
else:
|
||||
yield BltSequence(
|
||||
tokens=seq_tokens[idx],
|
||||
mask=seq_mask[idx],
|
||||
patch_lengths=None,
|
||||
)
|
||||
|
|
|
@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum):
|
|||
bpe = "bpe"
|
||||
bpe_patcher = "bpe_patcher"
|
||||
space = "space"
|
||||
static = "static"
|
||||
byte = "byte"
|
||||
|
||||
|
||||
class PatcherArgs(BaseModel):
|
||||
|
@ -34,7 +36,6 @@ class PatcherArgs(BaseModel):
|
|||
max_patch_length: int | None = None
|
||||
patch_size: float = 4.5
|
||||
patching_batch_size: int = 1
|
||||
data_loader_patching: bool = False
|
||||
device: str = "cuda"
|
||||
monotonicity: bool = False
|
||||
log_time: bool = False
|
||||
|
@ -486,7 +487,6 @@ class Patcher:
|
|||
self.max_patch_length = patcher_args.max_patch_length
|
||||
self.patch_size = patcher_args.patch_size
|
||||
self.patching_batch_size = patcher_args.patching_batch_size
|
||||
self.data_loader_patching = patcher_args.data_loader_patching
|
||||
self.device = patcher_args.device
|
||||
self.monotonicity = patcher_args.monotonicity
|
||||
self.log_time = patcher_args.log_time
|
||||
|
@ -528,7 +528,7 @@ class Patcher:
|
|||
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
||||
scores = None
|
||||
# STATIC
|
||||
if self.patching_mode is None:
|
||||
if self.patching_mode == PatchingModeEnum.static:
|
||||
patch_lengths = torch.zeros(
|
||||
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
||||
dtype=tokens.dtype,
|
||||
|
@ -536,6 +536,10 @@ class Patcher:
|
|||
).fill_(self.patch_size)
|
||||
if seq_len_next_tok % self.patch_size != 0:
|
||||
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
||||
elif self.patching_mode == PatchingModeEnum.byte:
|
||||
patch_lengths = torch.ones(
|
||||
(bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
|
||||
)
|
||||
# ENTROPY
|
||||
elif self.patching_mode == PatchingModeEnum.entropy:
|
||||
if self.log_time:
|
||||
|
|
|
@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|||
n_heads: int = 8
|
||||
# TODO: What is the purpose of this parameter?
|
||||
weight_tying: bool = False
|
||||
patch_in_forward: bool = False
|
||||
|
||||
# Architecture and dimensions
|
||||
dim_token: int = 256
|
||||
|
@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|||
n_layers_local_encoder: int = 8
|
||||
|
||||
# Tokenization and patching
|
||||
tokenization_mode: str = "bpe"
|
||||
patch_size: float | None = None
|
||||
patching_mode: str | None = None
|
||||
patching_threshold: float | None = None
|
||||
|
@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|||
monotonicity: bool = False
|
||||
patching_batch_size: int = 1
|
||||
patching_device: str = "cuda"
|
||||
data_loader_patching: bool = False
|
||||
max_patch_length: int | None = None
|
||||
|
||||
# Encoder/Decoder configuration
|
||||
|
@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module):
|
|||
self.output.weight = self.tok_embeddings.weight
|
||||
|
||||
# Patcher module
|
||||
if not args.data_loader_patching:
|
||||
if args.patch_in_forward:
|
||||
self.patcher = Patcher(
|
||||
PatcherArgs(
|
||||
patch_size=args.patch_size,
|
||||
|
|
|
@ -68,10 +68,9 @@ def create_args(cross_attention=False):
|
|||
# Additional args from command line
|
||||
dim_token=256,
|
||||
patch_size=6,
|
||||
tokenization_mode="bytes",
|
||||
patching_mode="space",
|
||||
tie_local_encoder_decoder_logits=False,
|
||||
data_loader_patching=True,
|
||||
patch_in_forward=False,
|
||||
max_encoder_seq_length=12288,
|
||||
pad_to_max_length=True,
|
||||
encoder_lm_loss=False,
|
||||
|
|
|
@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD
|
|||
from bytelatent.profiling import maybe_run_profiler
|
||||
from bytelatent.stool import StoolArgs, launch_job
|
||||
from bytelatent.transformer import (
|
||||
LMTransformer,
|
||||
build_fsdp_grouping_plan,
|
||||
get_no_recompute_ops,
|
||||
get_num_flop_per_token,
|
||||
|
@ -103,10 +104,15 @@ class TrainState(Stateful):
|
|||
|
||||
|
||||
def validate_train_args(args: TrainArgs, output_size: int):
|
||||
if args.model.vocab_size < 0:
|
||||
assert args.model is not None or args.entropy_model is not None
|
||||
if args.model is not None:
|
||||
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
||||
args.model.vocab_size = output_size
|
||||
|
||||
if args.entropy_model is not None:
|
||||
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
||||
args.entropy_model.vocab_size = output_size
|
||||
|
||||
assert args.dump_dir, "Dump dir not set"
|
||||
|
||||
if args.checkpoint.path is None:
|
||||
|
@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|||
and args.distributed.dp_replicate == get_world_size()
|
||||
)
|
||||
|
||||
args.model.max_seqlen = args.data.seq_len
|
||||
if args.model is not None:
|
||||
args.model.max_seqlen = args.data.seq_len
|
||||
if args.entropy_model is not None:
|
||||
args.entropy_model.max_seqlen = args.data.seq_len
|
||||
|
||||
if args.distributed.tp_size == 1:
|
||||
logger.warning(
|
||||
|
@ -237,7 +246,14 @@ def train(args: TrainArgs):
|
|||
|
||||
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
||||
with torch.device("meta"):
|
||||
model = ByteLatentTransformer(args.model)
|
||||
if args.train_entropy_model:
|
||||
assert args.entropy_model is not None
|
||||
model = LMTransformer(args.entropy_model)
|
||||
model_args = args.entropy_model
|
||||
else:
|
||||
assert args.model is not None
|
||||
model = ByteLatentTransformer(args.model)
|
||||
model_args = args.model
|
||||
logger.info("Model is built !")
|
||||
|
||||
model_param_count = get_num_params(model)
|
||||
|
@ -247,7 +263,7 @@ def train(args: TrainArgs):
|
|||
world_mesh,
|
||||
args.model,
|
||||
args.distributed,
|
||||
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
|
||||
fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
|
||||
tp_parallelize=tp_parallelize,
|
||||
no_recompute_ops=get_no_recompute_ops(),
|
||||
)
|
||||
|
@ -267,7 +283,7 @@ def train(args: TrainArgs):
|
|||
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||
else:
|
||||
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||
torch.manual_seed(args.model.seed)
|
||||
torch.manual_seed(model_args.seed)
|
||||
model.init_weights()
|
||||
check_model_value_range(model, range=10.0, std=1.0)
|
||||
|
||||
|
@ -342,10 +358,17 @@ def train(args: TrainArgs):
|
|||
batch.x,
|
||||
).cuda()
|
||||
batch_y = torch.from_numpy(batch.y).cuda()
|
||||
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
||||
if batch.patch_lengths is None:
|
||||
batch_patch_lengths = None
|
||||
else:
|
||||
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||
|
||||
if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
|
||||
if (
|
||||
not args.train_entropy_model
|
||||
and args.model.encoder_enable_byte_ngrams
|
||||
and batch.ngram_ids is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot enable byte ngrams and have batch.ngram_ids be None"
|
||||
)
|
||||
|
@ -408,9 +431,12 @@ def train(args: TrainArgs):
|
|||
next(probe_mod.parameters()).grad is None
|
||||
), "Probe model shouldn't have grads at this point"
|
||||
|
||||
pred = model(
|
||||
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
||||
)
|
||||
if args.train_entropy_model:
|
||||
pred = model(batch_x)
|
||||
else:
|
||||
pred = model(
|
||||
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
||||
)
|
||||
|
||||
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
|
||||
|
||||
|
@ -474,9 +500,9 @@ def train(args: TrainArgs):
|
|||
# Use xformer's analyze profile trace to get actual measurement
|
||||
FLOPS = (
|
||||
get_num_flop_per_token(
|
||||
model_param_count - args.model.vocab_size * args.model.dim,
|
||||
args.model.n_layers,
|
||||
args.model.dim,
|
||||
model_param_count - model_args.vocab_size * model_args.dim,
|
||||
model_args.n_layers,
|
||||
model_args.dim,
|
||||
args.data.seq_len,
|
||||
)
|
||||
* wps
|
||||
|
|
Loading…
Reference in a new issue