mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 08:27:45 +00:00
parent
2fdc6f3cc9
commit
d4ddb95322
|
@ -1,3 +1,4 @@
|
||||||
data_path: plot_data/entropy_figure.json
|
data_path: plot_data/entropy_figure.json
|
||||||
chart_path: figures/entropy_figure.pdf
|
chart_path: figures/entropy_figure.pdf
|
||||||
# chart_path: figures/entropy_figure.pdf
|
threshold_override: 1.7171002626419067
|
||||||
|
score_override_path: plot_data/scores.json
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -12,6 +13,8 @@ from pydantic import BaseModel
|
||||||
class PlotEntropiesConfig(BaseModel):
|
class PlotEntropiesConfig(BaseModel):
|
||||||
data_path: str | None
|
data_path: str | None
|
||||||
chart_path: str
|
chart_path: str
|
||||||
|
score_override_path: str | None = None
|
||||||
|
threshold_override: float | None = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
@ -37,8 +40,20 @@ def main():
|
||||||
plot_config = PlotEntropiesConfig(**conf_dict)
|
plot_config = PlotEntropiesConfig(**conf_dict)
|
||||||
with open(plot_config.data_path) as f:
|
with open(plot_config.data_path) as f:
|
||||||
json_data = f.read()
|
json_data = f.read()
|
||||||
|
|
||||||
plot_data = PlotEntropiesData.model_validate_json(json_data)
|
plot_data = PlotEntropiesData.model_validate_json(json_data)
|
||||||
df = pd.read_json(plot_data.dataframe_json)
|
df = pd.read_json(plot_data.dataframe_json)
|
||||||
|
print("LEN", len(df))
|
||||||
|
if plot_config.threshold_override is None:
|
||||||
|
threshold = plot_data.threshold
|
||||||
|
else:
|
||||||
|
threshold = plot_config.threshold_override
|
||||||
|
if plot_config.score_override_path is not None:
|
||||||
|
with open(plot_config.score_override_path) as f:
|
||||||
|
scores = json.load(f)["score"]
|
||||||
|
assert len(scores) == len(df)
|
||||||
|
df["entropies"] = scores
|
||||||
|
df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1]
|
||||||
|
|
||||||
x_ticks = []
|
x_ticks = []
|
||||||
for row in df.itertuples():
|
for row in df.itertuples():
|
||||||
|
@ -65,7 +80,7 @@ def main():
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
|
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
|
||||||
y=alt.datum(plot_data.threshold),
|
y=alt.datum(threshold),
|
||||||
)
|
)
|
||||||
patch_rules = (
|
patch_rules = (
|
||||||
alt.Chart(df[df["start"] > 0])
|
alt.Chart(df[df["start"] > 0])
|
||||||
|
|
1
plot_data/scores.json
Normal file
1
plot_data/scores.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]}
|
Loading…
Reference in a new issue