mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
109 lines
3 KiB
Python
109 lines
3 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import altair as alt
|
|
import pandas as pd
|
|
import pydantic
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
class ScalingPlotsConfig(pydantic.BaseModel):
|
|
df_dir: str
|
|
output_chart_dir: str
|
|
frame_files: list[str]
|
|
|
|
class Config:
|
|
extra = "forbid"
|
|
|
|
|
|
def determine_family(key: str):
|
|
if key.startswith("Megabyte++"):
|
|
return "Megabyte++"
|
|
elif key.startswith("BLT"):
|
|
return "BLT"
|
|
elif key.startswith("LLaMA"):
|
|
return "LLaMA"
|
|
elif key.startswith("Space"):
|
|
return "Space"
|
|
|
|
|
|
file_to_vars = {}
|
|
|
|
|
|
def create_chart(df: pd.DataFrame, output_file: str):
|
|
df["metric"] = df["bpb/not_heldout.jsonl"]
|
|
df["family"] = df["key"].map(determine_family)
|
|
model_domain = [
|
|
"BLT Space ps=6",
|
|
"BLT Space w/o cross-attn",
|
|
"SpaceByte",
|
|
"LLaMA 3 BPE",
|
|
"Megabyte++ ps=4",
|
|
"Megabyte++ ps=6",
|
|
]
|
|
color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
|
|
shape_range = [
|
|
"circle",
|
|
"square",
|
|
"cross",
|
|
"diamond",
|
|
"triangle-up",
|
|
"triangle-down",
|
|
]
|
|
color_scale = alt.Scale(domain=model_domain, range=color_range)
|
|
shape_scale = alt.Scale(
|
|
domain=model_domain,
|
|
range=shape_range,
|
|
)
|
|
base_chart = alt.Chart(df).encode(
|
|
x=alt.X("flops", title="Training FLOPS")
|
|
.scale(type="log", domain=[2e20, 1.25e22])
|
|
.axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
|
|
y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
|
|
)
|
|
lines = base_chart.encode(
|
|
color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
|
|
strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
|
|
).mark_line()
|
|
points = base_chart.encode(
|
|
color=alt.Color("key", title="Model", scale=color_scale),
|
|
shape=alt.Shape("key", title="", scale=shape_scale),
|
|
).mark_point(size=70)
|
|
chart = (
|
|
(lines + points)
|
|
.resolve_scale(
|
|
color="independent",
|
|
shape="independent",
|
|
# strokeDash="independent",
|
|
)
|
|
.configure_legend(orient="right")
|
|
.properties(height=300, width=400)
|
|
)
|
|
print("Saving", output_file)
|
|
chart.save(output_file)
|
|
|
|
|
|
def main():
|
|
config_path = sys.argv[1]
|
|
file_config = OmegaConf.load(config_path)
|
|
# Omit program name and config file name
|
|
cli_conf = OmegaConf.from_cli(sys.argv[2:])
|
|
conf_dict = OmegaConf.to_container(
|
|
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
|
|
)
|
|
plot_config = ScalingPlotsConfig(**conf_dict)
|
|
df_dir = Path(plot_config.df_dir)
|
|
chart_dir = Path(plot_config.output_chart_dir)
|
|
chart_dir.mkdir(exist_ok=True, parents=True)
|
|
for ff in plot_config.frame_files:
|
|
path = df_dir / ff
|
|
df = pd.read_json(path)
|
|
print(df)
|
|
print(df.columns)
|
|
create_chart(df, chart_dir / f"{path.name}.pdf")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|