mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-19 16:31:30 +00:00
add built-in A0 CLI Connector plugin
Introduce the builtin `_a0_connector` plugin that lets the host-side A0 CLI connect to Agent Zero over authenticated HTTP and `/ws`. This adds connector capability discovery, chat/context lifecycle endpoints, log streaming, and the remote text editing, code execution, and file tree bridge used by the CLI workflow.
This commit is contained in:
parent
85654c6cc7
commit
8c5cf1f69f
36 changed files with 2702 additions and 0 deletions
0
plugins/_a0_connector/api/__init__.py
Normal file
0
plugins/_a0_connector/api/__init__.py
Normal file
0
plugins/_a0_connector/api/v1/__init__.py
Normal file
0
plugins/_a0_connector/api/v1/__init__.py
Normal file
15
plugins/_a0_connector/api/v1/agents_list.py
Normal file
15
plugins/_a0_connector/api/v1/agents_list.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/agents_list."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class AgentsList(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import subagents
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"data": subagents.get_all_agents_list(),
|
||||
}
|
||||
29
plugins/_a0_connector/api/v1/base.py
Normal file
29
plugins/_a0_connector/api/v1/base.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from helpers.api import ApiHandler
|
||||
|
||||
|
||||
class PublicConnectorApiHandler(ApiHandler):
|
||||
@classmethod
|
||||
def requires_auth(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def requires_csrf(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ProtectedConnectorApiHandler(ApiHandler):
|
||||
@classmethod
|
||||
def requires_auth(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def requires_csrf(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
return False
|
||||
86
plugins/_a0_connector/api/v1/capabilities.py
Normal file
86
plugins/_a0_connector/api/v1/capabilities.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/capabilities."""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
_BASE_FEATURES = [
|
||||
"chat_create",
|
||||
"chats_list",
|
||||
"chat_get",
|
||||
"chat_reset",
|
||||
"chat_delete",
|
||||
"pause",
|
||||
"nudge",
|
||||
"message_send",
|
||||
"log_tail",
|
||||
"projects",
|
||||
"text_editor_remote",
|
||||
"code_execution_remote",
|
||||
"remote_file_tree",
|
||||
"token_status",
|
||||
]
|
||||
|
||||
_OPTIONAL_FEATURES: dict[str, tuple[str, ...]] = {
|
||||
"settings_get": ("helpers.settings", "helpers.subagents"),
|
||||
"settings_set": ("helpers.settings", "helpers.subagents"),
|
||||
"agents_list": ("helpers.subagents",),
|
||||
"skills_list": ("helpers.skills", "helpers.files", "helpers.projects", "helpers.runtime"),
|
||||
"skills_delete": ("helpers.skills", "helpers.files", "helpers.projects", "helpers.runtime"),
|
||||
"model_presets": ("plugins._model_config.helpers.model_config",),
|
||||
"model_switcher": ("plugins._model_config.helpers.model_config",),
|
||||
"compact_chat": (
|
||||
"plugins._chat_compaction.helpers.compactor",
|
||||
"plugins._model_config.helpers.model_config",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _module_available(module_name: str) -> bool:
|
||||
if module_name in sys.modules:
|
||||
return True
|
||||
|
||||
try:
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
except (AttributeError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _feature_available(feature: str) -> bool:
|
||||
required = _OPTIONAL_FEATURES.get(feature, ())
|
||||
return all(_module_available(module_name) for module_name in required)
|
||||
|
||||
|
||||
def _feature_list() -> list[str]:
|
||||
features = list(_BASE_FEATURES)
|
||||
for feature in _OPTIONAL_FEATURES:
|
||||
if _feature_available(feature):
|
||||
features.append(feature)
|
||||
return features
|
||||
|
||||
|
||||
class Capabilities(connector_base.PublicConnectorApiHandler):
|
||||
"""Return the connector discovery contract for current Agent Zero."""
|
||||
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import login
|
||||
|
||||
return {
|
||||
"protocol": "a0-connector.v1",
|
||||
"version": "0.1.0",
|
||||
"auth": ["session"],
|
||||
"auth_required": bool(login.is_login_required()),
|
||||
"transports": ["http", "websocket"],
|
||||
"streaming": True,
|
||||
"websocket_namespace": "/ws",
|
||||
"websocket_handlers": ["plugins/_a0_connector/ws_connector"],
|
||||
"attachments": {
|
||||
"mode": "base64",
|
||||
"max_files": 20,
|
||||
},
|
||||
"features": _feature_list(),
|
||||
}
|
||||
40
plugins/_a0_connector/api/v1/chat_create.py
Normal file
40
plugins/_a0_connector/api/v1/chat_create.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/chat_create."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class ChatCreate(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import projects
|
||||
from plugins._a0_connector.helpers.chat_context import create_context
|
||||
|
||||
current_context_id = (
|
||||
str(input.get("current_context", input.get("current_context_id", ""))).strip()
|
||||
or None
|
||||
)
|
||||
project_name = str(input.get("project_name", "")).strip() or None
|
||||
agent_profile = str(input.get("agent_profile", "")).strip() or None
|
||||
|
||||
try:
|
||||
context = create_context(
|
||||
lock=self.thread_lock,
|
||||
current_context_id=current_context_id,
|
||||
agent_profile=agent_profile,
|
||||
project_name=project_name,
|
||||
)
|
||||
except Exception as exc:
|
||||
return Response(
|
||||
response=f'{{"error": "Failed to activate project: {str(exc)}"}}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context_data = context.output()
|
||||
return {
|
||||
"context_id": context.id,
|
||||
"created_at": context_data.get("created_at"),
|
||||
"agent_profile": agent_profile or getattr(context.agent0.config, "profile", "default"),
|
||||
"project_name": context.get_data(projects.CONTEXT_DATA_KEY_PROJECT),
|
||||
}
|
||||
39
plugins/_a0_connector/api/v1/chat_delete.py
Normal file
39
plugins/_a0_connector/api/v1/chat_delete.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/chat_delete."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class ChatDelete(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
from api.chat_remove import RemoveChat
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error": "context_id is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return Response(
|
||||
response='{"error": "Context not found"}',
|
||||
status=404,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
try:
|
||||
handler = RemoveChat(self.app, self.thread_lock)
|
||||
await handler.process({"context": context_id}, request)
|
||||
except Exception as exc:
|
||||
return Response(
|
||||
response=f'{{"error": "{str(exc)}"}}',
|
||||
status=500,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
return {"context_id": context_id, "status": "deleted"}
|
||||
46
plugins/_a0_connector/api/v1/chat_get.py
Normal file
46
plugins/_a0_connector/api/v1/chat_get.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/chat_get."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class ChatGet(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
from plugins._a0_connector.helpers.event_bridge import (
|
||||
get_context_log_entries,
|
||||
)
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error": "context_id is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return Response(
|
||||
response='{"error": "Context not found"}',
|
||||
status=404,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context_data = context.output()
|
||||
events, last_sequence = get_context_log_entries(context_id)
|
||||
|
||||
return {
|
||||
"context_id": context.id,
|
||||
"id": context.id,
|
||||
"name": context_data.get("name") or context.id,
|
||||
"created_at": context_data.get("created_at"),
|
||||
"last_message": context_data.get("last_message"),
|
||||
"running": context_data.get("running", False),
|
||||
"agent_profile": getattr(context.agent0.config, "profile", "default")
|
||||
if context.agent0
|
||||
else "default",
|
||||
"log_entries": len(events),
|
||||
"last_sequence": last_sequence,
|
||||
}
|
||||
31
plugins/_a0_connector/api/v1/chat_reset.py
Normal file
31
plugins/_a0_connector/api/v1/chat_reset.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/chat_reset."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class ChatReset(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
from api.chat_reset import Reset
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error": "context_id is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return Response(
|
||||
response='{"error": "Context not found"}',
|
||||
status=404,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
handler = Reset(self.app, self.thread_lock)
|
||||
await handler.process({"context": context_id}, request)
|
||||
return {"context_id": context_id, "status": "reset"}
|
||||
31
plugins/_a0_connector/api/v1/chats_list.py
Normal file
31
plugins/_a0_connector/api/v1/chats_list.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/chats_list."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class ChatsList(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
|
||||
contexts: list[dict[str, object]] = []
|
||||
for context in AgentContext.all():
|
||||
data = context.output()
|
||||
contexts.append(
|
||||
{
|
||||
"id": context.id,
|
||||
"name": data.get("name") or context.name or context.id,
|
||||
"created_at": data.get("created_at"),
|
||||
"last_message": data.get("last_message"),
|
||||
"running": data.get("running", False),
|
||||
"agent_profile": getattr(context.agent0.config, "profile", "default")
|
||||
if context.agent0
|
||||
else "default",
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"contexts": contexts,
|
||||
"chats": contexts,
|
||||
}
|
||||
75
plugins/_a0_connector/api/v1/compact_chat.py
Normal file
75
plugins/_a0_connector/api/v1/compact_chat.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/compact_chat."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
def _coerce_bool(value: object, default: bool = False) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
return default
|
||||
return normalized in {"1", "true", "yes", "on"}
|
||||
if value is None:
|
||||
return default
|
||||
return bool(value)
|
||||
|
||||
|
||||
async def _run_compaction_task(context, use_chat_model: bool, preset_name: str | None) -> None:
|
||||
from helpers.state_monitor_integration import mark_dirty_all
|
||||
from plugins._chat_compaction.helpers.compactor import run_compaction
|
||||
|
||||
try:
|
||||
await run_compaction(context, use_chat_model, preset_name)
|
||||
except Exception as exc:
|
||||
context.log.log(
|
||||
type="error",
|
||||
heading="Compaction Failed",
|
||||
content=str(exc),
|
||||
)
|
||||
mark_dirty_all(reason="plugins._a0_connector.compact_chat_error")
|
||||
|
||||
|
||||
class CompactChat(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
from plugins._chat_compaction.helpers.compactor import (
|
||||
MIN_COMPACTION_TOKENS,
|
||||
get_compaction_stats,
|
||||
)
|
||||
|
||||
action = str(input.get("action", "compact")).strip() or "compact"
|
||||
context_id = str(
|
||||
input.get("context", input.get("context_id", input.get("ctxid", "")))
|
||||
).strip()
|
||||
|
||||
if not context_id:
|
||||
return Response("Missing context id", 400)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if not context:
|
||||
return Response("Context not found", 404)
|
||||
|
||||
if context.is_running():
|
||||
return Response("Cannot compact while agent is running", 409)
|
||||
|
||||
stats = await get_compaction_stats(context)
|
||||
if stats["token_count"] < MIN_COMPACTION_TOKENS:
|
||||
return {
|
||||
"ok": False,
|
||||
"message": f"Not enough content to compact (minimum {MIN_COMPACTION_TOKENS:,} tokens)",
|
||||
}
|
||||
|
||||
if action == "stats":
|
||||
return {"ok": True, "stats": stats}
|
||||
|
||||
if action == "compact":
|
||||
use_chat_model = _coerce_bool(input.get("use_chat_model", True), default=True)
|
||||
preset_name = str(input.get("preset_name", "")).strip() or None
|
||||
context.run_task(_run_compaction_task, context, use_chat_model, preset_name)
|
||||
return {"ok": True, "message": "Compaction started"}
|
||||
|
||||
return Response(f"Unknown action: {action}", 400)
|
||||
32
plugins/_a0_connector/api/v1/log_tail.py
Normal file
32
plugins/_a0_connector/api/v1/log_tail.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/log_tail."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class LogTail(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from plugins._a0_connector.helpers.event_bridge import (
|
||||
get_context_log_entries,
|
||||
)
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error": "context_id is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
after = int(input.get("after", 0) or 0)
|
||||
limit = min(int(input.get("limit", 50) or 50), 250)
|
||||
|
||||
events, last_sequence = get_context_log_entries(context_id, after=after)
|
||||
limited_events = events[:limit]
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"events": limited_events,
|
||||
"last_sequence": last_sequence,
|
||||
"has_more": len(events) > len(limited_events),
|
||||
}
|
||||
120
plugins/_a0_connector/api/v1/message_send.py
Normal file
120
plugins/_a0_connector/api/v1/message_send.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/message_send."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from helpers.api import Request, Response
|
||||
from helpers.print_style import PrintStyle
|
||||
from helpers.security import safe_filename
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class MessageSend(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import UserMessage
|
||||
from helpers import files
|
||||
from plugins._a0_connector.helpers.chat_context import (
|
||||
ConnectorContextError,
|
||||
create_context,
|
||||
get_existing_context,
|
||||
)
|
||||
|
||||
message = str(input.get("message", "")).strip()
|
||||
if not message:
|
||||
return Response(
|
||||
response='{"error": "message is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip() or None
|
||||
current_context_id = (
|
||||
str(input.get("current_context", input.get("current_context_id", ""))).strip()
|
||||
or None
|
||||
)
|
||||
project_name = str(input.get("project_name", "")).strip() or None
|
||||
agent_profile = str(input.get("agent_profile", "")).strip() or None
|
||||
attachments_data = input.get("attachments", [])
|
||||
|
||||
attachment_paths: list[str] = []
|
||||
if isinstance(attachments_data, list) and attachments_data:
|
||||
upload_folder_ext = files.get_abs_path("usr/uploads")
|
||||
upload_folder_int = "/a0/usr/uploads"
|
||||
os.makedirs(upload_folder_ext, exist_ok=True)
|
||||
|
||||
for attachment in attachments_data:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
filename = str(attachment.get("filename", "")).strip()
|
||||
b64_content = str(attachment.get("base64", "")).strip()
|
||||
if not filename or not b64_content:
|
||||
continue
|
||||
|
||||
try:
|
||||
safe_name = safe_filename(filename)
|
||||
if not safe_name:
|
||||
continue
|
||||
save_path = os.path.join(upload_folder_ext, safe_name)
|
||||
with open(save_path, "wb") as handle:
|
||||
handle.write(base64.b64decode(b64_content))
|
||||
attachment_paths.append(os.path.join(upload_folder_int, safe_name))
|
||||
except Exception as exc:
|
||||
PrintStyle.error(f"[a0-connector] attachment error: {exc}")
|
||||
|
||||
try:
|
||||
if context_id:
|
||||
context = get_existing_context(
|
||||
context_id,
|
||||
agent_profile=agent_profile,
|
||||
project_name=project_name,
|
||||
)
|
||||
else:
|
||||
context = create_context(
|
||||
lock=self.thread_lock,
|
||||
current_context_id=current_context_id,
|
||||
agent_profile=agent_profile,
|
||||
project_name=project_name,
|
||||
)
|
||||
context_id = context.id
|
||||
except ConnectorContextError as exc:
|
||||
return Response(
|
||||
response=f'{{"error": "{str(exc)}"}}',
|
||||
status=exc.status_code,
|
||||
mimetype="application/json",
|
||||
)
|
||||
except Exception as exc:
|
||||
return Response(
|
||||
response=f'{{"error": "Failed to activate project: {str(exc)}"}}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
attachment_names = [os.path.basename(path) for path in attachment_paths]
|
||||
message_id = str(uuid.uuid4())
|
||||
context.log.log(
|
||||
type="user",
|
||||
heading="",
|
||||
content=message,
|
||||
kvps={"attachments": attachment_names},
|
||||
id=message_id,
|
||||
)
|
||||
|
||||
try:
|
||||
task = context.communicate(
|
||||
UserMessage(message=message, attachments=attachment_paths, id=message_id)
|
||||
)
|
||||
result = await task.result()
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"status": "completed",
|
||||
"response": result,
|
||||
}
|
||||
except Exception as exc:
|
||||
PrintStyle.error(f"[a0-connector] message_send error: {exc}")
|
||||
return Response(
|
||||
response=f'{{"error": "{str(exc)}"}}',
|
||||
status=500,
|
||||
mimetype="application/json",
|
||||
)
|
||||
29
plugins/_a0_connector/api/v1/model_presets.py
Normal file
29
plugins/_a0_connector/api/v1/model_presets.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/model_presets."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class ModelPresets(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from plugins._model_config.helpers import model_config
|
||||
|
||||
action = str(input.get("action", "get")).strip() or "get"
|
||||
|
||||
if action == "get":
|
||||
presets = model_config.get_presets()
|
||||
return {"ok": True, "presets": presets}
|
||||
|
||||
if action == "save":
|
||||
presets = input.get("presets")
|
||||
if not isinstance(presets, list):
|
||||
return Response(status=400, response="presets must be an array")
|
||||
model_config.save_presets(presets)
|
||||
return {"ok": True, "presets": presets}
|
||||
|
||||
if action == "reset":
|
||||
presets = model_config.reset_presets()
|
||||
return {"ok": True, "presets": presets}
|
||||
|
||||
return Response(status=400, response=f"Unknown action: {action}")
|
||||
167
plugins/_a0_connector/api/v1/model_switcher.py
Normal file
167
plugins/_a0_connector/api/v1/model_switcher.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/model_switcher."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
def _model_payload(config: dict | None, *, has_api_key: bool = False) -> dict[str, object]:
|
||||
config = config or {}
|
||||
provider = str(config.get("provider") or "").strip()
|
||||
name = str(config.get("name") or "").strip()
|
||||
return {
|
||||
"provider": provider,
|
||||
"name": name,
|
||||
"label": f"{provider}/{name}" if provider and name else (name or provider or "—"),
|
||||
"has_api_key": bool(has_api_key),
|
||||
}
|
||||
|
||||
|
||||
def _coerce_override_model(value: object) -> dict[str, str]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
|
||||
payload: dict[str, str] = {}
|
||||
provider = str(value.get("provider") or "").strip()
|
||||
name = str(value.get("name") or "").strip()
|
||||
api_key = str(value.get("api_key") or "").strip()
|
||||
api_base = str(value.get("api_base") or value.get("base_url") or "").strip()
|
||||
|
||||
if provider:
|
||||
payload["provider"] = provider
|
||||
if name:
|
||||
payload["name"] = name
|
||||
if api_key:
|
||||
payload["api_key"] = api_key
|
||||
if api_base:
|
||||
payload["api_base"] = api_base
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _provider_payload(
|
||||
value: object,
|
||||
*,
|
||||
has_api_key_lookup: Callable[[str], bool] | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
options: list[dict[str, object]] = []
|
||||
seen: set[str] = set()
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
provider = str(item.get("value") or item.get("id") or "").strip().lower()
|
||||
label = str(item.get("label") or item.get("name") or provider).strip()
|
||||
else:
|
||||
provider = str(item or "").strip().lower()
|
||||
label = provider.replace("_", " ").title()
|
||||
|
||||
if not provider or provider in seen:
|
||||
continue
|
||||
seen.add(provider)
|
||||
has_api_key = False
|
||||
if callable(has_api_key_lookup):
|
||||
try:
|
||||
has_api_key = bool(has_api_key_lookup(provider))
|
||||
except Exception:
|
||||
has_api_key = False
|
||||
elif isinstance(item, dict):
|
||||
has_api_key = bool(item.get("has_api_key"))
|
||||
|
||||
options.append({"value": provider, "label": label or provider, "has_api_key": has_api_key})
|
||||
|
||||
return options
|
||||
|
||||
|
||||
class ModelSwitcher(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
from helpers.persist_chat import save_tmp_chat
|
||||
from plugins._model_config.helpers import model_config
|
||||
|
||||
action = str(input.get("action", "get")).strip() or "get"
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
context = AgentContext.get(context_id) if context_id else None
|
||||
agent = getattr(context, "agent0", None) if context is not None else None
|
||||
|
||||
def build_state() -> dict[str, object]:
|
||||
override = context.get_data("chat_model_override") if context is not None else None
|
||||
try:
|
||||
chat_providers = _provider_payload(
|
||||
model_config.get_chat_providers(),
|
||||
has_api_key_lookup=lambda provider: model_config.has_provider_api_key(provider, ""),
|
||||
)
|
||||
except Exception:
|
||||
chat_providers = []
|
||||
chat_model = model_config.get_chat_model_config(agent)
|
||||
utility_model = model_config.get_utility_model_config(agent)
|
||||
|
||||
def _has_api_key(config: object) -> bool:
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
provider = str(config.get("provider") or "").strip().lower()
|
||||
api_key = str(config.get("api_key") or "").strip()
|
||||
if not provider:
|
||||
return bool(api_key)
|
||||
try:
|
||||
return bool(model_config.has_provider_api_key(provider, api_key))
|
||||
except Exception:
|
||||
return bool(api_key)
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"allowed": bool(model_config.is_chat_override_allowed(agent)),
|
||||
"override": override,
|
||||
"presets": model_config.get_presets(),
|
||||
"chat_providers": chat_providers,
|
||||
"main_model": _model_payload(chat_model, has_api_key=_has_api_key(chat_model)),
|
||||
"utility_model": _model_payload(utility_model, has_api_key=_has_api_key(utility_model)),
|
||||
}
|
||||
|
||||
if action == "get":
|
||||
return build_state()
|
||||
|
||||
if not context_id:
|
||||
return Response(status=400, response="Missing context_id")
|
||||
|
||||
if context is None:
|
||||
return Response(status=404, response="Context not found")
|
||||
|
||||
if not model_config.is_chat_override_allowed(agent):
|
||||
return Response(status=403, response="Per-chat override is disabled")
|
||||
|
||||
if action == "set_preset":
|
||||
preset_name = str(input.get("preset_name", "")).strip()
|
||||
if not preset_name:
|
||||
return Response(status=400, response="Missing preset_name")
|
||||
preset = model_config.get_preset_by_name(preset_name)
|
||||
if not preset:
|
||||
return Response(status=404, response=f"Preset '{preset_name}' not found")
|
||||
context.set_data("chat_model_override", {"preset_name": preset_name})
|
||||
save_tmp_chat(context)
|
||||
return build_state()
|
||||
|
||||
if action == "clear":
|
||||
context.set_data("chat_model_override", None)
|
||||
save_tmp_chat(context)
|
||||
return build_state()
|
||||
|
||||
if action == "set_override":
|
||||
main_model = _coerce_override_model(input.get("main_model"))
|
||||
utility_model = _coerce_override_model(input.get("utility_model"))
|
||||
if not main_model and not utility_model:
|
||||
return Response(status=400, response="Missing model override payload")
|
||||
|
||||
override: dict[str, dict[str, str]] = {}
|
||||
if main_model:
|
||||
override["chat"] = main_model
|
||||
if utility_model:
|
||||
override["utility"] = utility_model
|
||||
context.set_data("chat_model_override", override)
|
||||
save_tmp_chat(context)
|
||||
return build_state()
|
||||
|
||||
return Response(status=400, response=f"Unknown action: {action}")
|
||||
36
plugins/_a0_connector/api/v1/nudge.py
Normal file
36
plugins/_a0_connector/api/v1/nudge.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/nudge."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class Nudge(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error": "context_id is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return Response(
|
||||
response='{"error": "Context not found"}',
|
||||
status=404,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context.nudge()
|
||||
message = "Process reset, agent nudged."
|
||||
context.log.log(type="info", content=message)
|
||||
return {
|
||||
"ok": True,
|
||||
"context_id": context_id,
|
||||
"status": "nudged",
|
||||
"message": message,
|
||||
}
|
||||
48
plugins/_a0_connector/api/v1/pause.py
Normal file
48
plugins/_a0_connector/api/v1/pause.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/pause."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class Pause(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import AgentContext
|
||||
|
||||
context_id = str(input.get("context_id", "")).strip()
|
||||
raw_paused = input.get("paused", True)
|
||||
if isinstance(raw_paused, str):
|
||||
paused = raw_paused.strip().lower() not in {"", "0", "false", "no", "off"}
|
||||
else:
|
||||
paused = bool(raw_paused)
|
||||
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error": "context_id is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return Response(
|
||||
response='{"error": "Context not found"}',
|
||||
status=404,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
if paused and not context.is_running():
|
||||
return Response(
|
||||
response='{"error": "Context is not currently running"}',
|
||||
status=409,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context.paused = paused
|
||||
return {
|
||||
"ok": True,
|
||||
"context_id": context_id,
|
||||
"paused": paused,
|
||||
"status": "paused" if paused else "running",
|
||||
"message": "Agent paused." if paused else "Agent unpaused.",
|
||||
}
|
||||
108
plugins/_a0_connector/api/v1/projects.py
Normal file
108
plugins/_a0_connector/api/v1/projects.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/projects."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
def _string(value: object) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
def _normalize_project_summary(value: object) -> dict[str, str] | None:
|
||||
if not isinstance(value, Mapping):
|
||||
return None
|
||||
|
||||
name = _string(value.get("name"))
|
||||
if not name:
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"title": _string(value.get("title")),
|
||||
"description": _string(value.get("description")),
|
||||
"color": _string(value.get("color")),
|
||||
}
|
||||
|
||||
|
||||
class Projects(connector_base.ProtectedConnectorApiHandler):
|
||||
"""Thin connector proxy around the core `api.projects.Projects` surface."""
|
||||
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
action = _string(input.get("action")).lower() or "list"
|
||||
if action not in {"list", "load", "update", "activate", "deactivate"}:
|
||||
return {"ok": False, "error": f"Unsupported action: {action or '<missing>'}"}
|
||||
|
||||
core_response = await self._call_core(
|
||||
{
|
||||
"action": action,
|
||||
"context_id": _string(input.get("context_id")),
|
||||
"name": _string(input.get("name")),
|
||||
"project": input.get("project"),
|
||||
},
|
||||
request,
|
||||
)
|
||||
if isinstance(core_response, Response):
|
||||
return core_response
|
||||
if not isinstance(core_response, Mapping):
|
||||
return {"ok": False, "error": "Invalid response from core projects handler"}
|
||||
if not core_response.get("ok"):
|
||||
return {"ok": False, "error": _string(core_response.get("error")) or "Project request failed"}
|
||||
|
||||
if action in {"activate", "deactivate", "list"}:
|
||||
return await self._normalized_list_state(_string(input.get("context_id")), request)
|
||||
|
||||
project = core_response.get("data")
|
||||
return {
|
||||
"ok": True,
|
||||
"project": dict(project) if isinstance(project, Mapping) else {},
|
||||
}
|
||||
|
||||
async def _normalized_list_state(self, context_id: str, request: Request) -> dict[str, Any] | Response:
|
||||
core_response = await self._call_core(
|
||||
{
|
||||
"action": "list",
|
||||
"context_id": context_id,
|
||||
},
|
||||
request,
|
||||
)
|
||||
if isinstance(core_response, Response):
|
||||
return core_response
|
||||
if not isinstance(core_response, Mapping):
|
||||
return {"ok": False, "error": "Invalid response from core projects handler"}
|
||||
if not core_response.get("ok"):
|
||||
return {"ok": False, "error": _string(core_response.get("error")) or "Project request failed"}
|
||||
|
||||
projects: list[dict[str, str]] = []
|
||||
for item in core_response.get("data") or []:
|
||||
normalized = _normalize_project_summary(item)
|
||||
if normalized is not None:
|
||||
projects.append(normalized)
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"projects": projects,
|
||||
"current_project": self._load_current_project(context_id),
|
||||
}
|
||||
|
||||
async def _call_core(self, payload: dict[str, Any], request: Request) -> dict | Response:
|
||||
from api.projects import Projects as CoreProjects
|
||||
|
||||
handler = CoreProjects(self.app, self.thread_lock)
|
||||
return await handler.process(payload, request)
|
||||
|
||||
def _load_current_project(self, context_id: str) -> dict[str, str] | None:
|
||||
if not context_id:
|
||||
return None
|
||||
|
||||
from agent import AgentContext
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
return _normalize_project_summary(context.get_output_data("project"))
|
||||
12
plugins/_a0_connector/api/v1/settings_get.py
Normal file
12
plugins/_a0_connector/api/v1/settings_get.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/settings_get."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class SettingsGet(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import settings
|
||||
|
||||
return dict(settings.convert_out(settings.get_settings()))
|
||||
22
plugins/_a0_connector/api/v1/settings_set.py
Normal file
22
plugins/_a0_connector/api/v1/settings_set.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/settings_set."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class SettingsSet(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import settings
|
||||
|
||||
payload = input.get("settings", input)
|
||||
if not isinstance(payload, dict):
|
||||
return Response(
|
||||
response='{"error":"settings must be an object"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
backend = settings.convert_in(settings.Settings(**payload))
|
||||
backend = settings.set_settings(backend)
|
||||
return dict(settings.convert_out(backend))
|
||||
26
plugins/_a0_connector/api/v1/skills_delete.py
Normal file
26
plugins/_a0_connector/api/v1/skills_delete.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/skills_delete."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class SkillsDelete(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import skills
|
||||
|
||||
skill_path = str(input.get("skill_path") or "").strip()
|
||||
if not skill_path:
|
||||
return Response(
|
||||
response='{"error":"skill_path is required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
skills.delete_skill(skill_path)
|
||||
return {
|
||||
"ok": True,
|
||||
"data": {
|
||||
"skill_path": skill_path,
|
||||
},
|
||||
}
|
||||
55
plugins/_a0_connector/api/v1/skills_list.py
Normal file
55
plugins/_a0_connector/api/v1/skills_list.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/skills_list."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class SkillsList(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from helpers import files, projects, runtime, skills
|
||||
|
||||
skill_list = skills.list_skills()
|
||||
project_name = str(input.get("project_name", "")).strip() or None
|
||||
|
||||
if project_name:
|
||||
project_folder = projects.get_project_folder(project_name)
|
||||
if runtime.is_development():
|
||||
project_folder = files.normalize_a0_path(project_folder)
|
||||
skill_list = [
|
||||
item
|
||||
for item in skill_list
|
||||
if files.is_in_dir(str(item.path), project_folder)
|
||||
]
|
||||
|
||||
agent_profile = str(input.get("agent_profile", "")).strip() or None
|
||||
if agent_profile:
|
||||
roots: list[str] = [
|
||||
files.get_abs_path("agents", agent_profile, "skills"),
|
||||
files.get_abs_path("usr", "agents", agent_profile, "skills"),
|
||||
]
|
||||
if project_name:
|
||||
roots.append(
|
||||
projects.get_project_meta(project_name, "agents", agent_profile, "skills")
|
||||
)
|
||||
|
||||
skill_list = [
|
||||
item
|
||||
for item in skill_list
|
||||
if any(files.is_in_dir(str(item.path), root) for root in roots)
|
||||
]
|
||||
|
||||
result = [
|
||||
{
|
||||
"name": skill.name,
|
||||
"description": skill.description,
|
||||
"path": str(skill.path),
|
||||
}
|
||||
for skill in skill_list
|
||||
]
|
||||
result.sort(key=lambda item: (item["name"], item["path"]))
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"data": result,
|
||||
}
|
||||
57
plugins/_a0_connector/api/v1/token_status.py
Normal file
57
plugins/_a0_connector/api/v1/token_status.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""POST /api/plugins/_a0_connector/v1/token_status."""
|
||||
from __future__ import annotations
|
||||
|
||||
from helpers.api import Request, Response
|
||||
import plugins._a0_connector.api.v1.base as connector_base
|
||||
|
||||
|
||||
class TokenStatus(connector_base.ProtectedConnectorApiHandler):
|
||||
async def process(self, input: dict, request: Request) -> dict | Response:
|
||||
from agent import Agent, AgentContext
|
||||
|
||||
context_id = str(
|
||||
input.get("context", input.get("context_id", input.get("ctxid", "")))
|
||||
).strip()
|
||||
if not context_id:
|
||||
return Response(
|
||||
response='{"error":"context_id required"}',
|
||||
status=400,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return Response(
|
||||
response='{"error":"Context not found"}',
|
||||
status=404,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
agent = context.streaming_agent or context.agent0
|
||||
window = agent.get_data(Agent.DATA_NAME_CTX_WINDOW) if agent is not None else None
|
||||
token_count: int | None = None
|
||||
if isinstance(window, dict):
|
||||
raw_tokens = window.get("tokens")
|
||||
try:
|
||||
token_count = int(raw_tokens)
|
||||
except (TypeError, ValueError):
|
||||
token_count = None
|
||||
|
||||
context_window: int | None = None
|
||||
try:
|
||||
from plugins._model_config.helpers.model_config import get_chat_model_config
|
||||
|
||||
chat_config = get_chat_model_config(agent)
|
||||
if isinstance(chat_config, dict):
|
||||
raw_context_window = int(chat_config.get("ctx_length", 0))
|
||||
if raw_context_window > 0:
|
||||
context_window = raw_context_window
|
||||
except Exception:
|
||||
context_window = None
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"context_id": context_id,
|
||||
"token_count": token_count,
|
||||
"context_window": context_window,
|
||||
}
|
||||
490
plugins/_a0_connector/api/ws_connector.py
Normal file
490
plugins/_a0_connector/api/ws_connector.py
Normal file
|
|
@ -0,0 +1,490 @@
|
|||
"""Connector WebSocket handler for the shared `/ws` namespace."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from helpers.print_style import PrintStyle
|
||||
from helpers.ws import WsHandler
|
||||
from helpers.ws_manager import WsResult
|
||||
|
||||
from plugins._a0_connector.helpers.event_bridge import get_context_log_entries
|
||||
from plugins._a0_connector.helpers.ws_runtime import (
|
||||
clear_remote_tree_snapshot,
|
||||
fail_pending_file_ops_for_sid,
|
||||
fail_pending_exec_ops_for_sid,
|
||||
register_sid,
|
||||
resolve_pending_file_op,
|
||||
resolve_pending_exec_op,
|
||||
store_remote_tree_snapshot,
|
||||
subscribe_sid_to_context,
|
||||
subscribed_contexts_for_sid,
|
||||
subscribed_sids_for_context,
|
||||
unsubscribe_sid_from_context,
|
||||
unregister_sid,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent import AgentContext, AgentContextType, UserMessage
|
||||
|
||||
|
||||
PROTOCOL_VERSION = "a0-connector.v1"
|
||||
WS_FEATURES = [
|
||||
"connector_subscribe_context",
|
||||
"connector_send_message",
|
||||
"text_editor_remote",
|
||||
"remote_file_tree",
|
||||
"code_execution_remote",
|
||||
]
|
||||
|
||||
|
||||
class WsConnector(WsHandler):
|
||||
_streaming_tasks: ClassVar[dict[tuple[str, str], asyncio.Task[None]]] = {}
|
||||
|
||||
@classmethod
|
||||
def requires_auth(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def requires_csrf(cls) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def requires_api_key(cls) -> bool:
|
||||
return False
|
||||
|
||||
async def on_connect(self, sid: str) -> None:
|
||||
register_sid(sid)
|
||||
PrintStyle.debug(f"[a0-connector] /ws connected: {sid}")
|
||||
|
||||
async def on_disconnect(self, sid: str) -> None:
|
||||
contexts = unregister_sid(sid)
|
||||
for context_id in contexts:
|
||||
self._cancel_streaming(sid, context_id)
|
||||
clear_remote_tree_snapshot(sid)
|
||||
fail_pending_file_ops_for_sid(
|
||||
sid,
|
||||
error="CLI disconnected before completing the requested file operation",
|
||||
)
|
||||
fail_pending_exec_ops_for_sid(
|
||||
sid,
|
||||
error="CLI disconnected before completing the requested remote execution",
|
||||
)
|
||||
PrintStyle.debug(f"[a0-connector] /ws disconnected: {sid}")
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: str,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult | None:
|
||||
if event == "connector_hello":
|
||||
return {
|
||||
"protocol": PROTOCOL_VERSION,
|
||||
"features": WS_FEATURES,
|
||||
}
|
||||
|
||||
if event == "connector_subscribe_context":
|
||||
return await self._handle_subscribe_context(data, sid)
|
||||
|
||||
if event == "connector_unsubscribe_context":
|
||||
return self._handle_unsubscribe_context(data, sid)
|
||||
|
||||
if event == "connector_send_message":
|
||||
return await self._handle_send_message(data, sid)
|
||||
|
||||
if event == "connector_file_op_result":
|
||||
return self._handle_file_op_result(data, sid)
|
||||
|
||||
if event == "connector_remote_tree_update":
|
||||
return self._handle_remote_tree_update(data, sid)
|
||||
|
||||
if event == "connector_exec_op_result":
|
||||
return self._handle_exec_op_result(data, sid)
|
||||
|
||||
if event.startswith("connector_"):
|
||||
return WsResult.error(
|
||||
code="UNKNOWN_EVENT",
|
||||
message=f"Unknown connector event: {event}",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_subscribe_context(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult:
|
||||
from agent import AgentContext
|
||||
|
||||
context_id = str(data.get("context_id", "")).strip()
|
||||
from_sequence = int(data.get("from", 0) or 0)
|
||||
|
||||
if not context_id:
|
||||
return WsResult.error(
|
||||
code="MISSING_CONTEXT_ID",
|
||||
message="context_id is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return WsResult.error(
|
||||
code="CONTEXT_NOT_FOUND",
|
||||
message=f"Context '{context_id}' not found",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
subscribe_sid_to_context(sid, context_id)
|
||||
events, last_sequence = get_context_log_entries(context_id, after=from_sequence)
|
||||
await self.emit_to(
|
||||
sid,
|
||||
"connector_context_snapshot",
|
||||
{
|
||||
"context_id": context_id,
|
||||
"events": events,
|
||||
"last_sequence": last_sequence,
|
||||
},
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
self._start_streaming(sid, context_id, from_sequence=last_sequence)
|
||||
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"subscribed": True,
|
||||
"last_sequence": last_sequence,
|
||||
}
|
||||
|
||||
def _handle_unsubscribe_context(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult:
|
||||
context_id = str(data.get("context_id", "")).strip()
|
||||
if not context_id:
|
||||
return WsResult.error(
|
||||
code="MISSING_CONTEXT_ID",
|
||||
message="context_id is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
self._cancel_streaming(sid, context_id)
|
||||
unsubscribe_sid_from_context(sid, context_id)
|
||||
return {"context_id": context_id, "unsubscribed": True}
|
||||
|
||||
async def _handle_send_message(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult:
|
||||
from plugins._a0_connector.helpers.chat_context import ConnectorContextError
|
||||
|
||||
message = str(data.get("message", "")).strip()
|
||||
if not message:
|
||||
return WsResult.error(
|
||||
code="MISSING_MESSAGE",
|
||||
message="message is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
context_id = str(data.get("context_id", "")).strip() or None
|
||||
current_context_id = (
|
||||
str(data.get("current_context", data.get("current_context_id", ""))).strip()
|
||||
or None
|
||||
)
|
||||
client_message_id = str(data.get("client_message_id", "")).strip()
|
||||
attachments = list(data.get("attachments", [])) if isinstance(data.get("attachments"), list) else []
|
||||
project_name = str(data.get("project_name", "")).strip() or None
|
||||
agent_profile = str(data.get("agent_profile", "")).strip() or None
|
||||
|
||||
try:
|
||||
context, context_id = await self._resolve_context(
|
||||
context_id=context_id,
|
||||
current_context_id=current_context_id,
|
||||
agent_profile=agent_profile,
|
||||
project_name=project_name,
|
||||
)
|
||||
except ConnectorContextError as exc:
|
||||
return WsResult.error(
|
||||
code=exc.code,
|
||||
message=str(exc),
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
except Exception as exc:
|
||||
return WsResult.error(
|
||||
code="BAD_REQUEST",
|
||||
message=str(exc),
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
if context is None or context_id is None:
|
||||
return WsResult.error(
|
||||
code="CONTEXT_NOT_FOUND",
|
||||
message="Unable to resolve or create the requested context",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
if context_id not in subscribed_contexts_for_sid(sid):
|
||||
subscribe_sid_to_context(sid, context_id)
|
||||
events, last_sequence = get_context_log_entries(context_id, after=0)
|
||||
await self.emit_to(
|
||||
sid,
|
||||
"connector_context_snapshot",
|
||||
{
|
||||
"context_id": context_id,
|
||||
"events": events,
|
||||
"last_sequence": last_sequence,
|
||||
},
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
self._start_streaming(sid, context_id, from_sequence=last_sequence)
|
||||
|
||||
message_id = client_message_id or data.get("correlationId") or ""
|
||||
context.log.log(
|
||||
type="user",
|
||||
heading="",
|
||||
content=message,
|
||||
kvps={},
|
||||
id=message_id,
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
self._run_message(
|
||||
context=context,
|
||||
context_id=context_id,
|
||||
message=message,
|
||||
attachments=attachments,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"status": "accepted",
|
||||
"client_message_id": client_message_id or None,
|
||||
}
|
||||
|
||||
def _handle_file_op_result(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult:
|
||||
op_id = str(data.get("op_id", "")).strip()
|
||||
if not op_id:
|
||||
return WsResult.error(
|
||||
code="MISSING_OP_ID",
|
||||
message="op_id is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
if not resolve_pending_file_op(op_id, sid=sid, payload=data):
|
||||
return WsResult.error(
|
||||
code="UNKNOWN_OP_ID",
|
||||
message=f"No pending file operation for op_id '{op_id}'",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
return {"op_id": op_id, "accepted": True}
|
||||
|
||||
def _handle_remote_tree_update(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult:
|
||||
tree = data.get("tree")
|
||||
root_path = data.get("root_path")
|
||||
tree_hash = data.get("tree_hash")
|
||||
|
||||
if not isinstance(tree, str) or not tree.strip():
|
||||
return WsResult.error(
|
||||
code="INVALID_TREE_PAYLOAD",
|
||||
message="tree is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
if not isinstance(root_path, str) or not root_path.strip():
|
||||
return WsResult.error(
|
||||
code="INVALID_TREE_PAYLOAD",
|
||||
message="root_path is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
if not isinstance(tree_hash, str) or not tree_hash.strip():
|
||||
return WsResult.error(
|
||||
code="INVALID_TREE_PAYLOAD",
|
||||
message="tree_hash is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
snapshot = store_remote_tree_snapshot(sid, data)
|
||||
return {
|
||||
"accepted": True,
|
||||
"sid": sid,
|
||||
"tree_hash": tree_hash,
|
||||
"updated_at": snapshot.updated_at,
|
||||
}
|
||||
|
||||
def _handle_exec_op_result(
|
||||
self,
|
||||
data: dict[str, Any],
|
||||
sid: str,
|
||||
) -> dict[str, Any] | WsResult:
|
||||
op_id = str(data.get("op_id", "")).strip()
|
||||
if not op_id:
|
||||
return WsResult.error(
|
||||
code="MISSING_OP_ID",
|
||||
message="op_id is required",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
if not resolve_pending_exec_op(op_id, sid=sid, payload=data):
|
||||
return WsResult.error(
|
||||
code="UNKNOWN_OP_ID",
|
||||
message=f"No pending exec operation for op_id '{op_id}'",
|
||||
correlation_id=data.get("correlationId"),
|
||||
)
|
||||
|
||||
return {"op_id": op_id, "accepted": True}
|
||||
|
||||
async def _resolve_context(
|
||||
self,
|
||||
*,
|
||||
context_id: str | None,
|
||||
current_context_id: str | None,
|
||||
agent_profile: str | None,
|
||||
project_name: str | None,
|
||||
) -> tuple[AgentContext | None, str | None]:
|
||||
from plugins._a0_connector.helpers.chat_context import (
|
||||
create_context,
|
||||
get_existing_context,
|
||||
)
|
||||
|
||||
if context_id:
|
||||
context = get_existing_context(
|
||||
context_id,
|
||||
agent_profile=agent_profile,
|
||||
project_name=project_name,
|
||||
)
|
||||
return context, context_id
|
||||
|
||||
context = create_context(
|
||||
lock=self.lock,
|
||||
current_context_id=current_context_id,
|
||||
agent_profile=agent_profile,
|
||||
project_name=project_name,
|
||||
)
|
||||
context_id = context.id
|
||||
return context, context_id
|
||||
|
||||
async def _run_message(
|
||||
self,
|
||||
*,
|
||||
context: AgentContext,
|
||||
context_id: str,
|
||||
message: str,
|
||||
attachments: list[Any],
|
||||
) -> None:
|
||||
from agent import AgentContext, UserMessage
|
||||
|
||||
try:
|
||||
AgentContext.use(context_id)
|
||||
task = context.communicate(
|
||||
UserMessage(message=message, attachments=attachments)
|
||||
)
|
||||
result = await task.result()
|
||||
except Exception as exc:
|
||||
PrintStyle.error(f"[a0-connector] connector_send_message error: {exc}")
|
||||
await self._emit_context_error(
|
||||
context_id=context_id,
|
||||
code="AGENT_ERROR",
|
||||
message=str(exc),
|
||||
)
|
||||
await self._emit_context_complete(
|
||||
context_id=context_id,
|
||||
payload={"status": "error", "error": str(exc)},
|
||||
)
|
||||
return
|
||||
|
||||
await self._emit_context_complete(
|
||||
context_id=context_id,
|
||||
payload={"status": "completed", "response": result},
|
||||
)
|
||||
|
||||
async def _emit_context_error(
|
||||
self,
|
||||
*,
|
||||
context_id: str,
|
||||
code: str,
|
||||
message: str,
|
||||
) -> None:
|
||||
payload = {
|
||||
"context_id": context_id,
|
||||
"code": code,
|
||||
"message": message,
|
||||
}
|
||||
for target_sid in subscribed_sids_for_context(context_id):
|
||||
try:
|
||||
await self.emit_to(target_sid, "connector_error", payload)
|
||||
except Exception as exc:
|
||||
PrintStyle.error(
|
||||
f"[a0-connector] failed to emit connector_error to {target_sid}: {exc}"
|
||||
)
|
||||
|
||||
async def _emit_context_complete(
|
||||
self,
|
||||
*,
|
||||
context_id: str,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
event_payload = {"context_id": context_id, **payload}
|
||||
for target_sid in subscribed_sids_for_context(context_id):
|
||||
try:
|
||||
await self.emit_to(
|
||||
target_sid,
|
||||
"connector_context_complete",
|
||||
event_payload,
|
||||
)
|
||||
except Exception as exc:
|
||||
PrintStyle.error(
|
||||
f"[a0-connector] failed to emit connector_context_complete to {target_sid}: {exc}"
|
||||
)
|
||||
|
||||
def _start_streaming(self, sid: str, context_id: str, *, from_sequence: int) -> None:
|
||||
key = (sid, context_id)
|
||||
task = self._streaming_tasks.get(key)
|
||||
if task is not None and not task.done():
|
||||
return
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._stream_events(sid, context_id, from_sequence=from_sequence)
|
||||
)
|
||||
self._streaming_tasks[key] = task
|
||||
|
||||
def _cancel_streaming(self, sid: str, context_id: str) -> None:
|
||||
task = self._streaming_tasks.pop((sid, context_id), None)
|
||||
if task is not None and not task.done():
|
||||
task.get_loop().call_soon_threadsafe(task.cancel)
|
||||
|
||||
async def _stream_events(
|
||||
self,
|
||||
sid: str,
|
||||
context_id: str,
|
||||
*,
|
||||
from_sequence: int,
|
||||
) -> None:
|
||||
# `from_sequence` is a log-output cursor (not an event sequence number).
|
||||
cursor = max(int(from_sequence or 0), 0)
|
||||
try:
|
||||
while context_id in subscribed_contexts_for_sid(sid):
|
||||
events, next_cursor = get_context_log_entries(context_id, after=cursor)
|
||||
for event in events:
|
||||
await self.emit_to(sid, "connector_context_event", event)
|
||||
cursor = max(cursor, int(next_cursor or cursor))
|
||||
await asyncio.sleep(0.5)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
PrintStyle.error(
|
||||
f"[a0-connector] stream error sid={sid} context={context_id}: {exc}"
|
||||
)
|
||||
finally:
|
||||
self._streaming_tasks.pop((sid, context_id), None)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from agent import LoopData
|
||||
from helpers.extension import Extension
|
||||
|
||||
from plugins._a0_connector.helpers.ws_runtime import latest_remote_tree_for_context
|
||||
|
||||
|
||||
class IncludeRemoteFileStructure(Extension):
|
||||
async def execute(self, loop_data: LoopData = LoopData(), **kwargs):
|
||||
if not self.agent:
|
||||
return
|
||||
|
||||
context_id = getattr(self.agent.context, "id", "")
|
||||
if not context_id:
|
||||
return
|
||||
|
||||
snapshot = latest_remote_tree_for_context(context_id, max_age_seconds=90.0)
|
||||
if not snapshot:
|
||||
return
|
||||
|
||||
file_structure = str(snapshot.get("tree") or "").strip()
|
||||
if not file_structure:
|
||||
return
|
||||
|
||||
folder = str(snapshot.get("root_path") or "").strip() or "unknown"
|
||||
generated_at = str(snapshot.get("generated_at") or "unknown")
|
||||
updated_at = float(snapshot.get("updated_at") or 0.0)
|
||||
age_seconds = max(0, int(time.time() - updated_at))
|
||||
|
||||
prompt = self.agent.read_prompt(
|
||||
"agent.extras.remote_file_structure.md",
|
||||
folder=folder,
|
||||
generated_at=generated_at,
|
||||
age_seconds=age_seconds,
|
||||
file_structure=file_structure,
|
||||
)
|
||||
loop_data.extras_temporary["remote_file_structure"] = prompt
|
||||
0
plugins/_a0_connector/helpers/__init__.py
Normal file
0
plugins/_a0_connector/helpers/__init__.py
Normal file
115
plugins/_a0_connector/helpers/chat_context.py
Normal file
115
plugins/_a0_connector/helpers/chat_context.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Shared chat-context helpers for connector handlers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ConnectorContextError(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
status_code: int = 400,
|
||||
code: str = "BAD_REQUEST",
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
self.code = code
|
||||
|
||||
|
||||
def get_existing_context(
|
||||
context_id: str,
|
||||
*,
|
||||
agent_profile: str | None = None,
|
||||
project_name: str | None = None,
|
||||
):
|
||||
from agent import AgentContext
|
||||
from helpers import projects
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
raise ConnectorContextError(
|
||||
"Context not found",
|
||||
status_code=404,
|
||||
code="CONTEXT_NOT_FOUND",
|
||||
)
|
||||
|
||||
if agent_profile and getattr(context.agent0.config, "profile", None) != agent_profile:
|
||||
raise ConnectorContextError(
|
||||
"Cannot change agent_profile on existing context",
|
||||
status_code=400,
|
||||
code="INVALID_AGENT_PROFILE",
|
||||
)
|
||||
|
||||
existing_project = context.get_data(projects.CONTEXT_DATA_KEY_PROJECT)
|
||||
if project_name and existing_project and existing_project != project_name:
|
||||
raise ConnectorContextError(
|
||||
"Project can only be set on first message",
|
||||
status_code=400,
|
||||
code="PROJECT_CONFLICT",
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def create_context(
|
||||
*,
|
||||
lock: Any | None = None,
|
||||
current_context_id: str | None = None,
|
||||
agent_profile: str | None = None,
|
||||
project_name: str | None = None,
|
||||
):
|
||||
from agent import AgentContext, AgentContextType
|
||||
from helpers import projects, settings
|
||||
from helpers.state_monitor_integration import mark_dirty_all
|
||||
from initialize import initialize_agent
|
||||
from plugins._model_config.helpers.model_config import is_chat_override_allowed
|
||||
|
||||
override_settings: dict[str, str] = {}
|
||||
if agent_profile:
|
||||
override_settings["agent_profile"] = agent_profile
|
||||
|
||||
with lock if lock is not None else nullcontext():
|
||||
current_context = AgentContext.get(current_context_id or "") if current_context_id else None
|
||||
|
||||
context = AgentContext(
|
||||
config=initialize_agent(override_settings=override_settings),
|
||||
type=AgentContextType.USER,
|
||||
set_current=True,
|
||||
)
|
||||
|
||||
if current_context and settings.get_settings().get("chat_inherit_project", True):
|
||||
current_project = current_context.get_data(projects.CONTEXT_DATA_KEY_PROJECT)
|
||||
if current_project:
|
||||
context.set_data(projects.CONTEXT_DATA_KEY_PROJECT, current_project)
|
||||
|
||||
current_project_output = current_context.get_output_data(
|
||||
projects.CONTEXT_DATA_KEY_PROJECT
|
||||
)
|
||||
if current_project_output:
|
||||
context.set_output_data(
|
||||
projects.CONTEXT_DATA_KEY_PROJECT,
|
||||
current_project_output,
|
||||
)
|
||||
|
||||
if current_context:
|
||||
model_override = current_context.get_data("chat_model_override")
|
||||
if model_override and is_chat_override_allowed(context.agent0):
|
||||
context.set_data("chat_model_override", model_override)
|
||||
|
||||
if project_name:
|
||||
try:
|
||||
try:
|
||||
projects.activate_project(context.id, project_name, mark_dirty=False)
|
||||
except TypeError as exc:
|
||||
if "mark_dirty" not in str(exc):
|
||||
raise
|
||||
projects.activate_project(context.id, project_name)
|
||||
except Exception:
|
||||
AgentContext.remove(context.id)
|
||||
raise
|
||||
|
||||
mark_dirty_all(reason="plugins._a0_connector.chat_context.create_context")
|
||||
return context
|
||||
128
plugins/_a0_connector/helpers/event_bridge.py
Normal file
128
plugins/_a0_connector/helpers/event_bridge.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""Context event streaming bridge for the a0-connector plugin."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncIterator, Callable
|
||||
|
||||
from helpers.print_style import PrintStyle
|
||||
|
||||
|
||||
EVENT_USER_MESSAGE = "user_message"
|
||||
EVENT_ASSISTANT_DELTA = "assistant_delta"
|
||||
EVENT_ASSISTANT_MESSAGE = "assistant_message"
|
||||
EVENT_TOOL_START = "tool_start"
|
||||
EVENT_TOOL_OUTPUT = "tool_output"
|
||||
EVENT_TOOL_END = "tool_end"
|
||||
EVENT_CODE_START = "code_start"
|
||||
EVENT_CODE_OUTPUT = "code_output"
|
||||
EVENT_WARNING = "warning"
|
||||
EVENT_ERROR = "error"
|
||||
EVENT_INFO = "info"
|
||||
EVENT_STATUS = "status"
|
||||
EVENT_UTIL_MESSAGE = "util_message"
|
||||
EVENT_MESSAGE_COMPLETE = "message_complete"
|
||||
EVENT_CONTEXT_UPDATED = "context_updated"
|
||||
|
||||
_LOG_TYPE_MAP: dict[str, str] = {
|
||||
"agent": EVENT_STATUS,
|
||||
"ai_response": EVENT_ASSISTANT_MESSAGE,
|
||||
"browser": EVENT_TOOL_OUTPUT,
|
||||
"code": EVENT_CODE_START,
|
||||
"code_exe": EVENT_CODE_OUTPUT,
|
||||
"code_output": EVENT_CODE_OUTPUT,
|
||||
"error": EVENT_ERROR,
|
||||
"hint": EVENT_STATUS,
|
||||
"info": EVENT_INFO,
|
||||
"input": EVENT_USER_MESSAGE,
|
||||
"mcp": EVENT_TOOL_START,
|
||||
"progress": EVENT_STATUS,
|
||||
"response": EVENT_ASSISTANT_MESSAGE,
|
||||
"subagent": EVENT_STATUS,
|
||||
"tool": EVENT_TOOL_START,
|
||||
"tool_output": EVENT_TOOL_OUTPUT,
|
||||
"user": EVENT_USER_MESSAGE,
|
||||
"util": EVENT_UTIL_MESSAGE,
|
||||
"warning": EVENT_WARNING,
|
||||
}
|
||||
|
||||
|
||||
def log_entry_to_connector_event(
|
||||
log_entry: dict[str, Any],
|
||||
context_id: str,
|
||||
) -> dict[str, Any]:
|
||||
entry_type = str(log_entry.get("type", "")).strip()
|
||||
event_type = _LOG_TYPE_MAP.get(entry_type, EVENT_STATUS)
|
||||
item_no = int(log_entry.get("no", 0) or 0)
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
content = log_entry.get("content")
|
||||
heading = log_entry.get("heading")
|
||||
kvps = log_entry.get("kvps")
|
||||
|
||||
if isinstance(content, str) and content:
|
||||
data["text"] = content
|
||||
if isinstance(heading, str) and heading:
|
||||
data["heading"] = heading
|
||||
if isinstance(kvps, dict) and kvps:
|
||||
data["meta"] = kvps
|
||||
|
||||
return {
|
||||
"context_id": context_id,
|
||||
"sequence": item_no + 1,
|
||||
"event": event_type,
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
def get_context_log_entries(
|
||||
context_id: str,
|
||||
after: int = 0,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Return connector events plus the next log cursor for the context."""
|
||||
try:
|
||||
from agent import AgentContext
|
||||
|
||||
context = AgentContext.get(context_id)
|
||||
if context is None:
|
||||
return [], 0
|
||||
|
||||
log_output = context.log.output(start=max(int(after or 0), 0))
|
||||
events = [
|
||||
log_entry_to_connector_event(entry, context_id)
|
||||
for entry in log_output.items
|
||||
if isinstance(entry, dict)
|
||||
]
|
||||
return events, int(log_output.end)
|
||||
except Exception as exc:
|
||||
PrintStyle.error(
|
||||
f"[a0-connector] event_bridge error for context {context_id}: {exc}"
|
||||
)
|
||||
return [], max(int(after or 0), 0)
|
||||
|
||||
|
||||
async def stream_context_events(
|
||||
context_id: str,
|
||||
from_sequence: int = 0,
|
||||
poll_interval: float = 0.5,
|
||||
timeout: float = 300.0,
|
||||
emit_fn: Callable[[dict[str, Any]], Any] | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
cursor = max(int(from_sequence or 0), 0)
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
events, next_cursor = get_context_log_entries(context_id, after=cursor)
|
||||
for event in events:
|
||||
if emit_fn is not None:
|
||||
try:
|
||||
result = emit_fn(event)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
PrintStyle.error(f"[a0-connector] emit_fn error: {exc}")
|
||||
yield event
|
||||
|
||||
cursor = max(cursor, next_cursor)
|
||||
await asyncio.sleep(poll_interval)
|
||||
295
plugins/_a0_connector/helpers/ws_runtime.py
Normal file
295
plugins/_a0_connector/helpers/ws_runtime.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingFileOperation:
|
||||
sid: str
|
||||
loop: asyncio.AbstractEventLoop
|
||||
future: asyncio.Future[dict[str, Any]]
|
||||
context_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingExecOperation:
|
||||
sid: str
|
||||
loop: asyncio.AbstractEventLoop
|
||||
future: asyncio.Future[dict[str, Any]]
|
||||
context_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoteTreeSnapshot:
|
||||
sid: str
|
||||
payload: dict[str, Any]
|
||||
updated_at: float
|
||||
|
||||
|
||||
_context_subscriptions: dict[str, set[str]] = {}
|
||||
_sid_contexts: dict[str, set[str]] = {}
|
||||
_pending_file_ops: dict[str, PendingFileOperation] = {}
|
||||
_pending_exec_ops: dict[str, PendingExecOperation] = {}
|
||||
_remote_tree_snapshots: dict[str, RemoteTreeSnapshot] = {}
|
||||
_state_lock = threading.RLock()
|
||||
|
||||
|
||||
def register_sid(sid: str) -> None:
|
||||
with _state_lock:
|
||||
_sid_contexts.setdefault(sid, set())
|
||||
|
||||
|
||||
def unregister_sid(sid: str) -> set[str]:
|
||||
with _state_lock:
|
||||
contexts = _sid_contexts.pop(sid, set())
|
||||
_remote_tree_snapshots.pop(sid, None)
|
||||
for context_id in contexts:
|
||||
subscribers = _context_subscriptions.get(context_id)
|
||||
if not subscribers:
|
||||
continue
|
||||
subscribers.discard(sid)
|
||||
if not subscribers:
|
||||
_context_subscriptions.pop(context_id, None)
|
||||
return contexts
|
||||
|
||||
|
||||
def subscribe_sid_to_context(sid: str, context_id: str) -> None:
|
||||
with _state_lock:
|
||||
_sid_contexts.setdefault(sid, set()).add(context_id)
|
||||
_context_subscriptions.setdefault(context_id, set()).add(sid)
|
||||
|
||||
|
||||
def unsubscribe_sid_from_context(sid: str, context_id: str) -> None:
|
||||
with _state_lock:
|
||||
contexts = _sid_contexts.get(sid)
|
||||
if contexts is not None:
|
||||
contexts.discard(context_id)
|
||||
if not contexts:
|
||||
_sid_contexts.pop(sid, None)
|
||||
|
||||
subscribers = _context_subscriptions.get(context_id)
|
||||
if subscribers is not None:
|
||||
subscribers.discard(sid)
|
||||
if not subscribers:
|
||||
_context_subscriptions.pop(context_id, None)
|
||||
|
||||
|
||||
def subscribed_contexts_for_sid(sid: str) -> set[str]:
|
||||
with _state_lock:
|
||||
return set(_sid_contexts.get(sid, set()))
|
||||
|
||||
|
||||
def subscribed_sids_for_context(context_id: str) -> set[str]:
|
||||
with _state_lock:
|
||||
return set(_context_subscriptions.get(context_id, set()))
|
||||
|
||||
|
||||
def store_remote_tree_snapshot(
|
||||
sid: str,
|
||||
payload: dict[str, Any],
|
||||
) -> RemoteTreeSnapshot:
|
||||
snapshot = RemoteTreeSnapshot(
|
||||
sid=sid,
|
||||
payload=dict(payload),
|
||||
updated_at=time.time(),
|
||||
)
|
||||
with _state_lock:
|
||||
_remote_tree_snapshots[sid] = snapshot
|
||||
return snapshot
|
||||
|
||||
|
||||
def clear_remote_tree_snapshot(sid: str) -> None:
|
||||
with _state_lock:
|
||||
_remote_tree_snapshots.pop(sid, None)
|
||||
|
||||
|
||||
def latest_remote_tree_for_context(
|
||||
context_id: str,
|
||||
*,
|
||||
max_age_seconds: float = 90.0,
|
||||
) -> dict[str, Any] | None:
|
||||
now = time.time()
|
||||
with _state_lock:
|
||||
subscribers = _context_subscriptions.get(context_id, set())
|
||||
snapshots = [
|
||||
_remote_tree_snapshots[sid]
|
||||
for sid in subscribers
|
||||
if sid in _remote_tree_snapshots
|
||||
]
|
||||
|
||||
if not snapshots:
|
||||
return None
|
||||
|
||||
snapshots.sort(key=lambda item: item.updated_at, reverse=True)
|
||||
for snapshot in snapshots:
|
||||
if max_age_seconds > 0 and now - snapshot.updated_at > max_age_seconds:
|
||||
continue
|
||||
payload = dict(snapshot.payload)
|
||||
payload["sid"] = snapshot.sid
|
||||
payload["updated_at"] = snapshot.updated_at
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def select_target_sid(context_id: str) -> str | None:
|
||||
with _state_lock:
|
||||
subscribers = _context_subscriptions.get(context_id, set())
|
||||
if not subscribers:
|
||||
return None
|
||||
return sorted(subscribers)[0]
|
||||
|
||||
|
||||
def store_pending_file_op(
|
||||
op_id: str,
|
||||
*,
|
||||
sid: str,
|
||||
future: asyncio.Future[dict[str, Any]],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
context_id: str | None = None,
|
||||
) -> None:
|
||||
with _state_lock:
|
||||
_pending_file_ops[op_id] = PendingFileOperation(
|
||||
sid=sid,
|
||||
loop=loop,
|
||||
future=future,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
|
||||
def clear_pending_file_op(op_id: str) -> None:
|
||||
with _state_lock:
|
||||
_pending_file_ops.pop(op_id, None)
|
||||
|
||||
|
||||
def resolve_pending_file_op(
|
||||
op_id: str,
|
||||
*,
|
||||
sid: str,
|
||||
payload: dict[str, Any],
|
||||
) -> bool:
|
||||
with _state_lock:
|
||||
pending = _pending_file_ops.get(op_id)
|
||||
if pending is None or pending.sid != sid:
|
||||
return False
|
||||
_pending_file_ops.pop(op_id, None)
|
||||
|
||||
pending.loop.call_soon_threadsafe(_set_future_result, pending.future, dict(payload))
|
||||
return True
|
||||
|
||||
|
||||
def fail_pending_file_op(
|
||||
op_id: str,
|
||||
*,
|
||||
sid: str | None = None,
|
||||
error: str,
|
||||
) -> bool:
|
||||
with _state_lock:
|
||||
pending = _pending_file_ops.get(op_id)
|
||||
if pending is None:
|
||||
return False
|
||||
if sid is not None and pending.sid != sid:
|
||||
return False
|
||||
_pending_file_ops.pop(op_id, None)
|
||||
|
||||
payload = {"op_id": op_id, "ok": False, "error": error}
|
||||
pending.loop.call_soon_threadsafe(_set_future_result, pending.future, payload)
|
||||
return True
|
||||
|
||||
|
||||
def fail_pending_file_ops_for_sid(sid: str, *, error: str) -> None:
|
||||
with _state_lock:
|
||||
matches = [
|
||||
(op_id, pending)
|
||||
for op_id, pending in _pending_file_ops.items()
|
||||
if pending.sid == sid
|
||||
]
|
||||
for op_id, _pending in matches:
|
||||
_pending_file_ops.pop(op_id, None)
|
||||
|
||||
for op_id, pending in matches:
|
||||
payload = {"op_id": op_id, "ok": False, "error": error}
|
||||
pending.loop.call_soon_threadsafe(_set_future_result, pending.future, payload)
|
||||
|
||||
|
||||
def store_pending_exec_op(
|
||||
op_id: str,
|
||||
*,
|
||||
sid: str,
|
||||
future: asyncio.Future[dict[str, Any]],
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
context_id: str | None = None,
|
||||
) -> None:
|
||||
with _state_lock:
|
||||
_pending_exec_ops[op_id] = PendingExecOperation(
|
||||
sid=sid,
|
||||
loop=loop,
|
||||
future=future,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
|
||||
def clear_pending_exec_op(op_id: str) -> None:
|
||||
with _state_lock:
|
||||
_pending_exec_ops.pop(op_id, None)
|
||||
|
||||
|
||||
def resolve_pending_exec_op(
|
||||
op_id: str,
|
||||
*,
|
||||
sid: str,
|
||||
payload: dict[str, Any],
|
||||
) -> bool:
|
||||
with _state_lock:
|
||||
pending = _pending_exec_ops.get(op_id)
|
||||
if pending is None or pending.sid != sid:
|
||||
return False
|
||||
_pending_exec_ops.pop(op_id, None)
|
||||
|
||||
pending.loop.call_soon_threadsafe(_set_future_result, pending.future, dict(payload))
|
||||
return True
|
||||
|
||||
|
||||
def fail_pending_exec_op(
|
||||
op_id: str,
|
||||
*,
|
||||
sid: str | None = None,
|
||||
error: str,
|
||||
) -> bool:
|
||||
with _state_lock:
|
||||
pending = _pending_exec_ops.get(op_id)
|
||||
if pending is None:
|
||||
return False
|
||||
if sid is not None and pending.sid != sid:
|
||||
return False
|
||||
_pending_exec_ops.pop(op_id, None)
|
||||
|
||||
payload = {"op_id": op_id, "ok": False, "error": error}
|
||||
pending.loop.call_soon_threadsafe(_set_future_result, pending.future, payload)
|
||||
return True
|
||||
|
||||
|
||||
def fail_pending_exec_ops_for_sid(sid: str, *, error: str) -> None:
|
||||
with _state_lock:
|
||||
matches = [
|
||||
(op_id, pending)
|
||||
for op_id, pending in _pending_exec_ops.items()
|
||||
if pending.sid == sid
|
||||
]
|
||||
for op_id, _pending in matches:
|
||||
_pending_exec_ops.pop(op_id, None)
|
||||
|
||||
for op_id, pending in matches:
|
||||
payload = {"op_id": op_id, "ok": False, "error": error}
|
||||
pending.loop.call_soon_threadsafe(_set_future_result, pending.future, payload)
|
||||
|
||||
|
||||
def _set_future_result(
|
||||
future: asyncio.Future[dict[str, Any]],
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
if not future.done():
|
||||
future.set_result(payload)
|
||||
9
plugins/_a0_connector/plugin.yaml
Normal file
9
plugins/_a0_connector/plugin.yaml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
name: _a0_connector
|
||||
title: A0 Connector
|
||||
description: Current Agent Zero connector plugin for HTTP plus /ws integration, using session auth and handler activation through auth.handlers.
|
||||
version: 0.1.0
|
||||
settings_sections:
|
||||
- external
|
||||
- developer
|
||||
per_project_config: false
|
||||
per_agent_config: false
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Remote file structure of connected CLI workspace {{folder}}
|
||||
- this snapshot comes from the frontend machine, not the Agent Zero server filesystem
|
||||
- snapshot age (seconds): {{age_seconds}}
|
||||
- generated at: {{generated_at}}
|
||||
|
||||
## file tree
|
||||
{{file_structure}}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
# code_execution_remote tool
|
||||
|
||||
This tool runs shell-backed execution on the **remote machine where the CLI is running**.
|
||||
It converges onto Agent Zero Core's persistent local-shell model, so the frontend session
|
||||
can execute terminal commands and shell-launched `python` / `nodejs` snippets while keeping
|
||||
session ids stable across calls.
|
||||
|
||||
## Requirements
|
||||
- A CLI client must be connected to this context via the shared `/ws` namespace.
|
||||
- The CLI client must support `connector_exec_op`.
|
||||
- Frontend execution may be locally disabled in the CLI session; in that case the result is
|
||||
a structured `{ok: false}` error and no fallback runtime is used.
|
||||
|
||||
## Arguments
|
||||
- `runtime`: one of `terminal`, `python`, `nodejs`, `output`, `reset`
|
||||
- `runtime=input` is a temporary deprecated compatibility alias for sending one line of
|
||||
keyboard input into a running shell session
|
||||
- `session`: integer session id (default `0`)
|
||||
|
||||
Runtime-specific fields:
|
||||
- `terminal`, `python`, `nodejs`: require `code`
|
||||
- `input`: requires `keyboard` (or `code` as fallback)
|
||||
- `reset`: optional `reason`
|
||||
|
||||
## Usage
|
||||
|
||||
### Execute a terminal command
|
||||
```json
|
||||
{
|
||||
"tool_name": "code_execution_remote",
|
||||
"tool_args": {
|
||||
"runtime": "terminal",
|
||||
"session": 0,
|
||||
"code": "pwd && ls -la"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Execute Python through the shell-backed runtime
|
||||
```json
|
||||
{
|
||||
"tool_name": "code_execution_remote",
|
||||
"tool_args": {
|
||||
"runtime": "python",
|
||||
"session": 0,
|
||||
"code": "import os\nprint(os.getcwd())"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Execute Node.js through the shell-backed runtime
|
||||
```json
|
||||
{
|
||||
"tool_name": "code_execution_remote",
|
||||
"tool_args": {
|
||||
"runtime": "nodejs",
|
||||
"session": 0,
|
||||
"code": "console.log(process.cwd())"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Poll output from a running session
|
||||
```json
|
||||
{
|
||||
"tool_name": "code_execution_remote",
|
||||
"tool_args": {
|
||||
"runtime": "output",
|
||||
"session": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Send keyboard input to a running session
|
||||
```json
|
||||
{
|
||||
"tool_name": "code_execution_remote",
|
||||
"tool_args": {
|
||||
"runtime": "input",
|
||||
"session": 0,
|
||||
"keyboard": "yes"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Reset a session
|
||||
```json
|
||||
{
|
||||
"tool_name": "code_execution_remote",
|
||||
"tool_args": {
|
||||
"runtime": "reset",
|
||||
"session": 0,
|
||||
"reason": "stuck process"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
- Session state is frontend-local and shell-backed.
|
||||
- `output` is for long-running operations where a prior call returned control before the
|
||||
shell reached a prompt.
|
||||
- The transport uses `connector_exec_op` and `connector_exec_op_result` with shared `op_id`.
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
# text_editor_remote tool
|
||||
|
||||
This tool allows you to read, write, and patch files on the **remote machine where the CLI is running**.
|
||||
This is different from `text_editor` which operates on the Agent Zero server's filesystem.
|
||||
|
||||
Use `text_editor_remote` when the user asks you to edit files on their local machine while connected via the CLI.
|
||||
|
||||
## Requirements
|
||||
- A CLI client must be connected to this context via the shared `/ws` namespace.
|
||||
- The CLI client must have enabled remote file editing support.
|
||||
|
||||
## Operations
|
||||
|
||||
### Read a file
|
||||
```json
|
||||
{
|
||||
"tool_name": "text_editor_remote",
|
||||
"tool_args": {
|
||||
"op": "read",
|
||||
"path": "/path/on/remote/machine/file.py",
|
||||
"line_from": 1,
|
||||
"line_to": 50
|
||||
}
|
||||
}
|
||||
```
|
||||
Returns file content with line numbers. `line_from` and `line_to` are optional.
|
||||
|
||||
### Write a file
|
||||
```json
|
||||
{
|
||||
"tool_name": "text_editor_remote",
|
||||
"tool_args": {
|
||||
"op": "write",
|
||||
"path": "/path/on/remote/machine/file.py",
|
||||
"content": "import os\nprint('hello')\n"
|
||||
}
|
||||
}
|
||||
```
|
||||
Creates or overwrites the file on the remote machine.
|
||||
|
||||
### Patch a file
|
||||
```json
|
||||
{
|
||||
"tool_name": "text_editor_remote",
|
||||
"tool_args": {
|
||||
"op": "patch",
|
||||
"path": "/path/on/remote/machine/file.py",
|
||||
"edits": [
|
||||
{"from": 5, "to": 5, "content": " if x == 2:\n"}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
Applies line-range patches to the file. Use the same format as the standard `text_editor:patch` tool.
|
||||
|
||||
## Notes
|
||||
- Always read the file first before patching to get current line numbers.
|
||||
- Paths are evaluated on the **remote machine's filesystem**, not the Agent Zero server.
|
||||
- If no CLI is connected, the tool will return an error message.
|
||||
- The transport uses `connector_file_op` and `connector_file_op_result` with a shared `op_id`.
|
||||
0
plugins/_a0_connector/tools/__init__.py
Normal file
0
plugins/_a0_connector/tools/__init__.py
Normal file
196
plugins/_a0_connector/tools/code_execution_remote.py
Normal file
196
plugins/_a0_connector/tools/code_execution_remote.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""code_execution_remote tool — run shell-backed frontend operations on the CLI machine via `/ws`."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from helpers.tool import Response, Tool
|
||||
from helpers.ws import NAMESPACE
|
||||
from helpers.ws_manager import ConnectionNotFoundError, get_shared_ws_manager
|
||||
|
||||
from plugins._a0_connector.helpers.ws_runtime import (
|
||||
clear_pending_exec_op,
|
||||
select_target_sid,
|
||||
store_pending_exec_op,
|
||||
)
|
||||
|
||||
|
||||
EXEC_OP_TIMEOUT = 120.0
|
||||
EXEC_OP_EVENT = "connector_exec_op"
|
||||
|
||||
|
||||
class CodeExecutionRemote(Tool):
|
||||
"""Send shell-backed frontend execution operations to the connected CLI machine."""
|
||||
|
||||
def get_log_object(self):
|
||||
import uuid
|
||||
|
||||
return self.agent.context.log.log(
|
||||
type="code_exe",
|
||||
heading=self.get_heading(),
|
||||
content="",
|
||||
kvps=self.args,
|
||||
id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
def get_heading(self, text: str = "") -> str:
|
||||
if not text:
|
||||
name = str(getattr(self, "name", "code_execution_remote"))
|
||||
runtime = str(self.args.get("runtime", "unknown") or "unknown")
|
||||
text = f"{name} - {runtime}"
|
||||
|
||||
normalized = " ".join(str(text).split())
|
||||
if len(normalized) > 200:
|
||||
normalized = normalized[:197].rstrip() + "..."
|
||||
|
||||
session = self.args.get("session", None)
|
||||
session_text = f"[{session}] " if session or session == 0 else ""
|
||||
return f"icon://terminal {session_text}{normalized}"
|
||||
|
||||
async def execute(self, **kwargs: Any) -> Response:
|
||||
runtime = str(self.args.get("runtime", "")).strip().lower()
|
||||
if runtime not in {"terminal", "python", "nodejs", "output", "input", "reset"}:
|
||||
return Response(
|
||||
message=(
|
||||
"runtime is required (terminal, python, nodejs, output, reset, "
|
||||
"or input [deprecated compatibility alias])"
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
context_id = self.agent.context.id
|
||||
sid = select_target_sid(context_id)
|
||||
if not sid:
|
||||
return Response(
|
||||
message=(
|
||||
"code_execution_remote: no CLI client connected to this context. "
|
||||
"Make sure the CLI is connected and subscribed."
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
try:
|
||||
session = int(self.args.get("session", 0) or 0)
|
||||
except (TypeError, ValueError):
|
||||
return Response(
|
||||
message="session must be an integer",
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
op_id = str(uuid.uuid4())
|
||||
payload: dict[str, Any] = {
|
||||
"op_id": op_id,
|
||||
"runtime": runtime,
|
||||
"session": session,
|
||||
"context_id": context_id,
|
||||
}
|
||||
|
||||
if runtime in {"terminal", "python", "nodejs"}:
|
||||
code = self.args.get("code")
|
||||
if code is None or not str(code).strip():
|
||||
return Response(
|
||||
message=f"code is required for runtime={runtime}",
|
||||
break_loop=False,
|
||||
)
|
||||
payload["code"] = str(code)
|
||||
|
||||
elif runtime == "input":
|
||||
keyboard = self.args.get("keyboard")
|
||||
if keyboard is None:
|
||||
keyboard = self.args.get("code")
|
||||
if keyboard is None:
|
||||
return Response(
|
||||
message="keyboard is required for runtime=input",
|
||||
break_loop=False,
|
||||
)
|
||||
payload["keyboard"] = str(keyboard)
|
||||
|
||||
elif runtime == "reset":
|
||||
reason = self.args.get("reason")
|
||||
if reason is not None:
|
||||
payload["reason"] = str(reason)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[dict[str, Any]] = loop.create_future()
|
||||
store_pending_exec_op(
|
||||
op_id,
|
||||
sid=sid,
|
||||
future=future,
|
||||
loop=loop,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await get_shared_ws_manager().emit_to(
|
||||
NAMESPACE,
|
||||
sid,
|
||||
EXEC_OP_EVENT,
|
||||
payload,
|
||||
handler_id=f"{self.__class__.__module__}.{self.__class__.__name__}",
|
||||
)
|
||||
result = await asyncio.wait_for(future, timeout=EXEC_OP_TIMEOUT)
|
||||
except ConnectionNotFoundError:
|
||||
clear_pending_exec_op(op_id)
|
||||
return Response(
|
||||
message=(
|
||||
"code_execution_remote: the selected CLI client disconnected before "
|
||||
"the execution request could be delivered"
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
clear_pending_exec_op(op_id)
|
||||
return Response(
|
||||
message=(
|
||||
"code_execution_remote: timed out waiting for CLI to respond "
|
||||
f"to runtime={runtime!r} in session {session}"
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
clear_pending_exec_op(op_id)
|
||||
return Response(
|
||||
message=f"code_execution_remote: error sending exec_op: {exc}",
|
||||
break_loop=False,
|
||||
)
|
||||
finally:
|
||||
clear_pending_exec_op(op_id)
|
||||
|
||||
return Response(
|
||||
message=self._extract_result(result, runtime, session),
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
def _extract_result(self, result: Any, runtime: str, session: int) -> str:
|
||||
if not isinstance(result, dict):
|
||||
return f"Unexpected response format from CLI: {result!r}"
|
||||
|
||||
ok = bool(result.get("ok"))
|
||||
data = result.get("result")
|
||||
error = result.get("error")
|
||||
|
||||
if not ok:
|
||||
return (
|
||||
f"Error (runtime={runtime!r}, session={session}): "
|
||||
f"{error or 'Unknown error'}"
|
||||
)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
|
||||
output = str(data.get("output") or data.get("text") or "").strip()
|
||||
message = str(data.get("message") or "").strip()
|
||||
running = bool(data.get("running"))
|
||||
|
||||
parts: list[str] = []
|
||||
if message:
|
||||
parts.append(message)
|
||||
if output:
|
||||
parts.append(output)
|
||||
|
||||
if not parts:
|
||||
state = "running" if running else "completed"
|
||||
parts.append(f"Session {session} {state}.")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
156
plugins/_a0_connector/tools/text_editor_remote.py
Normal file
156
plugins/_a0_connector/tools/text_editor_remote.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""text_editor_remote tool — edit files on the CLI machine via `/ws`."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from helpers.tool import Response, Tool
|
||||
from helpers.ws import NAMESPACE
|
||||
from helpers.ws_manager import ConnectionNotFoundError, get_shared_ws_manager
|
||||
|
||||
from plugins._a0_connector.helpers.ws_runtime import (
|
||||
clear_pending_file_op,
|
||||
select_target_sid,
|
||||
store_pending_file_op,
|
||||
)
|
||||
|
||||
|
||||
FILE_OP_TIMEOUT = 30.0
|
||||
FILE_OP_EVENT = "connector_file_op"
|
||||
|
||||
|
||||
class TextEditorRemote(Tool):
|
||||
"""Send file-editing operations to the connected CLI machine."""
|
||||
|
||||
async def execute(self, **kwargs: Any) -> Response:
|
||||
op = str(self.args.get("op") or self.args.get("operation", "")).strip().lower()
|
||||
if not op:
|
||||
return Response(
|
||||
message="op is required (read, write, or patch)",
|
||||
break_loop=False,
|
||||
)
|
||||
if op not in {"read", "write", "patch"}:
|
||||
return Response(
|
||||
message=f"Unknown operation: {op!r}. Use read, write, or patch.",
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
path = str(self.args.get("path", "")).strip()
|
||||
if not path:
|
||||
return Response(message="path is required", break_loop=False)
|
||||
|
||||
context_id = self.agent.context.id
|
||||
sid = select_target_sid(context_id)
|
||||
if not sid:
|
||||
return Response(
|
||||
message=(
|
||||
"text_editor_remote: no CLI client connected to this context. "
|
||||
"Make sure the CLI is connected and subscribed."
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
op_id = str(uuid.uuid4())
|
||||
payload: dict[str, Any] = {
|
||||
"op_id": op_id,
|
||||
"op": op,
|
||||
"path": path,
|
||||
"context_id": context_id,
|
||||
}
|
||||
if op == "read":
|
||||
if self.args.get("line_from"):
|
||||
payload["line_from"] = int(self.args["line_from"])
|
||||
if self.args.get("line_to"):
|
||||
payload["line_to"] = int(self.args["line_to"])
|
||||
elif op == "write":
|
||||
content = self.args.get("content")
|
||||
if content is None:
|
||||
return Response(
|
||||
message="content is required for write",
|
||||
break_loop=False,
|
||||
)
|
||||
payload["content"] = content
|
||||
else:
|
||||
edits = self.args.get("edits")
|
||||
if not edits:
|
||||
return Response(
|
||||
message="edits is required for patch",
|
||||
break_loop=False,
|
||||
)
|
||||
payload["edits"] = edits
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[dict[str, Any]] = loop.create_future()
|
||||
store_pending_file_op(
|
||||
op_id,
|
||||
sid=sid,
|
||||
future=future,
|
||||
loop=loop,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await get_shared_ws_manager().emit_to(
|
||||
NAMESPACE,
|
||||
sid,
|
||||
FILE_OP_EVENT,
|
||||
payload,
|
||||
handler_id=f"{self.__class__.__module__}.{self.__class__.__name__}",
|
||||
)
|
||||
result = await asyncio.wait_for(future, timeout=FILE_OP_TIMEOUT)
|
||||
except ConnectionNotFoundError:
|
||||
clear_pending_file_op(op_id)
|
||||
return Response(
|
||||
message=(
|
||||
"text_editor_remote: the selected CLI client disconnected before "
|
||||
"the file operation could be delivered"
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
clear_pending_file_op(op_id)
|
||||
return Response(
|
||||
message=(
|
||||
f"text_editor_remote: timed out waiting for CLI to respond "
|
||||
f"to {op} on {path!r}"
|
||||
),
|
||||
break_loop=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
clear_pending_file_op(op_id)
|
||||
return Response(
|
||||
message=f"text_editor_remote: error sending file_op: {exc}",
|
||||
break_loop=False,
|
||||
)
|
||||
finally:
|
||||
clear_pending_file_op(op_id)
|
||||
|
||||
return Response(
|
||||
message=self._extract_result(result, op, path),
|
||||
break_loop=False,
|
||||
)
|
||||
|
||||
def _extract_result(self, result: Any, op: str, path: str) -> str:
|
||||
if not isinstance(result, dict):
|
||||
return f"Unexpected response format from CLI: {result!r}"
|
||||
|
||||
ok = bool(result.get("ok"))
|
||||
data = result.get("result")
|
||||
error = result.get("error")
|
||||
|
||||
if not ok:
|
||||
return f"Error ({op} {path!r}): {error or 'Unknown error'}"
|
||||
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
|
||||
if op == "read":
|
||||
content = data.get("content", "")
|
||||
total_lines = data.get("total_lines", "?")
|
||||
return f"{path} {total_lines} lines\n>>>\n{content}\n<<<"
|
||||
if op == "write":
|
||||
return data.get("message") or f"{path} written successfully"
|
||||
if op == "patch":
|
||||
return data.get("message") or f"{path} patched successfully"
|
||||
return str(data)
|
||||
Loading…
Add table
Add a link
Reference in a new issue