Add plotting code from paper (#17)
Some checks are pending
Lint with Black / lint (push) Waiting to run
Lint with isort / lint (push) Waiting to run

Summary:

Test Plan:
This commit is contained in:
Pedro Rodriguez 2025-01-09 12:11:50 -08:00 committed by GitHub
parent 2fdc6f3cc9
commit d4ddb95322
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 19 additions and 2 deletions

View file

@ -1,3 +1,4 @@
data_path: plot_data/entropy_figure.json data_path: plot_data/entropy_figure.json
chart_path: figures/entropy_figure.pdf chart_path: figures/entropy_figure.pdf
# chart_path: figures/entropy_figure.pdf threshold_override: 1.7171002626419067
score_override_path: plot_data/scores.json

View file

@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@ -12,6 +13,8 @@ from pydantic import BaseModel
class PlotEntropiesConfig(BaseModel): class PlotEntropiesConfig(BaseModel):
data_path: str | None data_path: str | None
chart_path: str chart_path: str
score_override_path: str | None = None
threshold_override: float | None = None
class Config: class Config:
extra = "forbid" extra = "forbid"
@ -37,8 +40,20 @@ def main():
plot_config = PlotEntropiesConfig(**conf_dict) plot_config = PlotEntropiesConfig(**conf_dict)
with open(plot_config.data_path) as f: with open(plot_config.data_path) as f:
json_data = f.read() json_data = f.read()
plot_data = PlotEntropiesData.model_validate_json(json_data) plot_data = PlotEntropiesData.model_validate_json(json_data)
df = pd.read_json(plot_data.dataframe_json) df = pd.read_json(plot_data.dataframe_json)
print("LEN", len(df))
if plot_config.threshold_override is None:
threshold = plot_data.threshold
else:
threshold = plot_config.threshold_override
if plot_config.score_override_path is not None:
with open(plot_config.score_override_path) as f:
scores = json.load(f)["score"]
assert len(scores) == len(df)
df["entropies"] = scores
df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1]
x_ticks = [] x_ticks = []
for row in df.itertuples(): for row in df.itertuples():
@ -65,7 +80,7 @@ def main():
), ),
) )
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
y=alt.datum(plot_data.threshold), y=alt.datum(threshold),
) )
patch_rules = ( patch_rules = (
alt.Chart(df[df["start"] > 0]) alt.Chart(df[df["start"] > 0])

1
plot_data/scores.json Normal file
View file

@ -0,0 +1 @@
{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]}