mirror of
https://github.com/ruvnet/RuVector.git
synced 2026-05-23 04:27:11 +00:00
246 lines
9.2 KiB
Python
246 lines
9.2 KiB
Python
#!/usr/bin/env python3
|
|
"""LoRA training epochs and drift monitoring for pi.ruv.io brain.
|
|
|
|
Submits 3 LoRA delta submissions from 3 different contributor keys to
|
|
trigger Byzantine-tolerant federated aggregation (min_submissions=3),
|
|
checks drift after each, and reports final state.
|
|
|
|
The server expects LoraSubmission with:
|
|
- down_proj: Vec<f32> of size hidden_dim * rank = 128 * 2 = 256
|
|
- up_proj: Vec<f32> of size rank * hidden_dim = 2 * 128 = 256
|
|
- rank: 2
|
|
- hidden_dim: 128
|
|
- evidence_count: u64 (>= 5)
|
|
"""
|
|
import json
|
|
import random
|
|
import urllib.request
|
|
import urllib.error
|
|
import time
|
|
import hashlib
|
|
import sys
|
|
|
|
BASE = "https://ruvbrain-875130704813.us-central1.run.app"
|
|
|
|
# Server defaults: Rank-2, 128-dim
|
|
RANK = 2
|
|
HIDDEN_DIM = 128
|
|
PROJ_SIZE = HIDDEN_DIM * RANK # 256 floats each for down_proj and up_proj
|
|
|
|
# 3 distinct contributor keys for Byzantine aggregation testing
|
|
CONTRIBUTOR_KEYS = [
|
|
"lora-trainer-key-alpha-" + hashlib.sha256(b"contributor-0").hexdigest()[:16],
|
|
"lora-trainer-key-beta-" + hashlib.sha256(b"contributor-1").hexdigest()[:16],
|
|
"lora-trainer-key-gamma-" + hashlib.sha256(b"contributor-2").hexdigest()[:16],
|
|
]
|
|
|
|
|
|
def make_headers(key):
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {key}",
|
|
}
|
|
|
|
|
|
def api_get(path, key=None):
|
|
"""GET request, returns (data_dict, status_code)."""
|
|
headers = make_headers(key or CONTRIBUTOR_KEYS[0])
|
|
req = urllib.request.Request(f"{BASE}{path}", headers=headers, method="GET")
|
|
try:
|
|
resp = urllib.request.urlopen(req, timeout=30)
|
|
body = resp.read().decode()
|
|
return json.loads(body) if body else {}, resp.status
|
|
except urllib.error.HTTPError as e:
|
|
body = e.read().decode()[:500]
|
|
return {"error": body, "code": e.code}, e.code
|
|
except Exception as e:
|
|
return {"error": str(e)}, 0
|
|
|
|
|
|
def api_post(path, data, key=None):
|
|
"""POST request, returns (data_dict, status_code)."""
|
|
headers = make_headers(key or CONTRIBUTOR_KEYS[0])
|
|
payload = json.dumps(data).encode()
|
|
req = urllib.request.Request(f"{BASE}{path}", data=payload, headers=headers, method="POST")
|
|
try:
|
|
resp = urllib.request.urlopen(req, timeout=30)
|
|
body = resp.read().decode()
|
|
return json.loads(body) if body else {}, resp.status
|
|
except urllib.error.HTTPError as e:
|
|
body = e.read().decode()[:500]
|
|
return {"error": body, "code": e.code}, e.code
|
|
except Exception as e:
|
|
return {"error": str(e)}, 0
|
|
|
|
|
|
def generate_proj_weights(size=PROJ_SIZE, std=0.01):
|
|
"""Generate small random weights centered around 0 with given std deviation."""
|
|
return [round(random.gauss(0.0, std), 8) for _ in range(size)]
|
|
|
|
|
|
def weight_stats(weights):
|
|
"""Compute min/max/mean/norm for a weight vector."""
|
|
mn = min(weights)
|
|
mx = max(weights)
|
|
mean = sum(weights) / len(weights)
|
|
norm = sum(w * w for w in weights) ** 0.5
|
|
return {"min": round(mn, 6), "max": round(mx, 6), "mean": round(mean, 6), "norm": round(norm, 4)}
|
|
|
|
|
|
def print_separator():
|
|
print("=" * 60)
|
|
|
|
|
|
def main():
|
|
print_separator()
|
|
print("LoRA Training & Drift Monitoring for pi.ruv.io")
|
|
print(f" Base URL: {BASE}")
|
|
print(f" Contributors: {len(CONTRIBUTOR_KEYS)}")
|
|
print(f" Rank: {RANK}, Hidden dim: {HIDDEN_DIM}")
|
|
print(f" Projection size (each): {PROJ_SIZE}")
|
|
print_separator()
|
|
|
|
# --- Step 1: Check current LoRA state ---
|
|
print("\n[1] Checking current LoRA state: GET /v1/lora/latest")
|
|
lora_state, status = api_get("/v1/lora/latest")
|
|
print(f" Status: {status}")
|
|
print(f" Response: {json.dumps(lora_state, indent=2)[:500]}")
|
|
|
|
current_epoch = 0
|
|
if status == 200 and "epoch" in lora_state:
|
|
current_epoch = lora_state["epoch"]
|
|
print(f" Current epoch: {current_epoch}")
|
|
else:
|
|
print(" No existing LoRA state found or error, starting from epoch 0")
|
|
|
|
# --- Step 2: Check initial drift ---
|
|
print("\n[2] Checking initial drift: GET /v1/drift")
|
|
drift_initial, drift_status = api_get("/v1/drift")
|
|
print(f" Status: {drift_status}")
|
|
print(f" Response: {json.dumps(drift_initial, indent=2)[:500]}")
|
|
|
|
# --- Step 3: Submit 3 LoRA submissions from 3 different contributors ---
|
|
# Server requires min_submissions=3 before aggregation triggers a new epoch
|
|
print("\n[3] Submitting 3 LoRA deltas from 3 contributors...")
|
|
print(" (Server aggregates after 3 submissions via Byzantine-tolerant federation)")
|
|
errors = []
|
|
|
|
for i in range(3):
|
|
contributor_key = CONTRIBUTOR_KEYS[i]
|
|
contributor_label = ["alpha", "beta", "gamma"][i]
|
|
|
|
print(f"\n --- Submission {i+1}/3 (contributor: {contributor_label}) ---")
|
|
|
|
# Generate random LoRA delta weights for down_proj and up_proj
|
|
down_proj = generate_proj_weights()
|
|
up_proj = generate_proj_weights()
|
|
|
|
print(f" down_proj stats: {weight_stats(down_proj)}")
|
|
print(f" up_proj stats: {weight_stats(up_proj)}")
|
|
|
|
# Build LoraSubmission matching the server's expected schema
|
|
payload = {
|
|
"down_proj": down_proj,
|
|
"up_proj": up_proj,
|
|
"rank": RANK,
|
|
"hidden_dim": HIDDEN_DIM,
|
|
"evidence_count": 10 + i * 5, # >= 5 required
|
|
}
|
|
|
|
print(f" POST /v1/lora/submit (rank={RANK}, hidden_dim={HIDDEN_DIM}, evidence={payload['evidence_count']})")
|
|
result, submit_status = api_post("/v1/lora/submit", payload, key=contributor_key)
|
|
print(f" Submit status: {submit_status}")
|
|
print(f" Submit response: {json.dumps(result, indent=2)[:400]}")
|
|
|
|
if submit_status not in (200, 201):
|
|
errors.append(f"Submission {i+1}: failed with status {submit_status}: {result}")
|
|
|
|
# Brief pause between submissions
|
|
time.sleep(0.5)
|
|
|
|
# Check drift after this submission
|
|
print(f"\n Checking drift after submission {i+1}: GET /v1/drift")
|
|
drift_data, drift_s = api_get("/v1/drift")
|
|
print(f" Drift status: {drift_s}")
|
|
print(f" Drift response: {json.dumps(drift_data, indent=2)[:400]}")
|
|
|
|
# --- Step 4: Check final LoRA state (should show new epoch after 3 submissions) ---
|
|
print("\n" + "=" * 60)
|
|
print("[4] Final LoRA state: GET /v1/lora/latest")
|
|
final_lora, final_status = api_get("/v1/lora/latest")
|
|
print(f" Status: {final_status}")
|
|
# Print truncated weights info
|
|
if final_status == 200 and final_lora.get("weights"):
|
|
w = final_lora["weights"]
|
|
summary = {
|
|
"epoch": final_lora.get("epoch"),
|
|
"rank": w.get("rank"),
|
|
"hidden_dim": w.get("hidden_dim"),
|
|
"contributor_count": w.get("contributor_count"),
|
|
"total_evidence": w.get("total_evidence"),
|
|
"down_proj_len": len(w.get("down_proj", [])),
|
|
"up_proj_len": len(w.get("up_proj", [])),
|
|
}
|
|
if w.get("down_proj"):
|
|
summary["down_proj_sample"] = [round(x, 6) for x in w["down_proj"][:5]]
|
|
if w.get("up_proj"):
|
|
summary["up_proj_sample"] = [round(x, 6) for x in w["up_proj"][:5]]
|
|
print(f" Consensus summary: {json.dumps(summary, indent=2)}")
|
|
else:
|
|
print(f" Response: {json.dumps(final_lora, indent=2)[:600]}")
|
|
|
|
# --- Step 5: Final drift report ---
|
|
print("\n[5] Final drift report: GET /v1/drift")
|
|
final_drift, final_drift_status = api_get("/v1/drift")
|
|
print(f" Status: {final_drift_status}")
|
|
print(f" Response: {json.dumps(final_drift, indent=2)[:600]}")
|
|
|
|
# --- Step 6: Check status for LoRA info ---
|
|
print("\n[6] Server status: GET /v1/status")
|
|
srv_status, srv_code = api_get("/v1/status")
|
|
if srv_code == 200:
|
|
lora_info = {
|
|
"lora_epoch": srv_status.get("lora_epoch"),
|
|
"lora_pending_submissions": srv_status.get("lora_pending_submissions"),
|
|
"total_memories": srv_status.get("total_memories"),
|
|
"total_contributors": srv_status.get("total_contributors"),
|
|
}
|
|
print(f" {json.dumps(lora_info, indent=2)}")
|
|
else:
|
|
print(f" Status: {srv_code} - {json.dumps(srv_status)[:200]}")
|
|
|
|
# --- Summary ---
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
|
|
final_epoch = final_lora.get("epoch", "?")
|
|
print(f" Starting epoch: {current_epoch}")
|
|
print(f" Final epoch: {final_epoch}")
|
|
|
|
if final_lora.get("weights"):
|
|
w = final_lora["weights"]
|
|
print(f" Contributors: {w.get('contributor_count', '?')}")
|
|
print(f" Total evidence: {w.get('total_evidence', '?')}")
|
|
|
|
if final_drift_status == 200:
|
|
print(f" Drift detected: {final_drift.get('is_drifting', 'unknown')}")
|
|
print(f" Drift trend: {final_drift.get('trend', 'unknown')}")
|
|
print(f" Drift CoV: {final_drift.get('coefficient_of_variation', 'unknown')}")
|
|
print(f" Delta sparsity: {final_drift.get('delta_sparsity', 'unknown')}")
|
|
print(f" Window size: {final_drift.get('window_size', 'unknown')}")
|
|
print(f" Suggested: {final_drift.get('suggested_action', 'unknown')}")
|
|
|
|
if errors:
|
|
print(f"\n ERRORS ({len(errors)}):")
|
|
for err in errors:
|
|
print(f" - {err[:200]}")
|
|
else:
|
|
print("\n Errors: None")
|
|
|
|
print("=" * 60)
|
|
return 0 if not errors else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|