mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-31 01:52:15 +00:00
Initial codes and scripts for training entropy model (#34)
Summary: Test Plan:
This commit is contained in:
parent
a809259e71
commit
7622d28b74
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -166,3 +166,4 @@ figures/
|
||||||
.vscode/
|
.vscode/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
internal/
|
internal/
|
||||||
|
jobs_parallel-copy/
|
||||||
|
|
|
@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel):
|
||||||
max_encoder_seq_length: int = 12288
|
max_encoder_seq_length: int = 12288
|
||||||
enable_byte_ngrams: bool = False
|
enable_byte_ngrams: bool = False
|
||||||
|
|
||||||
|
add_patches: bool = True
|
||||||
|
|
||||||
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
||||||
patcher_args: PatcherArgs = PatcherArgs()
|
patcher_args: PatcherArgs = PatcherArgs()
|
||||||
|
|
||||||
|
@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel):
|
||||||
looping_iterator,
|
looping_iterator,
|
||||||
patcher_args=self.patcher_args,
|
patcher_args=self.patcher_args,
|
||||||
tokenizer_args=self.tokenizer_args,
|
tokenizer_args=self.tokenizer_args,
|
||||||
|
add_patches=self.add_patches,
|
||||||
)
|
)
|
||||||
sequence_iterator = SequenceIterator(
|
sequence_iterator = SequenceIterator(
|
||||||
preprocess_iterator,
|
preprocess_iterator,
|
||||||
|
@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel):
|
||||||
source_to_iterator=source_to_sequence_iterators,
|
source_to_iterator=source_to_sequence_iterators,
|
||||||
)
|
)
|
||||||
tokenizer = self.tokenizer_args.build()
|
tokenizer = self.tokenizer_args.build()
|
||||||
|
if self.tokenizer_args.name == "bytes":
|
||||||
|
# TODO: Check this with Artidoro
|
||||||
|
pad_id = 0
|
||||||
|
else:
|
||||||
|
pad_id = tokenizer.boe_id
|
||||||
packing_args = PackingArgs(
|
packing_args = PackingArgs(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
seq_len=self.seq_len,
|
seq_len=self.seq_len,
|
||||||
pad_id=tokenizer.boe_id,
|
pad_id=pad_id,
|
||||||
max_length=self.max_encoder_seq_length,
|
max_length=self.max_encoder_seq_length,
|
||||||
pad_to_max_length=self.pad_to_max_length,
|
pad_to_max_length=self.pad_to_max_length,
|
||||||
enable_byte_ngrams=self.enable_byte_ngrams,
|
enable_byte_ngrams=self.enable_byte_ngrams,
|
||||||
|
tokenizer_name=self.tokenizer_args.name,
|
||||||
)
|
)
|
||||||
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
||||||
if self.load_async:
|
if self.load_async:
|
||||||
|
@ -180,7 +189,7 @@ class TrainArgs(BaseModel):
|
||||||
|
|
||||||
data: DataloaderArgs = DataloaderArgs()
|
data: DataloaderArgs = DataloaderArgs()
|
||||||
optim: OptimArgs = OptimArgs()
|
optim: OptimArgs = OptimArgs()
|
||||||
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
|
||||||
# This is only needed for training the entropy model
|
# This is only needed for training the entropy model
|
||||||
entropy_model: LMTransformerArgs | None = None
|
entropy_model: LMTransformerArgs | None = None
|
||||||
# Instead of training main model, train entropy model
|
# Instead of training main model, train entropy model
|
||||||
|
|
|
@ -26,10 +26,9 @@ model:
|
||||||
vocab_size: 260
|
vocab_size: 260
|
||||||
dim_token: 256
|
dim_token: 256
|
||||||
patch_size: 6
|
patch_size: 6
|
||||||
tokenization_mode: "bytes"
|
|
||||||
patching_mode: "space"
|
patching_mode: "space"
|
||||||
tie_local_encoder_decoder_logits: false
|
tie_local_encoder_decoder_logits: false
|
||||||
data_loader_patching: true
|
patch_in_forward: false
|
||||||
max_encoder_seq_length: 12288
|
max_encoder_seq_length: 12288
|
||||||
pad_to_max_length: true
|
pad_to_max_length: true
|
||||||
patching_threshold: 3.1439168453216553
|
patching_threshold: 3.1439168453216553
|
||||||
|
|
82
bytelatent/configs/entropy_model.yaml
Normal file
82
bytelatent/configs/entropy_model.yaml
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
|
||||||
|
# Evals can be activated by uncommenting its config
|
||||||
|
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
|
||||||
|
|
||||||
|
dump_dir: /tmp/
|
||||||
|
name: "debug"
|
||||||
|
steps: 100_000
|
||||||
|
probe_freq: null
|
||||||
|
seed: 777
|
||||||
|
optim:
|
||||||
|
lr: 4e-04
|
||||||
|
warmup: 500
|
||||||
|
lr_min_ratio: 0.1
|
||||||
|
clip: 10.0
|
||||||
|
|
||||||
|
distributed:
|
||||||
|
fsdp_type: full_shard
|
||||||
|
model_dtype: bf16
|
||||||
|
matmul_allow_tf32: false
|
||||||
|
selective_activation_checkpointing: false
|
||||||
|
tp_size: 1
|
||||||
|
|
||||||
|
train_entropy_model: true
|
||||||
|
model: null
|
||||||
|
entropy_model:
|
||||||
|
dim: 768
|
||||||
|
n_layers: 14
|
||||||
|
n_heads: 12
|
||||||
|
max_seqlen: 8192
|
||||||
|
# vocab_size: -1
|
||||||
|
vocab_size: 260
|
||||||
|
ffn_dim_multiplier: 1.0
|
||||||
|
sliding_window: 512
|
||||||
|
attn_bias_type: "local_block_causal"
|
||||||
|
attn_impl: "xformers"
|
||||||
|
|
||||||
|
data:
|
||||||
|
s3_profile: blt
|
||||||
|
root_dir: ???
|
||||||
|
sources:
|
||||||
|
dclm_baseline_1.0: 1.0
|
||||||
|
batch_size: 2
|
||||||
|
prefetch_size: 64
|
||||||
|
# seqlen is in terms of patches and
|
||||||
|
# max_encoder_seq_length is in terms of bytes.
|
||||||
|
# For entropy model, these are the same since 1 patch=1 byte
|
||||||
|
seq_len: 8192
|
||||||
|
max_encoder_seq_length: 8192
|
||||||
|
load_async: true
|
||||||
|
preprocess_dir: ???
|
||||||
|
# We don't need patches for this model
|
||||||
|
add_patches: false
|
||||||
|
patcher_args:
|
||||||
|
# This doesn't matter since byte entropy model doesn't use patching,
|
||||||
|
# so pick the most efficient, so static
|
||||||
|
patching_mode: byte
|
||||||
|
tokenizer_args:
|
||||||
|
name: bytes
|
||||||
|
|
||||||
|
profiling:
|
||||||
|
run: false
|
||||||
|
|
||||||
|
checkpoint:
|
||||||
|
dump:
|
||||||
|
every: 500
|
||||||
|
keep: 3
|
||||||
|
eval:
|
||||||
|
every: 1000
|
||||||
|
keep: -1
|
||||||
|
|
||||||
|
logging:
|
||||||
|
freq: 10
|
||||||
|
|
||||||
|
eval_on_gpus: 8
|
||||||
|
eval:
|
||||||
|
dataset_dir: ???
|
||||||
|
tasks: ???
|
||||||
|
generator:
|
||||||
|
max_tokens: 65536
|
||||||
|
dtype: bf16
|
||||||
|
|
||||||
|
mp_size: 1
|
|
@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
||||||
class BltSequence(BaseModel):
|
class BltSequence(BaseModel):
|
||||||
tokens: list[int]
|
tokens: list[int]
|
||||||
mask: list[bool]
|
mask: list[bool]
|
||||||
patch_lengths: list[int]
|
patch_lengths: list[int] | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -17,6 +17,7 @@ class PackingArgs(BaseModel):
|
||||||
max_length: int | None
|
max_length: int | None
|
||||||
pad_to_max_length: bool
|
pad_to_max_length: bool
|
||||||
enable_byte_ngrams: bool
|
enable_byte_ngrams: bool
|
||||||
|
tokenizer_name: str
|
||||||
|
|
||||||
|
|
||||||
class PackingIteratorState(BaseModel, IteratorState):
|
class PackingIteratorState(BaseModel, IteratorState):
|
||||||
|
@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_iter(self):
|
def create_iter(self):
|
||||||
|
if self.packing_args.tokenizer_name == "bytes":
|
||||||
|
return self._create_iter_from_bytes()
|
||||||
|
else:
|
||||||
|
return self._create_iter_from_patch_lengths()
|
||||||
|
|
||||||
|
def _create_iter_from_bytes(self):
|
||||||
|
sequence_iter = self.sequence_iterator.create_iter()
|
||||||
|
batch_size = self.packing_args.batch_size
|
||||||
|
pad_id = self.packing_args.pad_id
|
||||||
|
seq_len = self.packing_args.seq_len
|
||||||
|
while True:
|
||||||
|
tokens: list[list[int]] = []
|
||||||
|
masks: list[list[bool]] = []
|
||||||
|
|
||||||
|
for _ in range(self.packing_args.batch_size):
|
||||||
|
sequence = next(sequence_iter)
|
||||||
|
_tokens = sequence.tokens
|
||||||
|
_mask = sequence.mask
|
||||||
|
assert (
|
||||||
|
sequence.patch_lengths is None
|
||||||
|
), "patch_lengths should not be used in byte packing"
|
||||||
|
tokens.append(_tokens)
|
||||||
|
masks.append(_mask)
|
||||||
|
|
||||||
|
x = np.full((batch_size, seq_len), fill_value=pad_id)
|
||||||
|
y = np.full((batch_size, seq_len), fill_value=pad_id)
|
||||||
|
|
||||||
|
for i, tok_seq in enumerate(tokens):
|
||||||
|
x[i, : len(tok_seq)] = tok_seq
|
||||||
|
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
||||||
|
batch = Batch(x=x, y=y)
|
||||||
|
assert (
|
||||||
|
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
||||||
|
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def _create_iter_from_patch_lengths(self):
|
||||||
sequence_iter = self.sequence_iterator.create_iter()
|
sequence_iter = self.sequence_iterator.create_iter()
|
||||||
batch_size = self.packing_args.batch_size
|
batch_size = self.packing_args.batch_size
|
||||||
pad_id = self.packing_args.pad_id
|
pad_id = self.packing_args.pad_id
|
||||||
|
@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
||||||
_tokens = sequence.tokens
|
_tokens = sequence.tokens
|
||||||
_mask = sequence.mask
|
_mask = sequence.mask
|
||||||
_patch_lengths = sequence.patch_lengths
|
_patch_lengths = sequence.patch_lengths
|
||||||
|
assert (
|
||||||
|
_patch_lengths is not None
|
||||||
|
), "patch lengths are required for packing based on patches."
|
||||||
|
# Reminder: seq_len is in terms of patches
|
||||||
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
||||||
last_patch_length = 0
|
last_patch_length = 0
|
||||||
if _patch_lengths[0] > 1:
|
if _patch_lengths[0] > 1:
|
||||||
|
|
|
@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator):
|
||||||
for example in example_iter:
|
for example in example_iter:
|
||||||
assert example.tokens is not None
|
assert example.tokens is not None
|
||||||
assert example.mask is not None
|
assert example.mask is not None
|
||||||
|
if self.preprocess_iterator.add_patches:
|
||||||
assert example.patch_lengths is not None
|
assert example.patch_lengths is not None
|
||||||
|
assert len(example.tokens) == sum(example.patch_lengths)
|
||||||
|
else:
|
||||||
|
assert example.patch_lengths is None
|
||||||
assert len(example.tokens) != 0
|
assert len(example.tokens) != 0
|
||||||
assert len(example.mask) != 0
|
assert len(example.mask) != 0
|
||||||
assert len(example.tokens) == len(example.mask)
|
assert len(example.tokens) == len(example.mask)
|
||||||
assert len(example.tokens) == sum(example.patch_lengths)
|
|
||||||
|
|
||||||
tokens.extend(example.tokens)
|
tokens.extend(example.tokens)
|
||||||
mask.extend(example.mask)
|
mask.extend(example.mask)
|
||||||
|
if self.preprocess_iterator.add_patches:
|
||||||
patch_lengths.extend(example.patch_lengths)
|
patch_lengths.extend(example.patch_lengths)
|
||||||
|
else:
|
||||||
|
# This lets the rest of the code work as expected and just yield byte seqs
|
||||||
|
patch_lengths.extend([1] * len(example.tokens))
|
||||||
|
|
||||||
while len(patch_lengths) >= n_buffer_patches:
|
while len(patch_lengths) >= n_buffer_patches:
|
||||||
if first:
|
if first:
|
||||||
|
@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator):
|
||||||
== len(seq_mask[idx])
|
== len(seq_mask[idx])
|
||||||
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
||||||
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
||||||
|
if self.preprocess_iterator.add_patches:
|
||||||
yield BltSequence(
|
yield BltSequence(
|
||||||
tokens=seq_tokens[idx],
|
tokens=seq_tokens[idx],
|
||||||
mask=seq_mask[idx],
|
mask=seq_mask[idx],
|
||||||
patch_lengths=seq_patch_lengths[idx],
|
patch_lengths=seq_patch_lengths[idx],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
yield BltSequence(
|
||||||
|
tokens=seq_tokens[idx],
|
||||||
|
mask=seq_mask[idx],
|
||||||
|
patch_lengths=None,
|
||||||
|
)
|
||||||
|
|
|
@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum):
|
||||||
bpe = "bpe"
|
bpe = "bpe"
|
||||||
bpe_patcher = "bpe_patcher"
|
bpe_patcher = "bpe_patcher"
|
||||||
space = "space"
|
space = "space"
|
||||||
|
static = "static"
|
||||||
|
byte = "byte"
|
||||||
|
|
||||||
|
|
||||||
class PatcherArgs(BaseModel):
|
class PatcherArgs(BaseModel):
|
||||||
|
@ -34,7 +36,6 @@ class PatcherArgs(BaseModel):
|
||||||
max_patch_length: int | None = None
|
max_patch_length: int | None = None
|
||||||
patch_size: float = 4.5
|
patch_size: float = 4.5
|
||||||
patching_batch_size: int = 1
|
patching_batch_size: int = 1
|
||||||
data_loader_patching: bool = False
|
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
monotonicity: bool = False
|
monotonicity: bool = False
|
||||||
log_time: bool = False
|
log_time: bool = False
|
||||||
|
@ -486,7 +487,6 @@ class Patcher:
|
||||||
self.max_patch_length = patcher_args.max_patch_length
|
self.max_patch_length = patcher_args.max_patch_length
|
||||||
self.patch_size = patcher_args.patch_size
|
self.patch_size = patcher_args.patch_size
|
||||||
self.patching_batch_size = patcher_args.patching_batch_size
|
self.patching_batch_size = patcher_args.patching_batch_size
|
||||||
self.data_loader_patching = patcher_args.data_loader_patching
|
|
||||||
self.device = patcher_args.device
|
self.device = patcher_args.device
|
||||||
self.monotonicity = patcher_args.monotonicity
|
self.monotonicity = patcher_args.monotonicity
|
||||||
self.log_time = patcher_args.log_time
|
self.log_time = patcher_args.log_time
|
||||||
|
@ -528,7 +528,7 @@ class Patcher:
|
||||||
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
||||||
scores = None
|
scores = None
|
||||||
# STATIC
|
# STATIC
|
||||||
if self.patching_mode is None:
|
if self.patching_mode == PatchingModeEnum.static:
|
||||||
patch_lengths = torch.zeros(
|
patch_lengths = torch.zeros(
|
||||||
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
||||||
dtype=tokens.dtype,
|
dtype=tokens.dtype,
|
||||||
|
@ -536,6 +536,10 @@ class Patcher:
|
||||||
).fill_(self.patch_size)
|
).fill_(self.patch_size)
|
||||||
if seq_len_next_tok % self.patch_size != 0:
|
if seq_len_next_tok % self.patch_size != 0:
|
||||||
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
||||||
|
elif self.patching_mode == PatchingModeEnum.byte:
|
||||||
|
patch_lengths = torch.ones(
|
||||||
|
(bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
|
||||||
|
)
|
||||||
# ENTROPY
|
# ENTROPY
|
||||||
elif self.patching_mode == PatchingModeEnum.entropy:
|
elif self.patching_mode == PatchingModeEnum.entropy:
|
||||||
if self.log_time:
|
if self.log_time:
|
||||||
|
|
|
@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
n_heads: int = 8
|
n_heads: int = 8
|
||||||
# TODO: What is the purpose of this parameter?
|
# TODO: What is the purpose of this parameter?
|
||||||
weight_tying: bool = False
|
weight_tying: bool = False
|
||||||
|
patch_in_forward: bool = False
|
||||||
|
|
||||||
# Architecture and dimensions
|
# Architecture and dimensions
|
||||||
dim_token: int = 256
|
dim_token: int = 256
|
||||||
|
@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
n_layers_local_encoder: int = 8
|
n_layers_local_encoder: int = 8
|
||||||
|
|
||||||
# Tokenization and patching
|
# Tokenization and patching
|
||||||
tokenization_mode: str = "bpe"
|
|
||||||
patch_size: float | None = None
|
patch_size: float | None = None
|
||||||
patching_mode: str | None = None
|
patching_mode: str | None = None
|
||||||
patching_threshold: float | None = None
|
patching_threshold: float | None = None
|
||||||
|
@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
||||||
monotonicity: bool = False
|
monotonicity: bool = False
|
||||||
patching_batch_size: int = 1
|
patching_batch_size: int = 1
|
||||||
patching_device: str = "cuda"
|
patching_device: str = "cuda"
|
||||||
data_loader_patching: bool = False
|
|
||||||
max_patch_length: int | None = None
|
max_patch_length: int | None = None
|
||||||
|
|
||||||
# Encoder/Decoder configuration
|
# Encoder/Decoder configuration
|
||||||
|
@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module):
|
||||||
self.output.weight = self.tok_embeddings.weight
|
self.output.weight = self.tok_embeddings.weight
|
||||||
|
|
||||||
# Patcher module
|
# Patcher module
|
||||||
if not args.data_loader_patching:
|
if args.patch_in_forward:
|
||||||
self.patcher = Patcher(
|
self.patcher = Patcher(
|
||||||
PatcherArgs(
|
PatcherArgs(
|
||||||
patch_size=args.patch_size,
|
patch_size=args.patch_size,
|
||||||
|
|
|
@ -68,10 +68,9 @@ def create_args(cross_attention=False):
|
||||||
# Additional args from command line
|
# Additional args from command line
|
||||||
dim_token=256,
|
dim_token=256,
|
||||||
patch_size=6,
|
patch_size=6,
|
||||||
tokenization_mode="bytes",
|
|
||||||
patching_mode="space",
|
patching_mode="space",
|
||||||
tie_local_encoder_decoder_logits=False,
|
tie_local_encoder_decoder_logits=False,
|
||||||
data_loader_patching=True,
|
patch_in_forward=False,
|
||||||
max_encoder_seq_length=12288,
|
max_encoder_seq_length=12288,
|
||||||
pad_to_max_length=True,
|
pad_to_max_length=True,
|
||||||
encoder_lm_loss=False,
|
encoder_lm_loss=False,
|
||||||
|
|
|
@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD
|
||||||
from bytelatent.profiling import maybe_run_profiler
|
from bytelatent.profiling import maybe_run_profiler
|
||||||
from bytelatent.stool import StoolArgs, launch_job
|
from bytelatent.stool import StoolArgs, launch_job
|
||||||
from bytelatent.transformer import (
|
from bytelatent.transformer import (
|
||||||
|
LMTransformer,
|
||||||
build_fsdp_grouping_plan,
|
build_fsdp_grouping_plan,
|
||||||
get_no_recompute_ops,
|
get_no_recompute_ops,
|
||||||
get_num_flop_per_token,
|
get_num_flop_per_token,
|
||||||
|
@ -103,10 +104,15 @@ class TrainState(Stateful):
|
||||||
|
|
||||||
|
|
||||||
def validate_train_args(args: TrainArgs, output_size: int):
|
def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
if args.model.vocab_size < 0:
|
assert args.model is not None or args.entropy_model is not None
|
||||||
|
if args.model is not None:
|
||||||
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
||||||
args.model.vocab_size = output_size
|
args.model.vocab_size = output_size
|
||||||
|
|
||||||
|
if args.entropy_model is not None:
|
||||||
|
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
||||||
|
args.entropy_model.vocab_size = output_size
|
||||||
|
|
||||||
assert args.dump_dir, "Dump dir not set"
|
assert args.dump_dir, "Dump dir not set"
|
||||||
|
|
||||||
if args.checkpoint.path is None:
|
if args.checkpoint.path is None:
|
||||||
|
@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
||||||
and args.distributed.dp_replicate == get_world_size()
|
and args.distributed.dp_replicate == get_world_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.model is not None:
|
||||||
args.model.max_seqlen = args.data.seq_len
|
args.model.max_seqlen = args.data.seq_len
|
||||||
|
if args.entropy_model is not None:
|
||||||
|
args.entropy_model.max_seqlen = args.data.seq_len
|
||||||
|
|
||||||
if args.distributed.tp_size == 1:
|
if args.distributed.tp_size == 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -237,7 +246,14 @@ def train(args: TrainArgs):
|
||||||
|
|
||||||
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
|
if args.train_entropy_model:
|
||||||
|
assert args.entropy_model is not None
|
||||||
|
model = LMTransformer(args.entropy_model)
|
||||||
|
model_args = args.entropy_model
|
||||||
|
else:
|
||||||
|
assert args.model is not None
|
||||||
model = ByteLatentTransformer(args.model)
|
model = ByteLatentTransformer(args.model)
|
||||||
|
model_args = args.model
|
||||||
logger.info("Model is built !")
|
logger.info("Model is built !")
|
||||||
|
|
||||||
model_param_count = get_num_params(model)
|
model_param_count = get_num_params(model)
|
||||||
|
@ -247,7 +263,7 @@ def train(args: TrainArgs):
|
||||||
world_mesh,
|
world_mesh,
|
||||||
args.model,
|
args.model,
|
||||||
args.distributed,
|
args.distributed,
|
||||||
fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
|
fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
|
||||||
tp_parallelize=tp_parallelize,
|
tp_parallelize=tp_parallelize,
|
||||||
no_recompute_ops=get_no_recompute_ops(),
|
no_recompute_ops=get_no_recompute_ops(),
|
||||||
)
|
)
|
||||||
|
@ -267,7 +283,7 @@ def train(args: TrainArgs):
|
||||||
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
||||||
else:
|
else:
|
||||||
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
||||||
torch.manual_seed(args.model.seed)
|
torch.manual_seed(model_args.seed)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
check_model_value_range(model, range=10.0, std=1.0)
|
check_model_value_range(model, range=10.0, std=1.0)
|
||||||
|
|
||||||
|
@ -342,10 +358,17 @@ def train(args: TrainArgs):
|
||||||
batch.x,
|
batch.x,
|
||||||
).cuda()
|
).cuda()
|
||||||
batch_y = torch.from_numpy(batch.y).cuda()
|
batch_y = torch.from_numpy(batch.y).cuda()
|
||||||
|
if batch.patch_lengths is None:
|
||||||
|
batch_patch_lengths = None
|
||||||
|
else:
|
||||||
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
||||||
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
||||||
|
|
||||||
if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
|
if (
|
||||||
|
not args.train_entropy_model
|
||||||
|
and args.model.encoder_enable_byte_ngrams
|
||||||
|
and batch.ngram_ids is None
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot enable byte ngrams and have batch.ngram_ids be None"
|
"Cannot enable byte ngrams and have batch.ngram_ids be None"
|
||||||
)
|
)
|
||||||
|
@ -408,6 +431,9 @@ def train(args: TrainArgs):
|
||||||
next(probe_mod.parameters()).grad is None
|
next(probe_mod.parameters()).grad is None
|
||||||
), "Probe model shouldn't have grads at this point"
|
), "Probe model shouldn't have grads at this point"
|
||||||
|
|
||||||
|
if args.train_entropy_model:
|
||||||
|
pred = model(batch_x)
|
||||||
|
else:
|
||||||
pred = model(
|
pred = model(
|
||||||
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
||||||
)
|
)
|
||||||
|
@ -474,9 +500,9 @@ def train(args: TrainArgs):
|
||||||
# Use xformer's analyze profile trace to get actual measurement
|
# Use xformer's analyze profile trace to get actual measurement
|
||||||
FLOPS = (
|
FLOPS = (
|
||||||
get_num_flop_per_token(
|
get_num_flop_per_token(
|
||||||
model_param_count - args.model.vocab_size * args.model.dim,
|
model_param_count - model_args.vocab_size * model_args.dim,
|
||||||
args.model.n_layers,
|
model_args.n_layers,
|
||||||
args.model.dim,
|
model_args.dim,
|
||||||
args.data.seq_len,
|
args.data.seq_len,
|
||||||
)
|
)
|
||||||
* wps
|
* wps
|
||||||
|
|
Loading…
Reference in a new issue