mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-25 06:27:33 +00:00
Add human-in-the-loop survey helper (GUI + CLI + Agent tool) (#181)
* Add human-in-the-loop survey helper Co-authored-by: nic <nicsins@users.noreply.github.com> * Make survey helper launcher robust and add CLI fallback Co-authored-by: nic <nicsins@users.noreply.github.com> * Detect missing display for tkinter GUI Co-authored-by: nic <nicsins@users.noreply.github.com> * Add prediction dataset + review workflow for uncertain survey answers Co-authored-by: nic <nicsins@users.noreply.github.com> --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: nic <nicsins@users.noreply.github.com>
This commit is contained in:
parent
37e696b104
commit
b7efe4992a
9 changed files with 526 additions and 49 deletions
|
|
@ -76,6 +76,32 @@ def _infer_label(soup: BeautifulSoup, field_tag: Tag) -> str | None:
|
|||
field_tag.get("name"),
|
||||
)
|
||||
|
||||
def _infer_group_label(soup: BeautifulSoup, first_input: Tag) -> str | None:
|
||||
# Prefer <fieldset><legend>Question</legend>...</fieldset>
|
||||
fs = first_input.find_parent("fieldset")
|
||||
if isinstance(fs, Tag):
|
||||
legend = fs.find("legend")
|
||||
if isinstance(legend, Tag):
|
||||
txt = _text(legend)
|
||||
if txt:
|
||||
return txt
|
||||
|
||||
# Try previous meaningful text near the input (common in survey builders)
|
||||
probe: Tag | None = first_input
|
||||
for _ in range(6):
|
||||
if not probe:
|
||||
break
|
||||
prev = probe.find_previous(
|
||||
["h1", "h2", "h3", "h4", "h5", "h6", "p", "div", "span", "label"]
|
||||
)
|
||||
if isinstance(prev, Tag):
|
||||
txt = _text(prev)
|
||||
# Avoid using option labels (very short) as question label
|
||||
if txt and len(txt) >= 4:
|
||||
return txt
|
||||
probe = probe.parent if isinstance(probe.parent, Tag) else None
|
||||
|
||||
return _first_non_empty(first_input.get("name"), first_input.get("aria-label"))
|
||||
|
||||
def _iter_controls(soup: BeautifulSoup) -> Iterable[Tag]:
|
||||
for tag in soup.find_all(["input", "textarea", "select"]):
|
||||
|
|
@ -144,13 +170,14 @@ def extract_form_fields(html: str, *, max_fields: int = 200) -> list[ExtractedFi
|
|||
"value": opt.get("value"),
|
||||
}
|
||||
)
|
||||
group_label = _infer_group_label(soup, c)
|
||||
out.append(
|
||||
ExtractedField(
|
||||
kind="input",
|
||||
input_type=input_type,
|
||||
name=name,
|
||||
id=cid,
|
||||
label=_infer_label(soup, c),
|
||||
label=group_label or _infer_label(soup, c),
|
||||
required=required,
|
||||
options=options,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -64,3 +64,69 @@ def suggest_answers_with_ollama(
|
|||
msg = resp.get("message", {}) if isinstance(resp, dict) else {}
|
||||
return str(msg.get("content", "")).strip()
|
||||
|
||||
def predict_answers_json_with_ollama(
|
||||
*,
|
||||
model: str,
|
||||
url: str,
|
||||
title: str,
|
||||
fields_json: str,
|
||||
profile_json: str,
|
||||
top_k: int = 3,
|
||||
base_url: str = "http://localhost:11434",
|
||||
) -> dict:
|
||||
"""
|
||||
Return structured predictions for each field as JSON.
|
||||
|
||||
Contract:
|
||||
- Always output candidates with confidences in [0,1]
|
||||
- If profile lacks required info, still provide best-guess candidates but set needs_clarification=true
|
||||
- Never claim facts not supported by the profile; label assumptions in rationale
|
||||
"""
|
||||
system = (
|
||||
"You are a survey helper that suggests answers for the user.\n"
|
||||
"Hard rules:\n"
|
||||
"- Do NOT invent personal facts.\n"
|
||||
"- If info is missing, make an educated guess BUT mark it as an assumption and set needs_clarification=true.\n"
|
||||
"- Prefer neutral/privacy-preserving options when uncertain.\n"
|
||||
"- Output MUST be valid JSON only (no markdown).\n"
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"{system}\n"
|
||||
f"PAGE_URL: {url}\n"
|
||||
f"PAGE_TITLE: {title}\n"
|
||||
f"TOP_K: {top_k}\n\n"
|
||||
f"USER_PROFILE_JSON:\n{profile_json}\n\n"
|
||||
f"FIELDS_JSON (array of fields; each has kind/input_type/label/options/etc):\n{fields_json}\n\n"
|
||||
"Return JSON with this shape:\n"
|
||||
"{\n"
|
||||
' "predictions": [\n'
|
||||
" {\n"
|
||||
' "field_index": 1,\n'
|
||||
' "selected": "string",\n'
|
||||
' "confidence": 0.0,\n'
|
||||
' "candidates": [{"value":"string","confidence":0.0}],\n'
|
||||
' "needs_clarification": true,\n'
|
||||
' "rationale": "short explanation, mention when assumption"\n'
|
||||
" }\n"
|
||||
" ]\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
resp = _post_json(f"{base_url}/api/chat", payload=payload)
|
||||
msg = resp.get("message", {}) if isinstance(resp, dict) else {}
|
||||
content = str(msg.get("content", "")).strip()
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
pass
|
||||
return {"error": "Model did not return valid JSON", "raw": content}
|
||||
|
|
|
|||
129
python/survey_assistant/predictions.py
Normal file
129
python/survey_assistant/predictions.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
DEFAULT_PREDICTIONS_PATH = Path("memory") / "survey_predictions.jsonl"
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _stable_hash(data: Any) -> str:
|
||||
raw = json.dumps(data, sort_keys=True, ensure_ascii=False, separators=(",", ":")).encode(
|
||||
"utf-8"
|
||||
)
|
||||
return hashlib.sha256(raw).hexdigest()[:16]
|
||||
|
||||
|
||||
def build_question_id(*, url: str, field: dict[str, Any]) -> str:
|
||||
key = {
|
||||
"url": url,
|
||||
"kind": field.get("kind"),
|
||||
"input_type": field.get("input_type"),
|
||||
"name": field.get("name"),
|
||||
"id": field.get("id"),
|
||||
"label": field.get("label"),
|
||||
"options": field.get("options") or [],
|
||||
}
|
||||
return f"q_{_stable_hash(key)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Candidate:
|
||||
value: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionRecord:
|
||||
"""
|
||||
A single predicted answer for a question/field, stored for later review/clarification.
|
||||
"""
|
||||
|
||||
id: str
|
||||
timestamp: str
|
||||
url: str
|
||||
title: str
|
||||
field_index: int
|
||||
field: dict[str, Any]
|
||||
selected: str
|
||||
confidence: float
|
||||
candidates: list[Candidate]
|
||||
rationale: str
|
||||
needs_clarification: bool
|
||||
source: str # llm|heuristic|profile
|
||||
|
||||
def to_jsonl(self) -> str:
|
||||
data = asdict(self)
|
||||
data["candidates"] = [asdict(c) for c in self.candidates]
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
def append_predictions(
|
||||
records: Iterable[PredictionRecord],
|
||||
*,
|
||||
path: str | Path = DEFAULT_PREDICTIONS_PATH,
|
||||
) -> Path:
|
||||
p = Path(path)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
with p.open("a", encoding="utf-8") as f:
|
||||
for r in records:
|
||||
f.write(r.to_jsonl() + "\n")
|
||||
return p
|
||||
|
||||
|
||||
def load_predictions(path: str | Path = DEFAULT_PREDICTIONS_PATH) -> list[dict[str, Any]]:
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
return []
|
||||
out: list[dict[str, Any]] = []
|
||||
for line in p.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if isinstance(obj, dict):
|
||||
out.append(obj)
|
||||
except Exception:
|
||||
continue
|
||||
return out
|
||||
|
||||
|
||||
def pending_predictions(
|
||||
path: str | Path = DEFAULT_PREDICTIONS_PATH,
|
||||
) -> list[dict[str, Any]]:
|
||||
return [r for r in load_predictions(path) if r.get("needs_clarification") is True]
|
||||
|
||||
|
||||
def write_clarifications(
|
||||
clarifications: dict[str, str],
|
||||
*,
|
||||
path: str | Path = Path("memory") / "survey_clarifications.json",
|
||||
) -> Path:
|
||||
p = Path(path)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
existing: dict[str, str] = {}
|
||||
if p.exists():
|
||||
try:
|
||||
obj = json.loads(p.read_text(encoding="utf-8"))
|
||||
if isinstance(obj, dict):
|
||||
existing = {str(k): str(v) for k, v in obj.items()}
|
||||
except Exception:
|
||||
existing = {}
|
||||
existing.update({str(k): str(v) for k, v in clarifications.items()})
|
||||
p.write_text(json.dumps(existing, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def utc_now_iso() -> str:
|
||||
# exported helper
|
||||
return _utc_now_iso()
|
||||
|
||||
|
|
@ -4,7 +4,15 @@ from python.helpers.tool import Tool, Response
|
|||
from python.survey_assistant.browser_render import render_url_async
|
||||
from python.survey_assistant.extract import extract_form_fields
|
||||
from python.survey_assistant.profile import SurveyProfile
|
||||
from python.survey_assistant.llm import ollama_available, suggest_answers_with_ollama
|
||||
from python.survey_assistant.llm import ollama_available, predict_answers_json_with_ollama
|
||||
from python.survey_assistant.predictions import (
|
||||
DEFAULT_PREDICTIONS_PATH,
|
||||
PredictionRecord,
|
||||
Candidate,
|
||||
append_predictions,
|
||||
build_question_id,
|
||||
utc_now_iso,
|
||||
)
|
||||
|
||||
|
||||
class SurveyHelper(Tool):
|
||||
|
|
@ -20,6 +28,9 @@ class SurveyHelper(Tool):
|
|||
html: str = "",
|
||||
include_suggestions: bool = False,
|
||||
ollama_model: str = "llama3",
|
||||
top_k: int = 3,
|
||||
record_predictions: bool = False,
|
||||
predictions_path: str = str(DEFAULT_PREDICTIONS_PATH),
|
||||
**kwargs,
|
||||
) -> Response:
|
||||
if not url and not html:
|
||||
|
|
@ -46,31 +57,68 @@ class SurveyHelper(Tool):
|
|||
|
||||
if include_suggestions:
|
||||
profile = SurveyProfile.load()
|
||||
questions_lines = []
|
||||
for i, f in enumerate(fields, start=1):
|
||||
label = f.label or f.name or f.id or "(unlabeled)"
|
||||
t = f.input_type or f.kind
|
||||
req = " (required)" if f.required else ""
|
||||
questions_lines.append(f"{i}. {label} — {t}{req}")
|
||||
if f.options:
|
||||
for opt in f.options[:30]:
|
||||
questions_lines.append(f" - {opt.get('label')}")
|
||||
if len(f.options) > 30:
|
||||
questions_lines.append(" - ...")
|
||||
|
||||
if ollama_available():
|
||||
try:
|
||||
payload["suggestions"] = suggest_answers_with_ollama(
|
||||
pred = predict_answers_json_with_ollama(
|
||||
model=ollama_model,
|
||||
questions_text="\n".join(questions_lines),
|
||||
url=final_url,
|
||||
title=page_title,
|
||||
fields_json=json.dumps([f.to_dict() for f in fields], ensure_ascii=False),
|
||||
profile_json=json.dumps(profile.as_dict(), indent=2, ensure_ascii=False),
|
||||
top_k=max(1, min(8, int(top_k or 3))),
|
||||
)
|
||||
payload["predictions"] = pred.get("predictions", [])
|
||||
if pred.get("error"):
|
||||
payload["predictions_error"] = pred.get("error")
|
||||
payload["predictions_raw"] = pred.get("raw")
|
||||
|
||||
if record_predictions and isinstance(payload.get("predictions"), list):
|
||||
records: list[PredictionRecord] = []
|
||||
for item in payload["predictions"]:
|
||||
try:
|
||||
idx = int(item.get("field_index"))
|
||||
except Exception:
|
||||
continue
|
||||
if idx < 1 or idx > len(fields):
|
||||
continue
|
||||
if not bool(item.get("needs_clarification")):
|
||||
continue
|
||||
field_dict = fields[idx - 1].to_dict()
|
||||
qid = build_question_id(url=final_url, field=field_dict)
|
||||
cand_objs: list[Candidate] = []
|
||||
for c in (item.get("candidates") or [])[: max(1, min(10, top_k))]:
|
||||
try:
|
||||
cand_objs.append(
|
||||
Candidate(
|
||||
value=str(c.get("value", "")),
|
||||
confidence=float(c.get("confidence", 0.0)),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
records.append(
|
||||
PredictionRecord(
|
||||
id=qid,
|
||||
timestamp=utc_now_iso(),
|
||||
url=final_url,
|
||||
title=page_title,
|
||||
field_index=idx,
|
||||
field=field_dict,
|
||||
selected=str(item.get("selected", "")),
|
||||
confidence=float(item.get("confidence", 0.0) or 0.0),
|
||||
candidates=cand_objs,
|
||||
rationale=str(item.get("rationale", "")),
|
||||
needs_clarification=True,
|
||||
source="llm",
|
||||
)
|
||||
)
|
||||
if records:
|
||||
p = append_predictions(records, path=predictions_path)
|
||||
payload["recorded_predictions_path"] = str(p)
|
||||
except Exception as exc:
|
||||
payload["suggestions_error"] = str(exc)
|
||||
payload["predictions_error"] = str(exc)
|
||||
else:
|
||||
payload["suggestions_error"] = (
|
||||
"Ollama not available at http://localhost:11434"
|
||||
)
|
||||
payload["predictions_error"] = "Ollama not available at http://localhost:11434"
|
||||
|
||||
return Response(message=json.dumps(payload, indent=2, ensure_ascii=False), break_loop=False)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue