Add way to call consolidate (#80)
Some checks failed
Lint with Black / lint (push) Has been cancelled
Lint with isort / lint (push) Has been cancelled

* Add way to call consolidate

* black

* isort

---------

Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
Srinivasan Iyer 2025-03-11 16:53:33 -07:00 committed by GitHub
parent a5ceaaa226
commit c110f6be2a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12,6 +12,7 @@ import torch.distributed as dist
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
import torch.optim.optimizer import torch.optim.optimizer
import typer
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from torch.distributed._tensor import DeviceMesh from torch.distributed._tensor import DeviceMesh
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
@ -323,3 +324,20 @@ class CheckpointManager:
dist.barrier() dist.barrier()
return cls(args) return cls(args)
def main(
command: str,
model_checkpoint_dir: str,
):
if command == "consolidate":
print(
f"Consolidating {model_checkpoint_dir}. Output will be in the {CONSOLIDATE_FOLDER} folder."
)
consolidate_checkpoints(fsspec.filesystem("file"), model_checkpoint_dir)
else:
raise ValueError("Invalid command")
if __name__ == "__main__":
typer.run(main)