blt/bytelatent/plotting/scaling_figures.py
2024-12-12 15:32:30 -08:00

109 lines
3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import sys
from pathlib import Path
import altair as alt
import pandas as pd
import pydantic
from omegaconf import OmegaConf
class ScalingPlotsConfig(pydantic.BaseModel):
df_dir: str
output_chart_dir: str
frame_files: list[str]
class Config:
extra = "forbid"
def determine_family(key: str):
if key.startswith("Megabyte++"):
return "Megabyte++"
elif key.startswith("BLT"):
return "BLT"
elif key.startswith("LLaMA"):
return "LLaMA"
elif key.startswith("Space"):
return "Space"
file_to_vars = {}
def create_chart(df: pd.DataFrame, output_file: str):
df["metric"] = df["bpb/not_heldout.jsonl"]
df["family"] = df["key"].map(determine_family)
model_domain = [
"BLT Space ps=6",
"BLT Space w/o cross-attn",
"SpaceByte",
"LLaMA 3 BPE",
"Megabyte++ ps=4",
"Megabyte++ ps=6",
]
color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
shape_range = [
"circle",
"square",
"cross",
"diamond",
"triangle-up",
"triangle-down",
]
color_scale = alt.Scale(domain=model_domain, range=color_range)
shape_scale = alt.Scale(
domain=model_domain,
range=shape_range,
)
base_chart = alt.Chart(df).encode(
x=alt.X("flops", title="Training FLOPS")
.scale(type="log", domain=[2e20, 1.25e22])
.axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
)
lines = base_chart.encode(
color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
).mark_line()
points = base_chart.encode(
color=alt.Color("key", title="Model", scale=color_scale),
shape=alt.Shape("key", title="", scale=shape_scale),
).mark_point(size=70)
chart = (
(lines + points)
.resolve_scale(
color="independent",
shape="independent",
# strokeDash="independent",
)
.configure_legend(orient="right")
.properties(height=300, width=400)
)
print("Saving", output_file)
chart.save(output_file)
def main():
config_path = sys.argv[1]
file_config = OmegaConf.load(config_path)
# Omit program name and config file name
cli_conf = OmegaConf.from_cli(sys.argv[2:])
conf_dict = OmegaConf.to_container(
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
)
plot_config = ScalingPlotsConfig(**conf_dict)
df_dir = Path(plot_config.df_dir)
chart_dir = Path(plot_config.output_chart_dir)
chart_dir.mkdir(exist_ok=True, parents=True)
for ff in plot_config.frame_files:
path = df_dir / ff
df = pd.read_json(path)
print(df)
print(df.columns)
create_chart(df, chart_dir / f"{path.name}.pdf")
if __name__ == "__main__":
main()