mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-24 22:04:03 +00:00
Agent persona profiles (#7)
* Add survey profile DB, parser, and background refiner Co-authored-by: nic <nicsins@users.noreply.github.com> * Block external survey domains unless allowlisted Co-authored-by: nic <nicsins@users.noreply.github.com> * Add persona creation tool for survey personas Co-authored-by: nic <nicsins@users.noreply.github.com> * Add local survey demo and documentation Co-authored-by: nic <nicsins@users.noreply.github.com> * Add profile update tool and email env fallback Co-authored-by: nic <nicsins@users.noreply.github.com> --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: nic <nicsins@users.noreply.github.com>
This commit is contained in:
parent
6a248a57cc
commit
37e696b104
20 changed files with 1618 additions and 0 deletions
|
|
@ -0,0 +1,8 @@
|
|||
from python.helpers.extension import Extension
|
||||
from python.surveys.profile_refiner import ensure_profile_refiner_running
|
||||
|
||||
|
||||
class StartProfileRefiner(Extension):
|
||||
async def execute(self, **kwargs):
|
||||
ensure_profile_refiner_running(self.agent)
|
||||
|
||||
|
|
@ -275,6 +275,10 @@ class Browser:
|
|||
clean_dom = self.strip_html_dom(full_dom)
|
||||
return self.process_html_with_selectors(clean_dom)
|
||||
|
||||
async def get_url(self) -> str:
|
||||
await self._check_page()
|
||||
return self.page.url
|
||||
|
||||
async def click(self, selector: str):
|
||||
await self._check_page()
|
||||
ctx, selector = self._parse_selector(selector)
|
||||
|
|
@ -310,6 +314,23 @@ class Browser:
|
|||
await ctx.fill(selector, text, force=True, timeout=Browser.interact_timeout)
|
||||
await self.wait_tick()
|
||||
|
||||
async def select(self, selector: str, value_or_label: str):
|
||||
"""Select option in <select> by value or label text."""
|
||||
await self._check_page()
|
||||
ctx, selector = self._parse_selector(selector)
|
||||
self.last_selector = selector
|
||||
try:
|
||||
await self.click(selector)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Prefer selecting by label; fall back to value.
|
||||
try:
|
||||
await ctx.select_option(selector, label=value_or_label, timeout=Browser.interact_timeout) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
await ctx.select_option(selector, value=value_or_label, timeout=Browser.interact_timeout) # type: ignore[arg-type]
|
||||
await self.wait_tick()
|
||||
|
||||
async def execute(self, js_code: str):
|
||||
await self._check_page()
|
||||
result = await self.page.evaluate(js_code)
|
||||
|
|
|
|||
7
python/surveys/__init__.py
Normal file
7
python/surveys/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Survey automation and profile/persona refinement utilities.
|
||||
|
||||
This package is intentionally self-contained so it can be used both:
|
||||
- from Agent Zero tools (python/tools/*.py)
|
||||
- from standalone scripts/tests
|
||||
"""
|
||||
|
||||
220
python/surveys/answerer.py
Normal file
220
python/surveys/answerer.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from agent import Agent
|
||||
from python.helpers import dotenv
|
||||
|
||||
from .schemas import AnswerAction, FieldKind, Persona, SurveyField, SurveyPage, UserProfile
|
||||
|
||||
|
||||
_WS = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _norm(s: str | None) -> str:
|
||||
return _WS.sub(" ", (s or "").strip()).lower()
|
||||
|
||||
|
||||
def _profile_get(profile: UserProfile | None, *keys: str) -> str | None:
|
||||
if not profile:
|
||||
return None
|
||||
data: Any = profile.data
|
||||
for k in keys:
|
||||
if not isinstance(data, dict) or k not in data:
|
||||
return None
|
||||
data = data[k]
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, (str, int, float, bool)):
|
||||
return str(data)
|
||||
return None
|
||||
|
||||
|
||||
def _best_option(options: list[str], want: str | None) -> str | None:
|
||||
if not options:
|
||||
return None
|
||||
if not want:
|
||||
return options[0]
|
||||
w = _norm(want)
|
||||
# exact contains
|
||||
for opt in options:
|
||||
if w and w in _norm(opt):
|
||||
return opt
|
||||
# fuzzy: shared tokens
|
||||
w_tokens = set(_norm(want).split())
|
||||
scored = []
|
||||
for opt in options:
|
||||
o_tokens = set(_norm(opt).split())
|
||||
score = len(w_tokens & o_tokens)
|
||||
scored.append((score, opt))
|
||||
scored.sort(reverse=True)
|
||||
return scored[0][1] if scored else options[0]
|
||||
|
||||
|
||||
def _infer_value(field: SurveyField, profile: UserProfile | None) -> str | None:
|
||||
label = _norm(field.label)
|
||||
placeholder = _norm(field.placeholder)
|
||||
hay = f"{label} {placeholder} {_norm(field.name)}".strip()
|
||||
|
||||
# Common mappings
|
||||
if any(k in hay for k in ("email", "e-mail")):
|
||||
return (
|
||||
_profile_get(profile, "contact", "email")
|
||||
or _profile_get(profile, "email")
|
||||
or str(dotenv.get_dotenv_value("A0_PROFILE_EMAIL", "") or "").strip()
|
||||
or None
|
||||
)
|
||||
if any(k in hay for k in ("first name", "firstname", "given name")):
|
||||
return _profile_get(profile, "name", "first") or _profile_get(profile, "first_name")
|
||||
if any(k in hay for k in ("last name", "lastname", "surname", "family name")):
|
||||
return _profile_get(profile, "name", "last") or _profile_get(profile, "last_name")
|
||||
if "name" in hay:
|
||||
return _profile_get(profile, "name", "full") or _profile_get(profile, "full_name")
|
||||
if any(k in hay for k in ("age", "how old")):
|
||||
return _profile_get(profile, "demographics", "age") or _profile_get(profile, "age")
|
||||
if any(k in hay for k in ("gender", "sex")):
|
||||
return _profile_get(profile, "demographics", "gender") or _profile_get(profile, "gender")
|
||||
if any(k in hay for k in ("country", "nation")):
|
||||
return _profile_get(profile, "demographics", "country") or _profile_get(profile, "country")
|
||||
if any(k in hay for k in ("city", "town")):
|
||||
return _profile_get(profile, "demographics", "city") or _profile_get(profile, "city")
|
||||
if any(k in hay for k in ("zip", "postal")):
|
||||
return _profile_get(profile, "demographics", "postal_code") or _profile_get(
|
||||
profile, "postal_code"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _button_score(field: SurveyField) -> int:
|
||||
t = _norm(field.label)
|
||||
if not t:
|
||||
return -999
|
||||
if any(k in t for k in ("submit", "finish", "complete")):
|
||||
return 100
|
||||
if any(k in t for k in ("next", "continue", "proceed")):
|
||||
return 90
|
||||
if any(k in t for k in ("start", "begin")):
|
||||
return 80
|
||||
if any(k in t for k in ("ok", "done")):
|
||||
return 70
|
||||
if any(k in t for k in ("back", "previous", "cancel")):
|
||||
return -10
|
||||
return 0
|
||||
|
||||
|
||||
async def answer_page(
|
||||
agent: Agent,
|
||||
page: SurveyPage,
|
||||
profile: UserProfile | None,
|
||||
persona: Persona | None,
|
||||
*,
|
||||
use_llm: bool = True,
|
||||
seed: int | None = None,
|
||||
) -> list[AnswerAction]:
|
||||
"""Return browser actions to answer one survey page.
|
||||
|
||||
LLM is used as a refinement layer; heuristics provide the baseline plan.
|
||||
"""
|
||||
|
||||
rng = random.Random(seed)
|
||||
|
||||
actions: list[AnswerAction] = []
|
||||
buttons: list[SurveyField] = []
|
||||
|
||||
for f in page.fields:
|
||||
if f.kind == FieldKind.BUTTON:
|
||||
buttons.append(f)
|
||||
continue
|
||||
|
||||
if f.kind in {FieldKind.TEXT, FieldKind.TEXTAREA, FieldKind.EMAIL, FieldKind.NUMBER, FieldKind.DATE}:
|
||||
v = _infer_value(f, profile)
|
||||
if v:
|
||||
actions.append(AnswerAction(action="fill", selector=f.selector, text=v, meta={"label": f.label}))
|
||||
continue
|
||||
|
||||
if f.kind == FieldKind.SELECT:
|
||||
want = _infer_value(f, profile)
|
||||
opt = _best_option(f.options, want)
|
||||
if opt and f.selector:
|
||||
actions.append(AnswerAction(action="select", selector=f.selector, text=opt, meta={"label": f.label}))
|
||||
continue
|
||||
|
||||
if f.kind in {FieldKind.RADIO, FieldKind.CHECKBOX}:
|
||||
want = _infer_value(f, profile)
|
||||
opt = _best_option(f.options, want)
|
||||
if opt and f.option_selectors.get(opt):
|
||||
actions.append(AnswerAction(action="click", selector=f.option_selectors[opt], meta={"label": f.label, "option": opt}))
|
||||
elif f.options:
|
||||
# choose a stable random option if we cannot match
|
||||
pick = rng.choice(f.options)
|
||||
sel = f.option_selectors.get(pick)
|
||||
if sel:
|
||||
actions.append(AnswerAction(action="click", selector=sel, meta={"label": f.label, "option": pick}))
|
||||
continue
|
||||
|
||||
# navigation: click best button last
|
||||
best_btn = None
|
||||
best_score = -999
|
||||
for b in buttons:
|
||||
s = _button_score(b)
|
||||
if s > best_score and b.selector:
|
||||
best_score = s
|
||||
best_btn = b
|
||||
if best_btn:
|
||||
actions.append(AnswerAction(action="click", selector=best_btn.selector, meta={"button": best_btn.label}))
|
||||
|
||||
if not use_llm:
|
||||
return [a for a in actions if (a.action == "press" or a.selector)]
|
||||
|
||||
# LLM refinement (optional): allow the utility model to adjust actions if it can.
|
||||
# This is intentionally constrained to a JSON list of actions.
|
||||
persona_txt = ""
|
||||
if persona:
|
||||
persona_txt = f"\nPersona name: {persona.name}\nPersona description: {persona.description}\nPersona constraints (JSON): {json.dumps(persona.constraints, ensure_ascii=False)}\n"
|
||||
|
||||
profile_txt = json.dumps(profile.data, ensure_ascii=False) if profile else "{}"
|
||||
system = (
|
||||
"You are a form-filling planner for online surveys. "
|
||||
"Return a JSON array of actions. "
|
||||
"Each action must be one of: fill, select, click, press. "
|
||||
"Use only selectors that exist in the provided DOM or option_selectors. "
|
||||
"Do not invent personal data; prefer values from profile_json. "
|
||||
"If unsure, keep the heuristic plan."
|
||||
)
|
||||
message = (
|
||||
f"URL: {page.url}\n"
|
||||
f"DOM:\n{page.raw_dom}\n\n"
|
||||
f"profile_json: {profile_txt}\n"
|
||||
f"{persona_txt}\n"
|
||||
f"heuristic_actions_json: {json.dumps([a.__dict__ for a in actions], ensure_ascii=False)}\n\n"
|
||||
"Return refined_actions_json only."
|
||||
)
|
||||
try:
|
||||
refined = await agent.call_utility_model(system=system, message=message, background=True)
|
||||
refined = refined.strip()
|
||||
data = json.loads(refined)
|
||||
out: list[AnswerAction] = []
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
act = item.get("action")
|
||||
if act not in {"fill", "select", "click", "press"}:
|
||||
continue
|
||||
out.append(
|
||||
AnswerAction(
|
||||
action=act,
|
||||
selector=item.get("selector"),
|
||||
text=item.get("text"),
|
||||
key=item.get("key"),
|
||||
meta=item.get("meta") if isinstance(item.get("meta"), dict) else {},
|
||||
)
|
||||
)
|
||||
return [a for a in out if (a.action == "press" or a.selector)]
|
||||
except Exception:
|
||||
return [a for a in actions if (a.action == "press" or a.selector)]
|
||||
|
||||
297
python/surveys/db.py
Normal file
297
python/surveys/db.py
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Iterable
|
||||
|
||||
from python.helpers import memory as memory_helper
|
||||
from agent import Agent
|
||||
|
||||
from .schemas import Persona, UserProfile, SurveyField
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
class SurveyDB:
|
||||
"""Structured persistence for personas, profiles, and survey answers."""
|
||||
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
os.makedirs(os.path.dirname(self.path) or ".", exist_ok=True)
|
||||
self._conn = sqlite3.connect(self.path, check_same_thread=False)
|
||||
self._conn.execute("PRAGMA journal_mode=WAL;")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON;")
|
||||
self._init_schema()
|
||||
|
||||
@staticmethod
|
||||
def for_agent(agent: Agent) -> "SurveyDB":
|
||||
base = memory_helper.get_memory_subdir_abs(agent)
|
||||
path = os.path.join(base, "survey_profiles.db")
|
||||
return SurveyDB(path)
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
cur = self._conn.cursor()
|
||||
cur.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS personas (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT NOT NULL,
|
||||
constraints_json TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS profiles (
|
||||
id TEXT PRIMARY KEY,
|
||||
persona_id TEXT NULL REFERENCES personas(id) ON DELETE SET NULL,
|
||||
data_json TEXT NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS survey_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
url TEXT NOT NULL,
|
||||
persona_id TEXT NULL REFERENCES personas(id) ON DELETE SET NULL,
|
||||
profile_id TEXT NULL REFERENCES profiles(id) ON DELETE SET NULL,
|
||||
started_at INTEGER NOT NULL,
|
||||
completed_at INTEGER NULL,
|
||||
status TEXT NOT NULL,
|
||||
notes TEXT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS survey_answers (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL REFERENCES survey_sessions(id) ON DELETE CASCADE,
|
||||
question_text TEXT NULL,
|
||||
field_kind TEXT NOT NULL,
|
||||
selector TEXT NULL,
|
||||
answer_text TEXT NOT NULL,
|
||||
field_json TEXT NOT NULL,
|
||||
raw_json TEXT NULL,
|
||||
processed INTEGER NOT NULL DEFAULT 0,
|
||||
created_at INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_survey_answers_processed
|
||||
ON survey_answers(processed, created_at);
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
# ---- Persona/profile API -------------------------------------------------
|
||||
|
||||
def upsert_persona(self, persona: Persona) -> None:
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO personas(id, name, description, constraints_json, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
name=excluded.name,
|
||||
description=excluded.description,
|
||||
constraints_json=excluded.constraints_json;
|
||||
""",
|
||||
(
|
||||
persona.id,
|
||||
persona.name,
|
||||
persona.description,
|
||||
json.dumps(persona.constraints or {}, ensure_ascii=False),
|
||||
_now_ts(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_persona(self, persona_id: str) -> Persona | None:
|
||||
row = self._conn.execute(
|
||||
"SELECT id, name, description, constraints_json FROM personas WHERE id=?",
|
||||
(persona_id,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return Persona(
|
||||
id=row[0],
|
||||
name=row[1],
|
||||
description=row[2],
|
||||
constraints=json.loads(row[3] or "{}"),
|
||||
)
|
||||
|
||||
def upsert_profile(self, profile: UserProfile) -> None:
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO profiles(id, persona_id, data_json, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
persona_id=excluded.persona_id,
|
||||
data_json=excluded.data_json,
|
||||
updated_at=excluded.updated_at;
|
||||
""",
|
||||
(
|
||||
profile.id,
|
||||
profile.persona_id,
|
||||
json.dumps(profile.data or {}, ensure_ascii=False),
|
||||
_now_ts(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_profile(self, profile_id: str) -> UserProfile | None:
|
||||
row = self._conn.execute(
|
||||
"SELECT id, persona_id, data_json FROM profiles WHERE id=?", (profile_id,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
return UserProfile(
|
||||
id=row[0],
|
||||
persona_id=row[1],
|
||||
data=json.loads(row[2] or "{}"),
|
||||
)
|
||||
|
||||
# ---- Survey session/answers API -----------------------------------------
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str,
|
||||
persona_id: str | None,
|
||||
profile_id: str | None,
|
||||
status: str = "running",
|
||||
notes: str | None = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO survey_sessions(id, url, persona_id, profile_id, started_at, status, notes)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(session_id, url, persona_id, profile_id, _now_ts(), status, notes),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def complete_session(self, session_id: str, status: str = "completed") -> None:
|
||||
self._conn.execute(
|
||||
"""
|
||||
UPDATE survey_sessions
|
||||
SET completed_at=?, status=?
|
||||
WHERE id=?
|
||||
""",
|
||||
(_now_ts(), status, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def insert_answer(
|
||||
self,
|
||||
answer_id: str,
|
||||
session_id: str,
|
||||
question_text: str | None,
|
||||
field: SurveyField,
|
||||
answer_text: str,
|
||||
raw: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO survey_answers(
|
||||
id, session_id, question_text, field_kind, selector, answer_text,
|
||||
field_json, raw_json, processed, created_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, 0, ?)
|
||||
""",
|
||||
(
|
||||
answer_id,
|
||||
session_id,
|
||||
question_text,
|
||||
field.kind.value,
|
||||
field.selector,
|
||||
answer_text,
|
||||
json.dumps(asdict(field), ensure_ascii=False),
|
||||
json.dumps(raw, ensure_ascii=False) if raw else None,
|
||||
_now_ts(),
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def fetch_unprocessed_answers(
|
||||
self, limit: int = 200
|
||||
) -> list[dict[str, Any]]:
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT id, session_id, question_text, field_kind, selector, answer_text, field_json, raw_json, created_at
|
||||
FROM survey_answers
|
||||
WHERE processed=0
|
||||
ORDER BY created_at ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
results: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"session_id": row[1],
|
||||
"question_text": row[2],
|
||||
"field_kind": row[3],
|
||||
"selector": row[4],
|
||||
"answer_text": row[5],
|
||||
"field": json.loads(row[6] or "{}"),
|
||||
"raw": json.loads(row[7]) if row[7] else None,
|
||||
"created_at": row[8],
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def fetch_unprocessed_answer_events(self, limit: int = 200) -> list[dict[str, Any]]:
|
||||
"""Fetch unprocessed answers with session context (url/profile/persona)."""
|
||||
rows = self._conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
a.id, a.session_id, a.question_text, a.field_kind, a.selector, a.answer_text,
|
||||
a.field_json, a.raw_json, a.created_at,
|
||||
s.url, s.persona_id, s.profile_id
|
||||
FROM survey_answers a
|
||||
JOIN survey_sessions s ON s.id = a.session_id
|
||||
WHERE a.processed=0
|
||||
ORDER BY a.created_at ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
results: list[dict[str, Any]] = []
|
||||
for row in rows:
|
||||
results.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"session_id": row[1],
|
||||
"question_text": row[2],
|
||||
"field_kind": row[3],
|
||||
"selector": row[4],
|
||||
"answer_text": row[5],
|
||||
"field": json.loads(row[6] or "{}"),
|
||||
"raw": json.loads(row[7]) if row[7] else None,
|
||||
"created_at": row[8],
|
||||
"url": row[9],
|
||||
"persona_id": row[10],
|
||||
"profile_id": row[11],
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def mark_answers_processed(self, answer_ids: Iterable[str]) -> None:
|
||||
ids = list(answer_ids)
|
||||
if not ids:
|
||||
return
|
||||
cur = self._conn.cursor()
|
||||
cur.executemany(
|
||||
"UPDATE survey_answers SET processed=1 WHERE id=?",
|
||||
[(i,) for i in ids],
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
204
python/surveys/parser.py
Normal file
204
python/surveys/parser.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Iterable
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
|
||||
from .schemas import FieldKind, SurveyField, SurveyPage
|
||||
|
||||
|
||||
_SPACE_RE = re.compile(r"\s+")
|
||||
|
||||
|
||||
def _norm(s: str | None) -> str:
|
||||
return _SPACE_RE.sub(" ", (s or "").strip())
|
||||
|
||||
|
||||
def _el_text_near(el: Tag, max_len: int = 160) -> str:
|
||||
"""Best-effort label extraction for cleaned DOM (no <label for=>, many tags unwrapped)."""
|
||||
# 1) explicit label/placeholder attributes
|
||||
for attr in ("label", "aria-label", "placeholder", "name"):
|
||||
v = _norm(el.get(attr)) # type: ignore[arg-type]
|
||||
if v:
|
||||
return v[:max_len]
|
||||
|
||||
# 2) immediate text around within parent
|
||||
parent = el.parent if isinstance(el.parent, Tag) else None
|
||||
if parent:
|
||||
parts: list[str] = []
|
||||
for child in parent.children:
|
||||
if child is el:
|
||||
break
|
||||
if isinstance(child, str):
|
||||
t = _norm(child)
|
||||
if t:
|
||||
parts.append(t)
|
||||
elif isinstance(child, Tag):
|
||||
t = _norm(child.get_text(" ", strip=True))
|
||||
if t and child.name not in {"script", "style"}:
|
||||
parts.append(t)
|
||||
if parts:
|
||||
t = _norm(" ".join(parts[-3:]))
|
||||
if t:
|
||||
return t[:max_len]
|
||||
|
||||
# 3) fallback: nearest ancestor text
|
||||
anc = parent
|
||||
for _ in range(3):
|
||||
if not anc:
|
||||
break
|
||||
t = _norm(anc.get_text(" ", strip=True))
|
||||
if t:
|
||||
return t[:max_len]
|
||||
anc = anc.parent if isinstance(anc.parent, Tag) else None
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _field_kind(el: Tag) -> FieldKind:
|
||||
name = (el.name or "").lower()
|
||||
if name == "textarea":
|
||||
return FieldKind.TEXTAREA
|
||||
if name == "select":
|
||||
return FieldKind.SELECT
|
||||
if name == "button":
|
||||
return FieldKind.BUTTON
|
||||
if name == "input":
|
||||
t = (el.get("type") or "").lower()
|
||||
if t in {"text", ""}:
|
||||
return FieldKind.TEXT
|
||||
if t in {"email"}:
|
||||
return FieldKind.EMAIL
|
||||
if t in {"number", "tel"}:
|
||||
return FieldKind.NUMBER
|
||||
if t in {"date"}:
|
||||
return FieldKind.DATE
|
||||
if t in {"radio"}:
|
||||
return FieldKind.RADIO
|
||||
if t in {"checkbox"}:
|
||||
return FieldKind.CHECKBOX
|
||||
if t in {"submit", "button"}:
|
||||
return FieldKind.BUTTON
|
||||
return FieldKind.UNKNOWN
|
||||
|
||||
|
||||
def _iter_interactive(soup: BeautifulSoup) -> Iterable[Tag]:
|
||||
for tag in soup.find_all(["input", "textarea", "select", "button"]):
|
||||
if not isinstance(tag, Tag):
|
||||
continue
|
||||
sel = _norm(tag.get("selector")) # from browser helper
|
||||
if not sel:
|
||||
continue
|
||||
yield tag
|
||||
|
||||
|
||||
def parse_survey_page(clean_dom: str, url: str = "") -> SurveyPage:
|
||||
soup = BeautifulSoup(clean_dom or "", "html.parser")
|
||||
|
||||
title = None
|
||||
h1 = soup.find("h1")
|
||||
if isinstance(h1, Tag):
|
||||
title = _norm(h1.get_text(" ", strip=True)) or None
|
||||
|
||||
raw_fields: list[SurveyField] = []
|
||||
|
||||
# First pass: collect individual elements
|
||||
for el in _iter_interactive(soup):
|
||||
kind = _field_kind(el)
|
||||
selector = _norm(el.get("selector")) # type: ignore[arg-type]
|
||||
label = _norm(el.get("label")) or None
|
||||
if kind == FieldKind.BUTTON:
|
||||
label = label or _norm(el.get("value")) or _norm(el.get_text(" ", strip=True)) or None
|
||||
if kind in {FieldKind.RADIO, FieldKind.CHECKBOX} and not label:
|
||||
parent = el.parent if isinstance(el.parent, Tag) else None
|
||||
if parent:
|
||||
# For choice inputs, the immediate container text is often the option label.
|
||||
label = _norm(parent.get_text(" ", strip=True)) or None
|
||||
label = label or _el_text_near(el) or None
|
||||
placeholder = _norm(el.get("placeholder")) or None
|
||||
name = _norm(el.get("name")) or None
|
||||
|
||||
options: list[str] = []
|
||||
option_selectors: dict[str, str] = {}
|
||||
|
||||
if kind == FieldKind.SELECT:
|
||||
for opt in el.find_all("option"):
|
||||
if not isinstance(opt, Tag):
|
||||
continue
|
||||
t = _norm(opt.get_text(" ", strip=True))
|
||||
if t:
|
||||
options.append(t)
|
||||
# selector is the <select> itself; options are chosen via fill/select later
|
||||
|
||||
raw_fields.append(
|
||||
SurveyField(
|
||||
selector=selector,
|
||||
kind=kind,
|
||||
name=name,
|
||||
label=label,
|
||||
placeholder=placeholder,
|
||||
options=options,
|
||||
option_selectors=option_selectors,
|
||||
required=False,
|
||||
)
|
||||
)
|
||||
|
||||
# Second pass: group radio/checkbox inputs by name when possible,
|
||||
# creating a single field with option -> selector mapping.
|
||||
grouped: list[SurveyField] = []
|
||||
radio_groups: dict[str, list[SurveyField]] = defaultdict(list)
|
||||
checkbox_groups: dict[str, list[SurveyField]] = defaultdict(list)
|
||||
passthrough: list[SurveyField] = []
|
||||
|
||||
for f in raw_fields:
|
||||
if f.kind == FieldKind.RADIO and f.name:
|
||||
radio_groups[f.name].append(f)
|
||||
elif f.kind == FieldKind.CHECKBOX and f.name:
|
||||
checkbox_groups[f.name].append(f)
|
||||
else:
|
||||
passthrough.append(f)
|
||||
|
||||
def _collapse(groups: dict[str, list[SurveyField]], kind: FieldKind) -> list[SurveyField]:
|
||||
out: list[SurveyField] = []
|
||||
for name, items in groups.items():
|
||||
# Build option labels from each item's label; fallback to selector.
|
||||
options: list[str] = []
|
||||
option_selectors: dict[str, str] = {}
|
||||
group_label = None
|
||||
for it in items:
|
||||
opt_label = _norm(it.label) or it.selector
|
||||
if opt_label and opt_label not in option_selectors:
|
||||
options.append(opt_label)
|
||||
option_selectors[opt_label] = it.selector
|
||||
group_label = group_label or it.label
|
||||
|
||||
if not options:
|
||||
# keep items as-is if we failed to collapse
|
||||
out.extend(items)
|
||||
continue
|
||||
|
||||
out.append(
|
||||
SurveyField(
|
||||
selector="", # group field has no single selector; choose via option_selectors
|
||||
kind=kind,
|
||||
name=name,
|
||||
label=group_label,
|
||||
placeholder=None,
|
||||
options=options,
|
||||
option_selectors=option_selectors,
|
||||
required=False,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
grouped.extend(passthrough)
|
||||
grouped.extend(_collapse(radio_groups, FieldKind.RADIO))
|
||||
grouped.extend(_collapse(checkbox_groups, FieldKind.CHECKBOX))
|
||||
|
||||
# Stable ordering: keep buttons last (helps answerer focus on fields first).
|
||||
grouped.sort(key=lambda f: (f.kind == FieldKind.BUTTON, f.kind.value, f.label or "", f.name or "", f.selector))
|
||||
|
||||
return SurveyPage(url=url, title=title, fields=grouped, raw_dom=clean_dom or "")
|
||||
|
||||
127
python/surveys/profile_refiner.py
Normal file
127
python/surveys/profile_refiner.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from agent import Agent
|
||||
from python.helpers.defer import DeferredTask
|
||||
|
||||
from .db import SurveyDB
|
||||
from .schemas import UserProfile
|
||||
|
||||
|
||||
def _deep_merge(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
|
||||
for k, v in src.items():
|
||||
if isinstance(v, dict) and isinstance(dst.get(k), dict):
|
||||
dst[k] = _deep_merge(dst[k], v) # type: ignore[arg-type]
|
||||
else:
|
||||
dst[k] = v
|
||||
return dst
|
||||
|
||||
|
||||
class ProfileRefinerService:
|
||||
"""Background worker that turns survey answers into structured profile updates."""
|
||||
|
||||
DATA_KEY = "_survey_profile_refiner"
|
||||
|
||||
def __init__(self, agent: Agent):
|
||||
self.agent = agent
|
||||
self.task: DeferredTask | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
if self.task and self.task.is_alive():
|
||||
return
|
||||
self.task = DeferredTask(thread_name=f"ProfileRefiner-{self.agent.context.id}")
|
||||
if self.agent.context.task:
|
||||
self.agent.context.task.add_child_task(self.task, terminate_thread=True)
|
||||
self.task.start_task(self._run_loop)
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.task:
|
||||
self.task.kill(terminate_thread=True)
|
||||
self.task = None
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
db = SurveyDB.for_agent(self.agent)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Do not interfere with active survey filling.
|
||||
if self.agent.get_data("_survey_active"):
|
||||
continue
|
||||
|
||||
events = db.fetch_unprocessed_answer_events(limit=200)
|
||||
if not events:
|
||||
continue
|
||||
|
||||
# Group by profile_id (fallback to "default").
|
||||
grouped: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for e in events:
|
||||
pid = e.get("profile_id") or "default"
|
||||
grouped[pid].append(e)
|
||||
|
||||
processed_ids: list[str] = []
|
||||
for profile_id, evs in grouped.items():
|
||||
profile = db.get_profile(profile_id) or UserProfile(
|
||||
id=profile_id, persona_id=None, data={}
|
||||
)
|
||||
|
||||
# Prepare a compact evidence block.
|
||||
lines = []
|
||||
for e in evs[:60]:
|
||||
q = (e.get("question_text") or "").strip()
|
||||
a = (e.get("answer_text") or "").strip()
|
||||
if not q:
|
||||
q = e.get("selector") or e.get("field_kind") or "question"
|
||||
if q and a:
|
||||
lines.append(f"- Q: {q}\n A: {a}")
|
||||
|
||||
system = (
|
||||
"You refine a user profile from survey answers.\n"
|
||||
"Output ONLY valid JSON.\n"
|
||||
"Return an object with keys:\n"
|
||||
"- profile_patch: object (deep-merge patch)\n"
|
||||
"- extracted_facts: array of short strings\n"
|
||||
"Rules:\n"
|
||||
"- Prefer stable fields: demographics, contact, preferences, traits.\n"
|
||||
"- If unsure, add to notes instead of guessing.\n"
|
||||
)
|
||||
message = (
|
||||
f"current_profile_json: {json.dumps(profile.data or {}, ensure_ascii=False)}\n\n"
|
||||
f"survey_answers:\n{chr(10).join(lines)}\n"
|
||||
)
|
||||
|
||||
try:
|
||||
out = await self.agent.call_utility_model(
|
||||
system=system, message=message, background=True
|
||||
)
|
||||
data = json.loads(out)
|
||||
patch = data.get("profile_patch") if isinstance(data, dict) else None
|
||||
if isinstance(patch, dict):
|
||||
profile.data = _deep_merge(profile.data or {}, patch)
|
||||
db.upsert_profile(profile)
|
||||
processed_ids.extend([e["id"] for e in evs])
|
||||
except Exception:
|
||||
# If parsing fails, do not mark processed.
|
||||
continue
|
||||
|
||||
if processed_ids:
|
||||
db.mark_answers_processed(processed_ids)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def ensure_profile_refiner_running(agent: Agent) -> ProfileRefinerService:
|
||||
svc = agent.get_data(ProfileRefinerService.DATA_KEY)
|
||||
if isinstance(svc, ProfileRefinerService):
|
||||
svc.start()
|
||||
return svc
|
||||
svc = ProfileRefinerService(agent)
|
||||
agent.set_data(ProfileRefinerService.DATA_KEY, svc)
|
||||
svc.start()
|
||||
return svc
|
||||
|
||||
71
python/surveys/schemas.py
Normal file
71
python/surveys/schemas.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
class FieldKind(str, Enum):
|
||||
TEXT = "text"
|
||||
TEXTAREA = "textarea"
|
||||
EMAIL = "email"
|
||||
NUMBER = "number"
|
||||
DATE = "date"
|
||||
SELECT = "select"
|
||||
RADIO = "radio"
|
||||
CHECKBOX = "checkbox"
|
||||
BUTTON = "button"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Persona:
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
constraints: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserProfile:
|
||||
id: str
|
||||
persona_id: str | None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SurveyField:
|
||||
selector: str
|
||||
kind: FieldKind
|
||||
name: str | None = None
|
||||
label: str | None = None
|
||||
placeholder: str | None = None
|
||||
options: list[str] = field(default_factory=list) # for select/radio/checkbox
|
||||
option_selectors: dict[str, str] = field(default_factory=dict)
|
||||
required: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SurveyPage:
|
||||
url: str
|
||||
title: str | None
|
||||
fields: list[SurveyField]
|
||||
raw_dom: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnswerAction:
|
||||
action: Literal["fill", "click", "press", "select"]
|
||||
selector: str | None = None
|
||||
text: str | None = None
|
||||
key: str | None = None
|
||||
meta: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SurveyAnswer:
|
||||
field: SurveyField
|
||||
answer_text: str
|
||||
confidence: float = 0.5
|
||||
rationale: str | None = None
|
||||
|
||||
1
python/surveys/tests/__init__.py
Normal file
1
python/surveys/tests/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
72
python/surveys/tests/test_parser_db.py
Normal file
72
python/surveys/tests/test_parser_db.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import unittest
|
||||
|
||||
from python.surveys.db import SurveyDB
|
||||
from python.surveys.parser import parse_survey_page
|
||||
from python.surveys.schemas import FieldKind, Persona, SurveyField, UserProfile
|
||||
from python.surveys.profile_refiner import _deep_merge
|
||||
|
||||
|
||||
class TestSurveyParser(unittest.TestCase):
|
||||
def test_groups_radio_checkbox(self):
|
||||
dom = """
|
||||
<h1>Demo</h1>
|
||||
<div>Preferred mobile platform</div>
|
||||
<div><input type="radio" name="mobile" selector="1a" /> Android</div>
|
||||
<div><input type="radio" name="mobile" selector="2a" /> iOS</div>
|
||||
<div>Topics</div>
|
||||
<div><input type="checkbox" name="topics" selector="3a" /> Music</div>
|
||||
<div><input type="checkbox" name="topics" selector="4a" /> Sports</div>
|
||||
"""
|
||||
page = parse_survey_page(dom, url="file://demo")
|
||||
radios = [f for f in page.fields if f.kind == FieldKind.RADIO]
|
||||
checks = [f for f in page.fields if f.kind == FieldKind.CHECKBOX]
|
||||
self.assertEqual(len(radios), 1)
|
||||
self.assertEqual(len(checks), 1)
|
||||
self.assertIn("android", " ".join(radios[0].options).lower())
|
||||
self.assertTrue(radios[0].option_selectors)
|
||||
self.assertEqual(set(checks[0].option_selectors.values()), {"3a", "4a"})
|
||||
|
||||
|
||||
class TestSurveyDB(unittest.TestCase):
|
||||
def test_db_roundtrip_and_events(self):
|
||||
db = SurveyDB(":memory:")
|
||||
try:
|
||||
persona = Persona(id="p1", name="Test", description="x", constraints={"a": 1})
|
||||
db.upsert_persona(persona)
|
||||
got = db.get_persona("p1")
|
||||
self.assertIsNotNone(got)
|
||||
self.assertEqual(got.name, "Test")
|
||||
|
||||
profile = UserProfile(id="default", persona_id="p1", data={"demographics": {"country": "DE"}})
|
||||
db.upsert_profile(profile)
|
||||
|
||||
db.create_session("s1", url="file://demo", persona_id="p1", profile_id="default")
|
||||
field = SurveyField(selector="1a", kind=FieldKind.TEXT, label="Email")
|
||||
db.insert_answer("a1", "s1", "Email", field, "test@example.com")
|
||||
|
||||
evs = db.fetch_unprocessed_answer_events()
|
||||
self.assertEqual(len(evs), 1)
|
||||
self.assertEqual(evs[0]["profile_id"], "default")
|
||||
self.assertEqual(evs[0]["persona_id"], "p1")
|
||||
self.assertEqual(evs[0]["answer_text"], "test@example.com")
|
||||
|
||||
db.mark_answers_processed(["a1"])
|
||||
self.assertEqual(db.fetch_unprocessed_answer_events(), [])
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class TestDeepMerge(unittest.TestCase):
|
||||
def test_deep_merge(self):
|
||||
dst = {"a": {"b": 1}, "x": 1}
|
||||
src = {"a": {"c": 2}, "x": 2, "y": 3}
|
||||
out = _deep_merge(dst, src)
|
||||
self.assertEqual(out["a"]["b"], 1)
|
||||
self.assertEqual(out["a"]["c"], 2)
|
||||
self.assertEqual(out["x"], 2)
|
||||
self.assertEqual(out["y"], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
65
python/tools/persona_create.py
Normal file
65
python/tools/persona_create.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import json
|
||||
import uuid
|
||||
|
||||
from python.helpers.tool import Tool, Response
|
||||
from python.surveys.db import SurveyDB
|
||||
from python.surveys.schemas import Persona
|
||||
|
||||
|
||||
class PersonaCreate(Tool):
|
||||
async def execute(
|
||||
self,
|
||||
name: str = "",
|
||||
description: str = "",
|
||||
constraints_json: str = "",
|
||||
generate: bool = False,
|
||||
seed: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create (and store) a persona used for survey answering.
|
||||
|
||||
If generate=true, the utility model will draft name/description/constraints using 'seed'.
|
||||
"""
|
||||
db = SurveyDB.for_agent(self.agent)
|
||||
try:
|
||||
constraints = {}
|
||||
if constraints_json:
|
||||
constraints = json.loads(constraints_json)
|
||||
|
||||
if generate:
|
||||
system = (
|
||||
"You design a survey-answering persona for testing.\n"
|
||||
"Output ONLY valid JSON with keys: name, description, constraints.\n"
|
||||
"Constraints must be a JSON object with stable fields (demographics, preferences, traits).\n"
|
||||
)
|
||||
msg = f"seed: {seed}\nexisting_name: {name}\nexisting_description: {description}\n"
|
||||
out = await self.agent.call_utility_model(system=system, message=msg, background=False)
|
||||
data = json.loads(out)
|
||||
if isinstance(data, dict):
|
||||
name = str(data.get("name") or name or "Persona")
|
||||
description = str(data.get("description") or description or "")
|
||||
if isinstance(data.get("constraints"), dict):
|
||||
constraints = data["constraints"]
|
||||
|
||||
if not name:
|
||||
name = "Persona"
|
||||
|
||||
persona = Persona(
|
||||
id=str(uuid.uuid4()),
|
||||
name=name,
|
||||
description=description or "",
|
||||
constraints=constraints or {},
|
||||
)
|
||||
db.upsert_persona(persona)
|
||||
return Response(
|
||||
message=json.dumps(
|
||||
{"persona_id": persona.id, "name": persona.name},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
57
python/tools/profile_update.py
Normal file
57
python/tools/profile_update.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import json
|
||||
|
||||
from python.helpers.tool import Tool, Response
|
||||
from python.surveys.db import SurveyDB
|
||||
from python.surveys.schemas import UserProfile
|
||||
from python.surveys.profile_refiner import _deep_merge
|
||||
|
||||
|
||||
class ProfileUpdate(Tool):
|
||||
async def execute(
|
||||
self,
|
||||
profile_id: str = "default",
|
||||
patch_json: str = "",
|
||||
persona_id: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update (deep-merge) the structured survey profile stored in SQLite.
|
||||
|
||||
- profile_id: profile key (default "default")
|
||||
- patch_json: JSON object to deep-merge into profile.data
|
||||
- persona_id: optional persona association
|
||||
"""
|
||||
if not patch_json:
|
||||
return Response(
|
||||
message="patch_json is required and must be a JSON object.",
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
try:
|
||||
patch = json.loads(patch_json)
|
||||
except Exception as e:
|
||||
return Response(message=f"Invalid patch_json: {e}", break_loop=False)
|
||||
|
||||
if not isinstance(patch, dict):
|
||||
return Response(message="patch_json must be a JSON object.", break_loop=False)
|
||||
|
||||
db = SurveyDB.for_agent(self.agent)
|
||||
try:
|
||||
profile = db.get_profile(profile_id) or UserProfile(
|
||||
id=profile_id, persona_id=(persona_id or None), data={}
|
||||
)
|
||||
if persona_id:
|
||||
profile.persona_id = persona_id
|
||||
profile.data = _deep_merge(profile.data or {}, patch)
|
||||
db.upsert_profile(profile)
|
||||
return Response(
|
||||
message=json.dumps(
|
||||
{"profile_id": profile.id, "updated_keys": sorted(patch.keys())},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
243
python/tools/survey_fill.py
Normal file
243
python/tools/survey_fill.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
import json
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from python.tools.browser import Browser
|
||||
from python.helpers.tool import Response
|
||||
from python.helpers import dotenv
|
||||
|
||||
from python.surveys.answerer import answer_page
|
||||
from python.surveys.db import SurveyDB
|
||||
from python.surveys.parser import parse_survey_page
|
||||
from python.surveys.schemas import Persona, UserProfile
|
||||
|
||||
|
||||
def _is_local_url(url: str) -> bool:
|
||||
if not url:
|
||||
return True
|
||||
if url.startswith("file://"):
|
||||
return True
|
||||
try:
|
||||
host = (urlparse(url).hostname or "").lower()
|
||||
return host in {"localhost", "127.0.0.1"}
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _host(url: str) -> str:
|
||||
try:
|
||||
return (urlparse(url).hostname or "").lower()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _allowed_external_host(host: str) -> bool:
|
||||
# Comma-separated allowlist, e.g. "surveys.mycompany.com,forms.example.org"
|
||||
raw = str(dotenv.get_dotenv_value("A0_SURVEY_ALLOWED_DOMAINS", "") or "")
|
||||
allowed = {h.strip().lower() for h in raw.split(",") if h.strip()}
|
||||
if not allowed:
|
||||
return False
|
||||
return host in allowed
|
||||
|
||||
|
||||
class SurveyFill(Browser):
|
||||
async def execute(
|
||||
self,
|
||||
url: str = "",
|
||||
profile_id: str = "default",
|
||||
persona_id: str = "",
|
||||
allow_persona: bool = False,
|
||||
allow_external: bool = False,
|
||||
max_pages: int = 12,
|
||||
use_llm: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Fill out an online survey in the built-in browser.
|
||||
|
||||
Safety: persona usage is restricted by default; set allow_persona=true to enable.
|
||||
"""
|
||||
|
||||
await self.prepare_state()
|
||||
db = SurveyDB.for_agent(self.agent)
|
||||
|
||||
if url and not _is_local_url(url):
|
||||
host = _host(url)
|
||||
if not allow_external or not _allowed_external_host(host):
|
||||
msg = (
|
||||
"External survey domains are blocked by default.\n"
|
||||
"To enable for authorized/owned surveys, set env `A0_SURVEY_ALLOWED_DOMAINS` "
|
||||
"to a comma-separated allowlist and pass allow_external=true.\n"
|
||||
f"Blocked host: {host or '(unknown)'}"
|
||||
)
|
||||
self.log.update(error=msg)
|
||||
return Response(message=msg, break_loop=False)
|
||||
|
||||
if persona_id and not allow_persona and url and not _is_local_url(url):
|
||||
msg = (
|
||||
"Persona mode is disabled for non-local URLs by default. "
|
||||
"Re-run with allow_persona=true if you have explicit permission to answer as a persona."
|
||||
)
|
||||
self.log.update(error=msg)
|
||||
return Response(message=msg, break_loop=False)
|
||||
|
||||
persona: Persona | None = db.get_persona(persona_id) if persona_id else None
|
||||
profile = db.get_profile(profile_id) or UserProfile(
|
||||
id=profile_id, persona_id=(persona.id if persona else None), data={}
|
||||
)
|
||||
if persona and profile.persona_id != persona.id:
|
||||
profile.persona_id = persona.id
|
||||
db.upsert_profile(profile)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
if url:
|
||||
self.update_progress("Opening survey...")
|
||||
await self.state.browser.open(url)
|
||||
|
||||
current_url = await self.state.browser.get_url()
|
||||
db.create_session(
|
||||
session_id=session_id,
|
||||
url=current_url,
|
||||
persona_id=persona.id if persona else None,
|
||||
profile_id=profile.id,
|
||||
status="running",
|
||||
)
|
||||
|
||||
self.agent.set_data("_survey_active", True)
|
||||
actions_log = []
|
||||
error_text = None
|
||||
try:
|
||||
for i in range(max_pages):
|
||||
self.update_progress(f"Parsing page {i+1}/{max_pages}...")
|
||||
await self.state.browser.wait_for_action()
|
||||
dom = await self.state.browser.get_clean_dom()
|
||||
current_url = await self.state.browser.get_url()
|
||||
page = parse_survey_page(dom, url=current_url)
|
||||
|
||||
low = (dom or "").lower()
|
||||
if any(k in low for k in ("thank you", "thanks for completing", "response recorded")):
|
||||
break
|
||||
|
||||
self.update_progress("Planning answers...")
|
||||
plan = await answer_page(
|
||||
self.agent,
|
||||
page,
|
||||
profile,
|
||||
persona,
|
||||
use_llm=use_llm,
|
||||
seed=i,
|
||||
)
|
||||
|
||||
if not plan:
|
||||
break
|
||||
|
||||
self.update_progress("Answering...")
|
||||
for act in plan:
|
||||
if act.action == "fill" and act.selector and act.text is not None:
|
||||
await self.state.browser.fill(act.selector, act.text)
|
||||
db.insert_answer(
|
||||
answer_id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
question_text=act.meta.get("label") if isinstance(act.meta, dict) else None,
|
||||
field=_field_stub_for_action(act),
|
||||
answer_text=act.text,
|
||||
raw={"action": act.__dict__},
|
||||
)
|
||||
actions_log.append({"action": "fill", "selector": act.selector, "text": act.text})
|
||||
elif act.action == "select" and act.selector and act.text is not None:
|
||||
await self.state.browser.select(act.selector, act.text)
|
||||
db.insert_answer(
|
||||
answer_id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
question_text=act.meta.get("label") if isinstance(act.meta, dict) else None,
|
||||
field=_field_stub_for_action(act),
|
||||
answer_text=act.text,
|
||||
raw={"action": act.__dict__},
|
||||
)
|
||||
actions_log.append({"action": "select", "selector": act.selector, "text": act.text})
|
||||
elif act.action == "click" and act.selector:
|
||||
await self.state.browser.click(act.selector)
|
||||
val = act.meta.get("option") if isinstance(act.meta, dict) else None
|
||||
db.insert_answer(
|
||||
answer_id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
question_text=act.meta.get("label") if isinstance(act.meta, dict) else None,
|
||||
field=_field_stub_for_action(act),
|
||||
answer_text=str(val) if val else "clicked",
|
||||
raw={"action": act.__dict__},
|
||||
)
|
||||
actions_log.append({"action": "click", "selector": act.selector})
|
||||
elif act.action == "press" and act.key:
|
||||
await self.state.browser.press(act.key)
|
||||
actions_log.append({"action": "press", "key": act.key})
|
||||
|
||||
await self.state.browser.wait(0.25)
|
||||
|
||||
# If we clicked "next"/"submit", navigation may occur; give it a moment.
|
||||
await self.state.browser.wait_for_action()
|
||||
|
||||
db.complete_session(session_id, status="completed")
|
||||
except Exception as e:
|
||||
error_text = str(e)
|
||||
db.complete_session(session_id, status="error")
|
||||
finally:
|
||||
self.agent.set_data("_survey_active", False)
|
||||
db.close()
|
||||
|
||||
self.update_progress("Taking screenshot...")
|
||||
screenshot = await self.save_screenshot()
|
||||
self.log.update(screenshot=screenshot)
|
||||
if error_text:
|
||||
try:
|
||||
dom = await self.state.browser.get_clean_dom()
|
||||
except Exception:
|
||||
dom = ""
|
||||
payload = {
|
||||
"session_id": session_id,
|
||||
"url": current_url,
|
||||
"profile_id": profile.id,
|
||||
"persona_id": persona.id if persona else None,
|
||||
"error": error_text,
|
||||
"actions": actions_log[-40:],
|
||||
"dom_excerpt": dom[:4000],
|
||||
"screenshot": screenshot,
|
||||
}
|
||||
self.log.update(error=error_text)
|
||||
self.cleanup_history()
|
||||
self.update_progress("Done")
|
||||
return Response(
|
||||
message=json.dumps(payload, ensure_ascii=False, indent=2),
|
||||
break_loop=False,
|
||||
)
|
||||
self.cleanup_history()
|
||||
self.update_progress("Done")
|
||||
|
||||
payload = {
|
||||
"session_id": session_id,
|
||||
"url": current_url,
|
||||
"profile_id": profile.id,
|
||||
"persona_id": persona.id if persona else None,
|
||||
"actions": actions_log[-40:],
|
||||
"screenshot": screenshot,
|
||||
}
|
||||
return Response(message=json.dumps(payload, ensure_ascii=False, indent=2), break_loop=False)
|
||||
|
||||
|
||||
def _field_stub_for_action(act):
|
||||
# Minimal field record for DB; full parsed field is not always resolvable after refinement.
|
||||
from python.surveys.schemas import SurveyField, FieldKind
|
||||
|
||||
kind = FieldKind.UNKNOWN
|
||||
if act.action == "fill":
|
||||
kind = FieldKind.TEXT
|
||||
elif act.action == "select":
|
||||
kind = FieldKind.SELECT
|
||||
return SurveyField(
|
||||
selector=act.selector or "",
|
||||
kind=kind,
|
||||
label=act.meta.get("label") if isinstance(act.meta, dict) else None,
|
||||
options=[],
|
||||
option_selectors={},
|
||||
required=False,
|
||||
)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue