blt/bytelatent/plotting/entropy_figure.py

101 lines
2.8 KiB
Python
Raw Permalink Normal View History

2024-12-12 23:32:30 +00:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
2024-12-12 23:32:30 +00:00
import os
import sys
from pathlib import Path
import altair as alt
import pandas as pd
from omegaconf import OmegaConf
from pydantic import BaseModel
class PlotEntropiesConfig(BaseModel):
data_path: str | None
chart_path: str
score_override_path: str | None = None
threshold_override: float | None = None
2024-12-12 23:32:30 +00:00
class Config:
extra = "forbid"
class PlotEntropiesData(BaseModel):
text: str
threshold: float = 1.335442066192627
dataframe_json: str | None
class Config:
extra = "forbid"
def main():
config_path = sys.argv[1]
file_config = OmegaConf.load(config_path)
# Omit program name and config file name
cli_conf = OmegaConf.from_cli(sys.argv[2:])
conf_dict = OmegaConf.to_container(
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
)
plot_config = PlotEntropiesConfig(**conf_dict)
with open(plot_config.data_path) as f:
json_data = f.read()
2024-12-12 23:32:30 +00:00
plot_data = PlotEntropiesData.model_validate_json(json_data)
df = pd.read_json(plot_data.dataframe_json)
print("LEN", len(df))
if plot_config.threshold_override is None:
threshold = plot_data.threshold
else:
threshold = plot_config.threshold_override
if plot_config.score_override_path is not None:
with open(plot_config.score_override_path) as f:
scores = json.load(f)["score"]
assert len(scores) == len(df)
df["entropies"] = scores
df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1]
2024-12-12 23:32:30 +00:00
x_ticks = []
for row in df.itertuples():
position = row.position
token = row.tokens
x_ticks.append(f"{str(position).zfill(3)}|{token}")
df["position_with_token"] = x_ticks
print(df)
x_axis = alt.Axis(
labelExpr="split(datum.label, '|')[1]",
grid=False,
labelOverlap=False,
labelAngle=0,
)
width = 1200
height = 150
base = alt.Chart(df).properties(width=width, height=height)
points = base.mark_line(point=True).encode(
x=alt.X("position_with_token:O", title=None, axis=x_axis),
y=alt.Y(
"entropies",
title="Entropy of Next Byte",
),
)
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
y=alt.datum(threshold),
2024-12-12 23:32:30 +00:00
)
patch_rules = (
alt.Chart(df[df["start"] > 0])
.properties(width=width, height=height)
.mark_rule(color="#474747", strokeDash=[4, 2])
.encode(x=alt.X("position_with_token:O", axis=x_axis))
)
chart = patch_rules + rule + points
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
path = Path(plot_config.chart_path)
path.parent.mkdir(exist_ok=True)
chart.save(path)
if __name__ == "__main__":
main()