From 138c2f3494ffcced2e096d90827eccdd5c692c98 Mon Sep 17 00:00:00 2001 From: Srinivasan Iyer Date: Tue, 8 Apr 2025 13:57:28 -0700 Subject: [PATCH] Init distributed when loading model (#94) Co-authored-by: Srini Iyer --- bytelatent/distributed.py | 2 +- bytelatent/generate.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/bytelatent/distributed.py b/bytelatent/distributed.py index d6dc5a5..5e76a55 100644 --- a/bytelatent/distributed.py +++ b/bytelatent/distributed.py @@ -301,7 +301,7 @@ def setup_torch_distributed(dist_args: DistributedArgs): - global_rank - world_size """ - mp.set_start_method(dist_args.spawn_method) + mp.set_start_method(dist_args.spawn_method, force=True) with mp.Manager(): pass diff --git a/bytelatent/generate.py b/bytelatent/generate.py index c76360e..b0280dd 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -25,7 +25,7 @@ from bytelatent.checkpoint import ( ) from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs -from bytelatent.distributed import get_global_rank +from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs from bytelatent.model.blt import ByteLatentTransformer from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.transformer import LMTransformer @@ -390,7 +390,13 @@ class PackedCausalTransformerGenerator: def load_consolidated_model_and_tokenizer( consolidated_path, + init_distributed=False ): + if init_distributed: + distributed_args = DistributedArgs() + distributed_args.configure_world() + if not torch.distributed.is_initialized(): + setup_torch_distributed(distributed_args) train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))