agent-zero/helpers/api.py
keyboardstaff 1d81f72a31 refactor: Backend core rewrite - WsHandler + WsManager + handler migration
- Add WsHandler base class, WsManager (connection tracking / event routing / buffering), WsResult
- Extract network.py (is_loopback_address) and context_utils.py (use_context) to eliminate duplication
- Migrate three handlers to api/ following the ws_* py naming convention
- Simplify run_ui.py WebSocket init from ~170 lines to ~10
- Update import paths in api.py, plugins.py, state_monitor.py
2026-03-26 00:58:01 -07:00

269 lines
8.5 KiB
Python

from abc import abstractmethod
import json
import threading
from functools import wraps
from pathlib import Path
from typing import Union, Dict, Any
from flask import (
Request,
Response,
jsonify,
Flask,
session,
request,
send_file,
redirect,
url_for,
)
from werkzeug.wrappers.response import Response as BaseResponse
from agent import AgentContext
from helpers.print_style import PrintStyle
from helpers.errors import format_error
from helpers import files, cache
ThreadLockType = Union[threading.Lock, threading.RLock]
CACHE_AREA = "api_handlers(api)"
cache.toggle_area(CACHE_AREA, False) # cache off for now
Input = dict
Output = Union[Dict[str, Any], Response]
class ApiHandler:
def __init__(self, app: Flask, thread_lock: ThreadLockType):
self.app = app
self.thread_lock = thread_lock
@classmethod
def requires_loopback(cls) -> bool:
return False
@classmethod
def requires_api_key(cls) -> bool:
return False
@classmethod
def requires_auth(cls) -> bool:
return True
@classmethod
def get_methods(cls) -> list[str]:
return ["POST"]
@classmethod
def requires_csrf(cls) -> bool:
return cls.requires_auth()
@abstractmethod
async def process(self, input: Input, request: Request) -> Output:
pass
async def handle_request(self, request: Request) -> Response:
try:
# input data from request based on type
input_data: Input = {}
if request.is_json:
try:
if request.data: # Check if there's any data
input_data = request.get_json()
# If empty or not valid JSON, use empty dict
except Exception as e:
# Just log the error and continue with empty input
PrintStyle().print(f"Error parsing JSON: {str(e)}")
input_data = {}
else:
# input_data = {"data": request.get_data(as_text=True)}
input_data = {}
# process via handler
output = await self.process(input_data, request)
# return output based on type
if isinstance(output, Response):
return output
else:
response_json = json.dumps(output)
return Response(
response=response_json, status=200, mimetype="application/json"
)
# return exceptions with 500
except Exception as e:
error = format_error(e)
PrintStyle.error(f"API error: {error}")
return Response(response=error, status=500, mimetype="text/plain")
# get context to run agent zero in
def use_context(self, ctxid: str, create_if_not_exists: bool = True):
from helpers.context_utils import use_context as _use_context
return _use_context(self.thread_lock, ctxid, create_if_not_exists)
from helpers.network import is_loopback_address
def requires_api_key(f):
@wraps(f)
async def decorated(*args, **kwargs):
from helpers.settings import get_settings
valid_api_key = get_settings()["mcp_server_token"]
if api_key := request.headers.get("X-API-KEY"):
if api_key != valid_api_key:
return Response("Invalid API key", 401)
elif request.json and request.json.get("api_key"):
api_key = request.json.get("api_key")
if api_key != valid_api_key:
return Response("Invalid API key", 401)
else:
return Response("API key required", 401)
return await f(*args, **kwargs)
return decorated
def requires_loopback(f):
@wraps(f)
async def decorated(*args, **kwargs):
if not is_loopback_address(str(request.remote_addr)):
return Response("Access denied.", 403, {})
return await f(*args, **kwargs)
return decorated
def requires_auth(f):
@wraps(f)
async def decorated(*args, **kwargs):
from helpers import login
user_pass_hash = login.get_credentials_hash()
if not user_pass_hash:
return await f(*args, **kwargs)
if session.get("authentication") != user_pass_hash:
return redirect(url_for("login_handler"))
return await f(*args, **kwargs)
return decorated
def csrf_protect(f):
@wraps(f)
async def decorated(*args, **kwargs):
from helpers import runtime
token = session.get("csrf_token")
header = request.headers.get("X-CSRF-Token")
cookie = request.cookies.get("csrf_token_" + runtime.get_runtime_id())
sent = header or cookie
if not token or not sent or token != sent:
return Response("CSRF token missing or invalid", 403)
return await f(*args, **kwargs)
return decorated
def register_api_route(app: Flask, lock: ThreadLockType) -> None:
from helpers.modules import load_classes_from_file
from helpers import plugins
async def _dispatch(path: str) -> BaseResponse:
# Return cached wrapped handler if available
cached = cache.get(CACHE_AREA, path)
if cached is not None:
return await cached()
# Resolve file path for the handler
# Try built-in api folder first, then plugin api folders
handler_cls: type[ApiHandler] | None = None
# Check built-in python/api/<path>.py
builtin_file = files.get_abs_path(f"api/{path}.py")
if files.is_in_dir(builtin_file, files.get_abs_path("api")) and files.exists(
builtin_file
):
classes = load_classes_from_file(builtin_file, ApiHandler)
if classes:
handler_cls = classes[0]
# Check plugin api folders: path format plugins/<plugin_name>/<handler>
if handler_cls is None and path.startswith("plugins/"):
parts = path.split("/", 2)
if len(parts) == 3:
_, plugin_name, handler_name = parts
plugin_dir = plugins.find_plugin_dir(plugin_name)
if plugin_dir:
plugin_file = Path(plugin_dir) / "api" / f"{handler_name}.py"
if plugin_file.is_file():
classes = load_classes_from_file(str(plugin_file), ApiHandler)
if classes:
handler_cls = classes[0]
if handler_cls is None:
return Response(f"API endpoint not found: {path}", 404)
# Check method is allowed
if request.method not in handler_cls.get_methods():
return Response(f"Method {request.method} not allowed for: {path}", 405)
# Build handler call, wrapping with security decorators as required
async def call_handler() -> BaseResponse:
instance = handler_cls(app, lock)
return await instance.handle_request(request=request)
handler_fn = call_handler
if handler_cls.requires_csrf():
handler_fn = csrf_protect(handler_fn)
if handler_cls.requires_api_key():
handler_fn = requires_api_key(handler_fn)
if handler_cls.requires_auth():
handler_fn = requires_auth(handler_fn)
if handler_cls.requires_loopback():
handler_fn = requires_loopback(handler_fn)
cache.add(CACHE_AREA, path, handler_fn)
return await handler_fn()
app.add_url_rule(
"/api/<path:path>",
"api_dispatch",
_dispatch,
methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
)
def register_watchdogs():
from helpers import watchdog
def on_api_change(items: list[watchdog.WatchItem]):
PrintStyle.debug("API endpoint watchdog triggered:", items)
cache.clear(CACHE_AREA)
watchdog.add_watchdog(
"api_handlers",
roots=[
files.get_abs_path(files.API_DIR),
files.get_abs_path(files.USER_DIR, files.API_DIR),
],
patterns=["*.py"],
handler=on_api_change,
)
# WS handler cache shares the same watched directories (api/, usr/api/)
from helpers.ws import CACHE_AREA as WS_CACHE_AREA
def on_ws_change(items: list[watchdog.WatchItem]):
PrintStyle.debug("WS handler watchdog triggered:", items)
cache.clear(WS_CACHE_AREA)
watchdog.add_watchdog(
"ws_handlers",
roots=[
files.get_abs_path(files.API_DIR),
files.get_abs_path(files.USER_DIR, files.API_DIR),
],
patterns=["ws_*.py"],
handler=on_ws_change,
)