From 77afd7d0ea5efe73fc2babcc9d15fb565b7189ed Mon Sep 17 00:00:00 2001 From: PSBigBig Date: Fri, 13 Jun 2025 13:10:06 +0800 Subject: [PATCH] Update setup.py --- setup.py | 96 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index d129b7eb..2213f45a 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,74 @@ -from setuptools import setup, find_packages +""" +WFGY · Metrics & Visuals +Pure-NumPy + Matplotlib helpers +""" -setup( - name="wfgy_sdk", - version="1.0.1", - description="WFGY 1.0 • Self-Healing LLM Framework SDK", - author="PSBigBig", - author_email="hello@onestardao.com", - url="https://github.com/onestardao/WFGY", - packages=find_packages(include=["wfgy_sdk", "wfgy_sdk.*"]), - python_requires=">=3.10", - install_requires=[ - "numpy<2.0", - "torch==2.2.2", - "transformers==4.41.2", - "tabulate", - "matplotlib", - ], - entry_points={ - "console_scripts": [ - "wfgy=wfgy_sdk.cli:main", - ] - }, -) +import numpy as np +import matplotlib.pyplot as plt +from tabulate import tabulate + + +# ─────────────────────────────────────────────── +# Stats helpers +# ─────────────────────────────────────────────── + +def _safe_std(x: np.ndarray) -> float: + s = float(np.std(x)) + return s if s > 0 else 1e-12 + + +def softmax(x: np.ndarray) -> np.ndarray: + z = x - x.max() + e = np.exp(z) + return e / e.sum() + + +# ─────────────────────────────────────────────── +# Public API ── must return keys expected by CI +# ─────────────────────────────────────────────── + +def compare_logits(old: np.ndarray, new: np.ndarray) -> dict: + sr = _safe_std(new) / _safe_std(old) # std ratio + var_drop = 1.0 - sr + p, q = softmax(old), softmax(new) + kl = float(np.sum(p * np.log((p + 1e-8) / (q + 1e-8)))) # KL divergence + top1_same = int(old.argmax() == new.argmax()) + + return { + "std_ratio": sr, + "var_drop": var_drop, + "kl_divergence": kl, # <── CI & Space both want this exact key + "top1": top1_same, + } + + +# ─────────────────────────────────────────────── +# Pretty print for CLI demo +# ─────────────────────────────────────────────── + +def pretty_print(m: dict) -> str: + tbl = tabulate( + [[f"{m['std_ratio']:.3f}", + f"{m['var_drop']*100:4.1f} %", + f"{m['kl_divergence']:.3f}", + "✔" if m['top1'] else "✘"]], + headers=["std_ratio", "▼ var", "KL", "Top-1"], + tablefmt="github", + ) + return tbl + + +# ─────────────────────────────────────────────── +# Histogram figure +# ─────────────────────────────────────────────── + +def plot_histogram(old: np.ndarray, new: np.ndarray, bins: int = 50) -> plt.Figure: + fig, ax = plt.subplots(figsize=(6, 3.5), dpi=110) + ax.hist(old, bins=bins, alpha=0.6, label="Raw", log=True) + ax.hist(new, bins=bins, alpha=0.6, label="WFGY", log=True) + ax.set_title("Logit Distribution (log-scale)") + ax.set_xlabel("logit value") + ax.set_ylabel("frequency") + ax.legend() + fig.tight_layout() + return fig