diff --git a/bytelatent/plotting/config_entropy_figure.yaml b/bytelatent/plotting/config_entropy_figure.yaml index 4d7bfd7..296ea07 100644 --- a/bytelatent/plotting/config_entropy_figure.yaml +++ b/bytelatent/plotting/config_entropy_figure.yaml @@ -1,3 +1,4 @@ data_path: plot_data/entropy_figure.json chart_path: figures/entropy_figure.pdf -# chart_path: figures/entropy_figure.pdf +threshold_override: 1.7171002626419067 +score_override_path: plot_data/scores.json diff --git a/bytelatent/plotting/entropy_figure.py b/bytelatent/plotting/entropy_figure.py index c1966a1..f401d7c 100644 --- a/bytelatent/plotting/entropy_figure.py +++ b/bytelatent/plotting/entropy_figure.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +import json import os import sys from pathlib import Path @@ -12,6 +13,8 @@ from pydantic import BaseModel class PlotEntropiesConfig(BaseModel): data_path: str | None chart_path: str + score_override_path: str | None = None + threshold_override: float | None = None class Config: extra = "forbid" @@ -37,8 +40,20 @@ def main(): plot_config = PlotEntropiesConfig(**conf_dict) with open(plot_config.data_path) as f: json_data = f.read() + plot_data = PlotEntropiesData.model_validate_json(json_data) df = pd.read_json(plot_data.dataframe_json) + print("LEN", len(df)) + if plot_config.threshold_override is None: + threshold = plot_data.threshold + else: + threshold = plot_config.threshold_override + if plot_config.score_override_path is not None: + with open(plot_config.score_override_path) as f: + scores = json.load(f)["score"] + assert len(scores) == len(df) + df["entropies"] = scores + df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1] x_ticks = [] for row in df.itertuples(): @@ -65,7 +80,7 @@ def main(): ), ) rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode( - y=alt.datum(plot_data.threshold), + y=alt.datum(threshold), ) patch_rules = ( alt.Chart(df[df["start"] > 0]) diff --git a/plot_data/scores.json b/plot_data/scores.json new file mode 100644 index 0000000..202cafc --- /dev/null +++ b/plot_data/scores.json @@ -0,0 +1 @@ +{"score": [3.3949153423309326, 2.1647746562957764, 2.3216569423675537, 2.8114914894104004, 1.505232334136963, 0.04055612534284592, 0.09150367230176926, 0.06008715182542801, 0.3453567624092102, 1.0483067035675049, 0.1967127025127411, 0.12737397849559784, 0.05923430994153023, 0.001597292022779584, 0.004362526815384626, 0.005547997076064348, 0.0011689786333590746, 0.0010273229563608766, 1.0228447914123535, 3.6863417625427246, 0.46605175733566284, 0.048645928502082825, 2.2544963359832764, 0.37329360842704773, 1.001160979270935, 2.9116122722625732, 1.8948925733566284, 1.4017235040664673, 0.3879640996456146, 0.2652309536933899, 1.780383825302124, 0.013964788988232613, 0.005456871818751097, 0.5426468253135681, 0.20666983723640442, 0.0051853349432349205, 0.0005802579107694328, 0.0007443525246344507, 0.0004390323010738939, 0.005452247802168131, 1.1932975053787231, 0.023798620328307152, 3.1230878829956055, 1.3915895223617554, 3.0489213466644287, 1.7018193006515503, 1.873910903930664, 1.4662408828735352, 0.004920408595353365, 0.02599342167377472, 0.6620859503746033, 0.31743818521499634, 2.8409600257873535, 1.1354060173034668, 0.0520976223051548, 0.3519965708255768, 0.40707266330718994, 2.5438783168792725, 1.3343133926391602, 0.023993035778403282, 3.445943832397461, 1.8542104959487915, 0.7849258780479431, 0.6848396062850952, 0.06938046962022781, 0.20923230051994324, 0.10084306448698044, 0.18334199488162994, 0.4126923978328705, 0.5505472421646118, 0.1042013093829155, 0.019447727128863335, 0.0014866517158225179, 0.0009848219342529774, 0.00021391961490735412, 0.007746236398816109, 0.00038792978739365935, 0.0007933690212666988, 1.2369810342788696, 0.4436197578907013, 4.6366687456611544e-05]} \ No newline at end of file