Update setup.py

This commit is contained in:
PSBigBig 2025-06-13 13:10:06 +08:00 committed by GitHub
parent d4c895cb64
commit 77afd7d0ea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,24 +1,74 @@
from setuptools import setup, find_packages
"""
WFGY · Metrics & Visuals
Pure-NumPy + Matplotlib helpers
"""
setup(
name="wfgy_sdk",
version="1.0.1",
description="WFGY 1.0 • Self-Healing LLM Framework SDK",
author="PSBigBig",
author_email="hello@onestardao.com",
url="https://github.com/onestardao/WFGY",
packages=find_packages(include=["wfgy_sdk", "wfgy_sdk.*"]),
python_requires=">=3.10",
install_requires=[
"numpy<2.0",
"torch==2.2.2",
"transformers==4.41.2",
"tabulate",
"matplotlib",
],
entry_points={
"console_scripts": [
"wfgy=wfgy_sdk.cli:main",
]
},
)
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate
# ───────────────────────────────────────────────
# Stats helpers
# ───────────────────────────────────────────────
def _safe_std(x: np.ndarray) -> float:
s = float(np.std(x))
return s if s > 0 else 1e-12
def softmax(x: np.ndarray) -> np.ndarray:
z = x - x.max()
e = np.exp(z)
return e / e.sum()
# ───────────────────────────────────────────────
# Public API ── must return keys expected by CI
# ───────────────────────────────────────────────
def compare_logits(old: np.ndarray, new: np.ndarray) -> dict:
sr = _safe_std(new) / _safe_std(old) # std ratio
var_drop = 1.0 - sr
p, q = softmax(old), softmax(new)
kl = float(np.sum(p * np.log((p + 1e-8) / (q + 1e-8)))) # KL divergence
top1_same = int(old.argmax() == new.argmax())
return {
"std_ratio": sr,
"var_drop": var_drop,
"kl_divergence": kl, # <── CI & Space both want this exact key
"top1": top1_same,
}
# ───────────────────────────────────────────────
# Pretty print for CLI demo
# ───────────────────────────────────────────────
def pretty_print(m: dict) -> str:
tbl = tabulate(
[[f"{m['std_ratio']:.3f}",
f"{m['var_drop']*100:4.1f} %",
f"{m['kl_divergence']:.3f}",
"" if m['top1'] else ""]],
headers=["std_ratio", "▼ var", "KL", "Top-1"],
tablefmt="github",
)
return tbl
# ───────────────────────────────────────────────
# Histogram figure
# ───────────────────────────────────────────────
def plot_histogram(old: np.ndarray, new: np.ndarray, bins: int = 50) -> plt.Figure:
fig, ax = plt.subplots(figsize=(6, 3.5), dpi=110)
ax.hist(old, bins=bins, alpha=0.6, label="Raw", log=True)
ax.hist(new, bins=bins, alpha=0.6, label="WFGY", log=True)
ax.set_title("Logit Distribution (log-scale)")
ax.set_xlabel("logit value")
ax.set_ylabel("frequency")
ax.legend()
fig.tight_layout()
return fig