mirror of
https://github.com/facebookresearch/blt.git
synced 2025-01-18 16:37:46 +00:00
101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import altair as alt
|
|
import pandas as pd
|
|
from omegaconf import OmegaConf
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class PlotEntropiesConfig(BaseModel):
|
|
data_path: str | None
|
|
chart_path: str
|
|
score_override_path: str | None = None
|
|
threshold_override: float | None = None
|
|
|
|
class Config:
|
|
extra = "forbid"
|
|
|
|
|
|
class PlotEntropiesData(BaseModel):
|
|
text: str
|
|
threshold: float = 1.335442066192627
|
|
dataframe_json: str | None
|
|
|
|
class Config:
|
|
extra = "forbid"
|
|
|
|
|
|
def main():
|
|
config_path = sys.argv[1]
|
|
file_config = OmegaConf.load(config_path)
|
|
# Omit program name and config file name
|
|
cli_conf = OmegaConf.from_cli(sys.argv[2:])
|
|
conf_dict = OmegaConf.to_container(
|
|
OmegaConf.merge(file_config, cli_conf), resolve=True, throw_on_missing=True
|
|
)
|
|
plot_config = PlotEntropiesConfig(**conf_dict)
|
|
with open(plot_config.data_path) as f:
|
|
json_data = f.read()
|
|
|
|
plot_data = PlotEntropiesData.model_validate_json(json_data)
|
|
df = pd.read_json(plot_data.dataframe_json)
|
|
print("LEN", len(df))
|
|
if plot_config.threshold_override is None:
|
|
threshold = plot_data.threshold
|
|
else:
|
|
threshold = plot_config.threshold_override
|
|
if plot_config.score_override_path is not None:
|
|
with open(plot_config.score_override_path) as f:
|
|
scores = json.load(f)["score"]
|
|
assert len(scores) == len(df)
|
|
df["entropies"] = scores
|
|
df["start"] = [1] + (df["entropies"] > threshold).values.tolist()[:-1]
|
|
|
|
x_ticks = []
|
|
for row in df.itertuples():
|
|
position = row.position
|
|
token = row.tokens
|
|
x_ticks.append(f"{str(position).zfill(3)}|{token}")
|
|
df["position_with_token"] = x_ticks
|
|
print(df)
|
|
|
|
x_axis = alt.Axis(
|
|
labelExpr="split(datum.label, '|')[1]",
|
|
grid=False,
|
|
labelOverlap=False,
|
|
labelAngle=0,
|
|
)
|
|
width = 1200
|
|
height = 150
|
|
base = alt.Chart(df).properties(width=width, height=height)
|
|
points = base.mark_line(point=True).encode(
|
|
x=alt.X("position_with_token:O", title=None, axis=x_axis),
|
|
y=alt.Y(
|
|
"entropies",
|
|
title="Entropy of Next Byte",
|
|
),
|
|
)
|
|
rule = base.mark_rule(color="red", strokeDash=[4, 4]).encode(
|
|
y=alt.datum(threshold),
|
|
)
|
|
patch_rules = (
|
|
alt.Chart(df[df["start"] > 0])
|
|
.properties(width=width, height=height)
|
|
.mark_rule(color="#474747", strokeDash=[4, 2])
|
|
.encode(x=alt.X("position_with_token:O", axis=x_axis))
|
|
)
|
|
|
|
chart = patch_rules + rule + points
|
|
chart = chart.configure_axis(labelFontSize=15, titleFontSize=15)
|
|
path = Path(plot_config.chart_path)
|
|
path.parent.mkdir(exist_ok=True)
|
|
chart.save(path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|