mirror of
https://github.com/onestardao/WFGY.git
synced 2026-05-22 03:02:03 +00:00
Update setup.py
This commit is contained in:
parent
d4c895cb64
commit
77afd7d0ea
1 changed files with 73 additions and 23 deletions
96
setup.py
96
setup.py
|
|
@ -1,24 +1,74 @@
|
|||
from setuptools import setup, find_packages
|
||||
"""
|
||||
WFGY · Metrics & Visuals
|
||||
Pure-NumPy + Matplotlib helpers
|
||||
"""
|
||||
|
||||
setup(
|
||||
name="wfgy_sdk",
|
||||
version="1.0.1",
|
||||
description="WFGY 1.0 • Self-Healing LLM Framework SDK",
|
||||
author="PSBigBig",
|
||||
author_email="hello@onestardao.com",
|
||||
url="https://github.com/onestardao/WFGY",
|
||||
packages=find_packages(include=["wfgy_sdk", "wfgy_sdk.*"]),
|
||||
python_requires=">=3.10",
|
||||
install_requires=[
|
||||
"numpy<2.0",
|
||||
"torch==2.2.2",
|
||||
"transformers==4.41.2",
|
||||
"tabulate",
|
||||
"matplotlib",
|
||||
],
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"wfgy=wfgy_sdk.cli:main",
|
||||
]
|
||||
},
|
||||
)
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# Stats helpers
|
||||
# ───────────────────────────────────────────────
|
||||
|
||||
def _safe_std(x: np.ndarray) -> float:
|
||||
s = float(np.std(x))
|
||||
return s if s > 0 else 1e-12
|
||||
|
||||
|
||||
def softmax(x: np.ndarray) -> np.ndarray:
|
||||
z = x - x.max()
|
||||
e = np.exp(z)
|
||||
return e / e.sum()
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# Public API ── must return keys expected by CI
|
||||
# ───────────────────────────────────────────────
|
||||
|
||||
def compare_logits(old: np.ndarray, new: np.ndarray) -> dict:
|
||||
sr = _safe_std(new) / _safe_std(old) # std ratio
|
||||
var_drop = 1.0 - sr
|
||||
p, q = softmax(old), softmax(new)
|
||||
kl = float(np.sum(p * np.log((p + 1e-8) / (q + 1e-8)))) # KL divergence
|
||||
top1_same = int(old.argmax() == new.argmax())
|
||||
|
||||
return {
|
||||
"std_ratio": sr,
|
||||
"var_drop": var_drop,
|
||||
"kl_divergence": kl, # <── CI & Space both want this exact key
|
||||
"top1": top1_same,
|
||||
}
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# Pretty print for CLI demo
|
||||
# ───────────────────────────────────────────────
|
||||
|
||||
def pretty_print(m: dict) -> str:
|
||||
tbl = tabulate(
|
||||
[[f"{m['std_ratio']:.3f}",
|
||||
f"{m['var_drop']*100:4.1f} %",
|
||||
f"{m['kl_divergence']:.3f}",
|
||||
"✔" if m['top1'] else "✘"]],
|
||||
headers=["std_ratio", "▼ var", "KL", "Top-1"],
|
||||
tablefmt="github",
|
||||
)
|
||||
return tbl
|
||||
|
||||
|
||||
# ───────────────────────────────────────────────
|
||||
# Histogram figure
|
||||
# ───────────────────────────────────────────────
|
||||
|
||||
def plot_histogram(old: np.ndarray, new: np.ndarray, bins: int = 50) -> plt.Figure:
|
||||
fig, ax = plt.subplots(figsize=(6, 3.5), dpi=110)
|
||||
ax.hist(old, bins=bins, alpha=0.6, label="Raw", log=True)
|
||||
ax.hist(new, bins=bins, alpha=0.6, label="WFGY", log=True)
|
||||
ax.set_title("Logit Distribution (log-scale)")
|
||||
ax.set_xlabel("logit value")
|
||||
ax.set_ylabel("frequency")
|
||||
ax.legend()
|
||||
fig.tight_layout()
|
||||
return fig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue