mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
134 lines
4.3 KiB
Python
134 lines
4.3 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||
|
|
||
|
import contextlib
|
||
|
import logging
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
|
||
|
import torch.distributed
|
||
|
import wandb
|
||
|
import xformers.profiler
|
||
|
from pydantic import BaseModel
|
||
|
from torch.profiler.profiler import profile
|
||
|
from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler
|
||
|
|
||
|
from bytelatent.distributed import get_is_master
|
||
|
|
||
|
|
||
|
class ProfilerArgs(BaseModel):
|
||
|
run: bool = False
|
||
|
trace_folder: str = "profiling"
|
||
|
mem_warmup: int = 100
|
||
|
mem_steps: int = 2
|
||
|
profile_warmup: int = 102
|
||
|
profile_steps: int = 2
|
||
|
|
||
|
|
||
|
logger = logging.getLogger()
|
||
|
|
||
|
|
||
|
def perfetto_to_html(json_file, html_file):
|
||
|
import gzip
|
||
|
import string
|
||
|
|
||
|
import viztracer
|
||
|
|
||
|
root = os.path.dirname(viztracer.__file__)
|
||
|
sub = {}
|
||
|
json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
|
||
|
with open(
|
||
|
os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
|
||
|
) as f:
|
||
|
tmpl = f.read()
|
||
|
with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
|
||
|
sub["trace_viewer_full"] = f.read()
|
||
|
with json_file as j:
|
||
|
content = j.read()
|
||
|
if isinstance(content, bytes):
|
||
|
content = content.decode("utf-8")
|
||
|
sub["json_data"] = content.replace("</script>", "<\\/script>") # type: ignore
|
||
|
with open(html_file, "w+", encoding="utf-8") as output_file:
|
||
|
output_file.write(string.Template(tmpl).substitute(sub))
|
||
|
|
||
|
|
||
|
class PyTorchProfilerWandb(PyTorchProfiler):
|
||
|
def __init__(self, main_profiler) -> None:
|
||
|
self.main_profiler = main_profiler
|
||
|
self.num_steps = 0
|
||
|
self.pytorch_profiler = torch.profiler.profile(
|
||
|
on_trace_ready=self._on_trace,
|
||
|
profile_memory=True,
|
||
|
record_shapes=True,
|
||
|
# With stack gives huge profile traces
|
||
|
# and bugs out because of some non ascii
|
||
|
# character somewhere in pytorch
|
||
|
with_stack=False,
|
||
|
with_flops=True,
|
||
|
activities=self.ACTIVITIES,
|
||
|
)
|
||
|
|
||
|
def _analyze_trace(self, prof: profile):
|
||
|
logger.info("Begin analyze trace")
|
||
|
super()._analyze_trace(prof)
|
||
|
logger.info("End analyze trace")
|
||
|
|
||
|
def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
|
||
|
super()._on_trace(prof)
|
||
|
if get_is_master() and wandb.run is not None:
|
||
|
filename = list(
|
||
|
Path(self.main_profiler.output_dir).glob(
|
||
|
"profile_CPU_CUDA*/*.pt.trace.json*"
|
||
|
)
|
||
|
)[0]
|
||
|
html_path = str(filename).replace(".json", ".html")
|
||
|
perfetto_to_html(filename, html_path)
|
||
|
wandb.log({"profile_trace": wandb.Html(html_path)})
|
||
|
|
||
|
|
||
|
class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
|
||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||
|
if get_is_master() and wandb.run is not None:
|
||
|
filename = list(
|
||
|
Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
|
||
|
)[0]
|
||
|
wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
|
||
|
# get user defined profiler settings
|
||
|
|
||
|
if config.run:
|
||
|
trace_dir = os.path.join(dump_dir, config.trace_folder)
|
||
|
|
||
|
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
|
||
|
|
||
|
if get_is_master() and not os.path.exists(trace_dir):
|
||
|
os.makedirs(trace_dir)
|
||
|
if torch.distributed.is_initialized():
|
||
|
torch.distributed.barrier()
|
||
|
|
||
|
with xformers.profiler.profile(
|
||
|
output_dir=trace_dir,
|
||
|
module=module,
|
||
|
schedule=[
|
||
|
(
|
||
|
MemSnapshotsProfilerWandb,
|
||
|
config.mem_warmup,
|
||
|
config.mem_warmup + config.mem_steps,
|
||
|
),
|
||
|
(
|
||
|
PyTorchProfilerWandb,
|
||
|
config.profile_warmup,
|
||
|
config.profile_warmup + config.profile_steps,
|
||
|
),
|
||
|
],
|
||
|
) as profiler:
|
||
|
yield profiler
|
||
|
|
||
|
else:
|
||
|
torch_profiler = contextlib.nullcontext()
|
||
|
yield None
|