blt/bytelatent/plotting/entropy_figure.py
Pedro Rodriguez d4ddb95322
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run
Add plotting code from paper (#17)
Summary:

Test Plan:
2025-01-09 12:11:50 -08:00

101 lines
2.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import os
import sys
from pathlib import Path
import altair as alt
import pandas as pd
from omegaconf import OmegaConf
from pydantic import BaseModel
class PlotEntropiesConfig(BaseModel):
data_path: str | None
chart_path: str
score_override_path: str | None = None
threshold_override: float | None = None
class Config:
extra = "forbid"
class PlotEntropiesData(BaseModel):
text: str
threshold: float = 1.335442066192627
dataframe_json: str | None
class Config:
extra = "forbid"
def main():
config_path = sys.argv[1]
file_config = OmegaConf.load(config_path)
# Omit program name and config file name
cli_conf = OmegaConf.from_cli(sys.argv[2:])
conf_dict = OmegaConf.to_container(
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
)
plot_config = PlotEntropiesConfig(**conf_dict)
with open(plot_config.data_path) as f:
json_data = f.read()
plot_data = PlotEntropiesData.model_validate_json(json_data)
df = pd.read_json(plot_data.dataframe_json)
print("LEN", len(df))
if plot_config.threshold_override is None:
threshold = plot_data.threshold
else:
threshold = plot_config.threshold_override
if plot_config.score_override_path is not None:
with open(plot_config.score_override_path) as f:
scores = json.load(f)["score"]
assert len(scores) == len(df)
df["entropies"] = scores
df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1]
x_ticks = []
for row in df.itertuples():
position = row.position
token = row.tokens
x_ticks.append(f"{str(position).zfill(3)}|{token}")
df["position_with_token"] = x_ticks
print(df)
x_axis = alt.Axis(
labelExpr="split(datum.label, '|')[1]",
grid=False,
labelOverlap=False,
labelAngle=0,
)
width = 1200
height = 150
base = alt.Chart(df).properties(width=width, height=height)
points = base.mark_line(point=True).encode(
x=alt.X("position_with_token:O", title=None, axis=x_axis),
y=alt.Y(
"entropies",
title="Entropy of Next Byte",
),
)
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
y=alt.datum(threshold),
)
patch_rules = (
alt.Chart(df[df["start"] > 0])
.properties(width=width, height=height)
.mark_rule(color="#474747", strokeDash=[4, 2])
.encode(x=alt.X("position_with_token:O", axis=x_axis))
)
chart = patch_rules + rule + points
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
path = Path(plot_config.chart_path)
path.parent.mkdir(exist_ok=True)
chart.save(path)
if __name__ == "__main__":
main()