mirror of
https://github.com/facebookresearch/blt.git
synced 2025-09-10 06:14:35 +00:00
Initial commit
This commit is contained in:
commit
bcc039bb75
86 changed files with 12203 additions and 0 deletions
108
bytelatent/plotting/scaling_figures.py
Normal file
108
bytelatent/plotting/scaling_figures.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import altair as alt
|
||||
import pandas as pd
|
||||
import pydantic
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
class ScalingPlotsConfig(pydantic.BaseModel):
|
||||
df_dir: str
|
||||
output_chart_dir: str
|
||||
frame_files: list[str]
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
def determine_family(key: str):
|
||||
if key.startswith("Megabyte++"):
|
||||
return "Megabyte++"
|
||||
elif key.startswith("BLT"):
|
||||
return "BLT"
|
||||
elif key.startswith("LLaMA"):
|
||||
return "LLaMA"
|
||||
elif key.startswith("Space"):
|
||||
return "Space"
|
||||
|
||||
|
||||
file_to_vars = {}
|
||||
|
||||
|
||||
def create_chart(df: pd.DataFrame, output_file: str):
|
||||
df["metric"] = df["bpb/not_heldout.jsonl"]
|
||||
df["family"] = df["key"].map(determine_family)
|
||||
model_domain = [
|
||||
"BLT Space ps=6",
|
||||
"BLT Space w/o cross-attn",
|
||||
"SpaceByte",
|
||||
"LLaMA 3 BPE",
|
||||
"Megabyte++ ps=4",
|
||||
"Megabyte++ ps=6",
|
||||
]
|
||||
color_range = ["#1f77b4", "#1f77b4", "#1f77b4", "#ff7f0e", "#2ca02c", "#2ca02c"]
|
||||
shape_range = [
|
||||
"circle",
|
||||
"square",
|
||||
"cross",
|
||||
"diamond",
|
||||
"triangle-up",
|
||||
"triangle-down",
|
||||
]
|
||||
color_scale = alt.Scale(domain=model_domain, range=color_range)
|
||||
shape_scale = alt.Scale(
|
||||
domain=model_domain,
|
||||
range=shape_range,
|
||||
)
|
||||
base_chart = alt.Chart(df).encode(
|
||||
x=alt.X("flops", title="Training FLOPS")
|
||||
.scale(type="log", domain=[2e20, 1.25e22])
|
||||
.axis(values=[2e20, 4e20, 8e20, 1e21, 2e21, 4e21, 8e21, 1e22]),
|
||||
y=alt.Y("metric", title="Bits per Byte (BPB)").scale(zero=False),
|
||||
)
|
||||
lines = base_chart.encode(
|
||||
color=alt.Color("key", title="Model Color", scale=color_scale, legend=None),
|
||||
strokeDash=alt.StrokeDash("family", title="Model Family", legend=None),
|
||||
).mark_line()
|
||||
points = base_chart.encode(
|
||||
color=alt.Color("key", title="Model", scale=color_scale),
|
||||
shape=alt.Shape("key", title="", scale=shape_scale),
|
||||
).mark_point(size=70)
|
||||
chart = (
|
||||
(lines + points)
|
||||
.resolve_scale(
|
||||
color="independent",
|
||||
shape="independent",
|
||||
# strokeDash="independent",
|
||||
)
|
||||
.configure_legend(orient="right")
|
||||
.properties(height=300, width=400)
|
||||
)
|
||||
print("Saving", output_file)
|
||||
chart.save(output_file)
|
||||
|
||||
|
||||
def main():
|
||||
config_path = sys.argv[1]
|
||||
file_config = OmegaConf.load(config_path)
|
||||
# Omit program name and config file name
|
||||
cli_conf = OmegaConf.from_cli(sys.argv[2:])
|
||||
conf_dict = OmegaConf.to_container(
|
||||
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
|
||||
)
|
||||
plot_config = ScalingPlotsConfig(**conf_dict)
|
||||
df_dir = Path(plot_config.df_dir)
|
||||
chart_dir = Path(plot_config.output_chart_dir)
|
||||
chart_dir.mkdir(exist_ok=True, parents=True)
|
||||
for ff in plot_config.frame_files:
|
||||
path = df_dir / ff
|
||||
df = pd.read_json(path)
|
||||
print(df)
|
||||
print(df.columns)
|
||||
create_chart(df, chart_dir / f"{path.name}.pdf")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue