mirror of
https://github.com/agent0ai/agent-zero.git
synced 2026-05-20 17:53:37 +00:00
Merge remote-tracking branch 'upstream/settings-refactor' into schedulercmp
This commit is contained in:
commit
06a6632997
10 changed files with 175 additions and 85 deletions
|
|
@ -5,6 +5,7 @@ from agent import AgentContext, UserMessage, AgentContextType
|
|||
from python.helpers.api import ApiHandler, Request, Response
|
||||
from python.helpers import files, projects
|
||||
from python.helpers.print_style import PrintStyle
|
||||
from python.helpers.projects import activate_project
|
||||
from werkzeug.utils import secure_filename
|
||||
from initialize import initialize_agent
|
||||
import threading
|
||||
|
|
@ -33,7 +34,13 @@ class ApiMessage(ApiHandler):
|
|||
message = input.get("message", "")
|
||||
attachments = input.get("attachments", [])
|
||||
lifetime_hours = input.get("lifetime_hours", 24) # Default 24 hours
|
||||
project = input.get("project", None) # Optional project name
|
||||
project_name = input.get("project_name", None)
|
||||
agent_profile = input.get("agent_profile", None)
|
||||
|
||||
# Set an agent if profile provided
|
||||
override_settings = {}
|
||||
if agent_profile:
|
||||
override_settings["agent_profile"] = agent_profile
|
||||
|
||||
if not message:
|
||||
return Response('{"error": "Message is required"}', status=400, mimetype="application/json")
|
||||
|
|
@ -73,20 +80,38 @@ class ApiMessage(ApiHandler):
|
|||
if not context:
|
||||
return Response('{"error": "Context not found"}', status=404, mimetype="application/json")
|
||||
|
||||
# Validation: if agent profile is provided, it must match the exising
|
||||
if agent_profile and context.agent0.config.profile != agent_profile:
|
||||
return Response('{"error": "Cannot override agent profile on existing context"}', status=400, mimetype="application/json")
|
||||
|
||||
|
||||
# Validation: if project is provided but context already has different project
|
||||
existing_project = context.get_data(projects.CONTEXT_DATA_KEY_PROJECT)
|
||||
if project and existing_project and existing_project != project:
|
||||
if project_name and existing_project and existing_project != project_name:
|
||||
return Response('{"error": "Project can only be set on first message"}', status=400, mimetype="application/json")
|
||||
else:
|
||||
config = initialize_agent()
|
||||
config = initialize_agent(override_settings=override_settings)
|
||||
context = AgentContext(config=config, type=AgentContextType.USER)
|
||||
AgentContext.use(context.id)
|
||||
context_id = context.id
|
||||
# Activate project if provided
|
||||
if project_name:
|
||||
try:
|
||||
activate_project(context_id, project_name)
|
||||
except Exception as e:
|
||||
# Handle project or context errors more gracefully
|
||||
error_msg = str(e)
|
||||
PrintStyle.error(f"Failed to activate project '{project_name}' for context '{context_id}': {error_msg}")
|
||||
return Response(
|
||||
f'{{"error": "Failed to activate project \\"{project_name}\\""}}',
|
||||
status=500,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
# Activate project if provided
|
||||
if project:
|
||||
if project_name:
|
||||
try:
|
||||
projects.activate_project(context_id, project)
|
||||
projects.activate_project(context_id, project_name)
|
||||
except Exception as e:
|
||||
return Response(f'{{"error": "Failed to activate project: {str(e)}"}}', status=400, mimetype="application/json")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from python.helpers.api import (
|
|||
from python.helpers import runtime, dotenv, login
|
||||
import fnmatch
|
||||
|
||||
ALLOWED_ORIGINS_KEY = "ALLOWED_ORIGINS"
|
||||
|
||||
|
||||
class GetCsrfToken(ApiHandler):
|
||||
|
||||
|
|
@ -44,9 +46,11 @@ class GetCsrfToken(ApiHandler):
|
|||
}
|
||||
|
||||
async def check_allowed_origin(self, request: Request):
|
||||
# if login is required, this che
|
||||
# if login is required, this check is unnecessary
|
||||
if login.is_login_required():
|
||||
return {"ok": True, "origin": "", "allowed_origins": ""}
|
||||
# initialize allowed origins if not yet set
|
||||
self.initialize_allowed_origins(request)
|
||||
# otherwise, check if the origin is allowed
|
||||
return await self.is_allowed_origin(request)
|
||||
|
||||
|
|
@ -66,6 +70,7 @@ class GetCsrfToken(ApiHandler):
|
|||
)
|
||||
return {"ok": match, "origin": origin, "allowed_origins": allowed_origins}
|
||||
|
||||
|
||||
def get_origin_from_request(self, request: Request):
|
||||
# get from origin
|
||||
r = request.headers.get("Origin") or request.environ.get("HTTP_ORIGIN")
|
||||
|
|
@ -88,7 +93,7 @@ class GetCsrfToken(ApiHandler):
|
|||
# get the allowed origins from the environment
|
||||
allowed_origins = [
|
||||
origin.strip()
|
||||
for origin in (dotenv.get_dotenv_value("ALLOWED_ORIGINS") or "").split(",")
|
||||
for origin in (dotenv.get_dotenv_value(ALLOWED_ORIGINS_KEY) or "").split(",")
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
|
|
@ -110,3 +115,34 @@ class GetCsrfToken(ApiHandler):
|
|||
|
||||
def get_default_allowed_origins(self) -> list[str]:
|
||||
return ["*://localhost:*", "*://127.0.0.1:*", "*://0.0.0.0:*"]
|
||||
|
||||
def initialize_allowed_origins(self, request: Request):
|
||||
"""
|
||||
If A0 is hosted on a server, add the first visit origin to ALLOWED_ORIGINS.
|
||||
This simplifies deployment process as users can access their new instance without
|
||||
additional setup while keeping it secure.
|
||||
"""
|
||||
# dotenv value is already set, do nothing
|
||||
denv = dotenv.get_dotenv_value(ALLOWED_ORIGINS_KEY)
|
||||
if denv:
|
||||
return
|
||||
|
||||
# get the origin from the request
|
||||
req_origin = self.get_origin_from_request(request)
|
||||
if not req_origin:
|
||||
return
|
||||
|
||||
# check if the origin is allowed by default
|
||||
allowed_origins = self.get_default_allowed_origins()
|
||||
match = any(
|
||||
fnmatch.fnmatch(req_origin, allowed_origin)
|
||||
for allowed_origin in allowed_origins
|
||||
)
|
||||
if match:
|
||||
return
|
||||
|
||||
# if not, add it to the allowed origins
|
||||
allowed_origins.append(req_origin)
|
||||
dotenv.save_dotenv_value(ALLOWED_ORIGINS_KEY, ",".join(allowed_origins))
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue