Add way to call consolidate ()

* Add way to call consolidate

* black

* isort

---------

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-03-11 16:53:33 -07:00 committed by GitHub
parent a5ceaaa226
commit c110f6be2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12,6 +12,7 @@ import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.optim.optimizer
import typer
from pydantic import BaseModel, ConfigDict
from torch.distributed._tensor import DeviceMesh
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
@ -323,3 +324,20 @@ class CheckpointManager:
dist.barrier()
return cls(args)
def main(
command: str,
model_checkpoint_dir: str,
):
if command == "consolidate":
print(
f"Consolidating {model_checkpoint_dir}. Output will be in the {CONSOLIDATE_FOLDER} folder."
)
consolidate_checkpoints(fsspec.filesystem("file"), model_checkpoint_dir)
else:
raise ValueError("Invalid command")
if __name__ == "__main__":
typer.run(main)