Merge remote-tracking branch 'upstream/settings-refactor' into schedulercmp

This commit is contained in:
Alessandro 2025-12-22 14:08:14 +01:00
commit 06a6632997
10 changed files with 175 additions and 85 deletions

View file

@ -5,6 +5,7 @@ from agent import AgentContext, UserMessage, AgentContextType
from python.helpers.api import ApiHandler, Request, Response
from python.helpers import files, projects
from python.helpers.print_style import PrintStyle
from python.helpers.projects import activate_project
from werkzeug.utils import secure_filename
from initialize import initialize_agent
import threading
@ -33,7 +34,13 @@ class ApiMessage(ApiHandler):
message = input.get("message", "")
attachments = input.get("attachments", [])
lifetime_hours = input.get("lifetime_hours", 24) # Default 24 hours
project = input.get("project", None) # Optional project name
project_name = input.get("project_name", None)
agent_profile = input.get("agent_profile", None)
# Set an agent if profile provided
override_settings = {}
if agent_profile:
override_settings["agent_profile"] = agent_profile
if not message:
return Response('{"error": "Message is required"}', status=400, mimetype="application/json")
@ -73,20 +80,38 @@ class ApiMessage(ApiHandler):
if not context:
return Response('{"error": "Context not found"}', status=404, mimetype="application/json")
# Validation: if agent profile is provided, it must match the exising
if agent_profile and context.agent0.config.profile != agent_profile:
return Response('{"error": "Cannot override agent profile on existing context"}', status=400, mimetype="application/json")
# Validation: if project is provided but context already has different project
existing_project = context.get_data(projects.CONTEXT_DATA_KEY_PROJECT)
if project and existing_project and existing_project != project:
if project_name and existing_project and existing_project != project_name:
return Response('{"error": "Project can only be set on first message"}', status=400, mimetype="application/json")
else:
config = initialize_agent()
config = initialize_agent(override_settings=override_settings)
context = AgentContext(config=config, type=AgentContextType.USER)
AgentContext.use(context.id)
context_id = context.id
# Activate project if provided
if project_name:
try:
activate_project(context_id, project_name)
except Exception as e:
# Handle project or context errors more gracefully
error_msg = str(e)
PrintStyle.error(f"Failed to activate project '{project_name}' for context '{context_id}': {error_msg}")
return Response(
f'{{"error": "Failed to activate project \\"{project_name}\\""}}',
status=500,
mimetype="application/json",
)
# Activate project if provided
if project:
if project_name:
try:
projects.activate_project(context_id, project)
projects.activate_project(context_id, project_name)
except Exception as e:
return Response(f'{{"error": "Failed to activate project: {str(e)}"}}', status=400, mimetype="application/json")

View file

@ -11,6 +11,8 @@ from python.helpers.api import (
from python.helpers import runtime, dotenv, login
import fnmatch
ALLOWED_ORIGINS_KEY = "ALLOWED_ORIGINS"
class GetCsrfToken(ApiHandler):
@ -44,9 +46,11 @@ class GetCsrfToken(ApiHandler):
}
async def check_allowed_origin(self, request: Request):
# if login is required, this che
# if login is required, this check is unnecessary
if login.is_login_required():
return {"ok": True, "origin": "", "allowed_origins": ""}
# initialize allowed origins if not yet set
self.initialize_allowed_origins(request)
# otherwise, check if the origin is allowed
return await self.is_allowed_origin(request)
@ -66,6 +70,7 @@ class GetCsrfToken(ApiHandler):
)
return {"ok": match, "origin": origin, "allowed_origins": allowed_origins}
def get_origin_from_request(self, request: Request):
# get from origin
r = request.headers.get("Origin") or request.environ.get("HTTP_ORIGIN")
@ -88,7 +93,7 @@ class GetCsrfToken(ApiHandler):
# get the allowed origins from the environment
allowed_origins = [
origin.strip()
for origin in (dotenv.get_dotenv_value("ALLOWED_ORIGINS") or "").split(",")
for origin in (dotenv.get_dotenv_value(ALLOWED_ORIGINS_KEY) or "").split(",")
if origin.strip()
]
@ -110,3 +115,34 @@ class GetCsrfToken(ApiHandler):
def get_default_allowed_origins(self) -> list[str]:
return ["*://localhost:*", "*://127.0.0.1:*", "*://0.0.0.0:*"]
def initialize_allowed_origins(self, request: Request):
"""
If A0 is hosted on a server, add the first visit origin to ALLOWED_ORIGINS.
This simplifies deployment process as users can access their new instance without
additional setup while keeping it secure.
"""
# dotenv value is already set, do nothing
denv = dotenv.get_dotenv_value(ALLOWED_ORIGINS_KEY)
if denv:
return
# get the origin from the request
req_origin = self.get_origin_from_request(request)
if not req_origin:
return
# check if the origin is allowed by default
allowed_origins = self.get_default_allowed_origins()
match = any(
fnmatch.fnmatch(req_origin, allowed_origin)
for allowed_origin in allowed_origins
)
if match:
return
# if not, add it to the allowed origins
allowed_origins.append(req_origin)
dotenv.save_dotenv_value(ALLOWED_ORIGINS_KEY, ",".join(allowed_origins))