WFGY/wfgy_sdk/cli.py
2025-09-27 16:11:32 +08:00

22 lines
743 B
Python

# cli.py
import argparse, wfgy_sdk as w
from wfgy_sdk.evaluator import compare_logits, pretty_print
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument("prompt", help="text prompt to test")
parser.add_argument("--model", default="gpt2",
help="huggingface model id (public)")
args = parser.parse_args()
logits = w.call_remote_model(args.prompt, model_id=args.model)
G = np.random.randn(128); G /= np.linalg.norm(G)
I = G + np.random.normal(scale=0.05, size=128)
logits_mod = w.get_engine().run(input_vec=I, ground_vec=G, logits=logits)
pretty_print(compare_logits(logits, logits_mod))
if __name__ == "__main__":
main()