mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-01 10:09:06 +00:00
Add way to call consolidate (#80)
* Add way to call consolidate * black * isort --------- Co-authored-by: Srini Iyer <sviyer@meta.com>
This commit is contained in:
parent
a5ceaaa226
commit
c110f6be2a
1 changed files with 18 additions and 0 deletions
|
@ -12,6 +12,7 @@ import torch.distributed as dist
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim.optimizer
|
import torch.optim.optimizer
|
||||||
|
import typer
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from torch.distributed._tensor import DeviceMesh
|
from torch.distributed._tensor import DeviceMesh
|
||||||
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
||||||
|
@ -323,3 +324,20 @@ class CheckpointManager:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
return cls(args)
|
return cls(args)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
command: str,
|
||||||
|
model_checkpoint_dir: str,
|
||||||
|
):
|
||||||
|
if command == "consolidate":
|
||||||
|
print(
|
||||||
|
f"Consolidating {model_checkpoint_dir}. Output will be in the {CONSOLIDATE_FOLDER} folder."
|
||||||
|
)
|
||||||
|
consolidate_checkpoints(fsspec.filesystem("file"), model_checkpoint_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid command")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
typer.run(main)
|
||||||
|
|
Loading…
Add table
Reference in a new issue