mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 01:59:05 +00:00
Init distributed when loading model (#94)
Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
19a3f7588d
commit
138c2f3494
2 changed files with 8 additions and 2 deletions
|
@ -301,7 +301,7 @@ def setup_torch_distributed(dist_args: DistributedArgs):
|
||||||
- global_rank
|
- global_rank
|
||||||
- world_size
|
- world_size
|
||||||
"""
|
"""
|
||||||
mp.set_start_method(dist_args.spawn_method)
|
mp.set_start_method(dist_args.spawn_method, force=True)
|
||||||
with mp.Manager():
|
with mp.Manager():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ from bytelatent.checkpoint import (
|
||||||
)
|
)
|
||||||
from bytelatent.config_parser import parse_args_to_pydantic_model
|
from bytelatent.config_parser import parse_args_to_pydantic_model
|
||||||
from bytelatent.data.file_util import get_fs
|
from bytelatent.data.file_util import get_fs
|
||||||
from bytelatent.distributed import get_global_rank
|
from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
|
||||||
from bytelatent.model.blt import ByteLatentTransformer
|
from bytelatent.model.blt import ByteLatentTransformer
|
||||||
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
||||||
from bytelatent.transformer import LMTransformer
|
from bytelatent.transformer import LMTransformer
|
||||||
|
@ -390,7 +390,13 @@ class PackedCausalTransformerGenerator:
|
||||||
|
|
||||||
def load_consolidated_model_and_tokenizer(
|
def load_consolidated_model_and_tokenizer(
|
||||||
consolidated_path,
|
consolidated_path,
|
||||||
|
init_distributed=False
|
||||||
):
|
):
|
||||||
|
if init_distributed:
|
||||||
|
distributed_args = DistributedArgs()
|
||||||
|
distributed_args.configure_world()
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
setup_torch_distributed(distributed_args)
|
||||||
train_args_path = os.path.join(consolidated_path, "params.json")
|
train_args_path = os.path.join(consolidated_path, "params.json")
|
||||||
fs = get_fs(train_args_path)
|
fs = get_fs(train_args_path)
|
||||||
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
||||||
|
|
Loading…
Add table
Reference in a new issue