Init distributed when loading model

This commit is contained in:
Srini Iyer 2025-04-08 18:18:40 +00:00
parent 19a3f7588d
commit a90d950d70
2 changed files with 8 additions and 2 deletions

View file

@ -301,7 +301,7 @@ def setup_torch_distributed(dist_args: DistributedArgs):
- global_rank
- world_size
"""
mp.set_start_method(dist_args.spawn_method)
mp.set_start_method(dist_args.spawn_method, force=True)
with mp.Manager():
pass

View file

@ -25,7 +25,7 @@ from bytelatent.checkpoint import (
)
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import get_global_rank
from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
from bytelatent.model.blt import ByteLatentTransformer
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
from bytelatent.transformer import LMTransformer
@ -390,7 +390,13 @@ class PackedCausalTransformerGenerator:
def load_consolidated_model_and_tokenizer(
consolidated_path,
init_distributed=False
):
if init_distributed:
distributed_args = DistributedArgs()
distributed_args.configure_world()
if not torch.distributed.is_initialized():
setup_torch_distributed(distributed_args)
train_args_path = os.path.join(consolidated_path, "params.json")
fs = get_fs(train_args_path)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))