diff --git a/api/api_message.py b/api/api_message.py index fc4e50706..1deb59589 100644 --- a/api/api_message.py +++ b/api/api_message.py @@ -1,7 +1,7 @@ import base64 import os import uuid -from datetime import datetime, timedelta +from datetime import datetime, timezone from agent import AgentContext, UserMessage, AgentContextType from helpers.api import ApiHandler, Request, Response from helpers import files, projects @@ -9,14 +9,9 @@ from helpers.print_style import PrintStyle from helpers.projects import activate_project from helpers.security import safe_filename from initialize import initialize_agent -import threading class ApiMessage(ApiHandler): - # Track chat lifetimes for cleanup - _chat_lifetimes = {} - _cleanup_lock = threading.Lock() - @classmethod def requires_auth(cls) -> bool: return False # No web auth required @@ -37,6 +32,16 @@ class ApiMessage(ApiHandler): lifetime_hours = input.get("lifetime_hours", 24) # Default 24 hours project_name = input.get("project_name", None) agent_profile = input.get("agent_profile", None) + try: + lifetime_hours = float(lifetime_hours) + if lifetime_hours <= 0: + raise ValueError("lifetime_hours must be greater than 0") + except (TypeError, ValueError): + return Response( + '{"error": "lifetime_hours must be a positive number"}', + status=400, + mimetype="application/json", + ) # Set an agent if profile provided override_settings = {} @@ -116,9 +121,9 @@ class ApiMessage(ApiHandler): except Exception as e: return Response(f'{{"error": "Failed to activate project: {str(e)}"}}', status=400, mimetype="application/json") - # Update chat lifetime - with self._cleanup_lock: - self._chat_lifetimes[context_id] = datetime.now() + timedelta(hours=lifetime_hours) + # Persist API chat lifetime in context data so cleanup survives restarts. + context.set_data("lifetime_hours", lifetime_hours) + context.last_message = datetime.now(timezone.utc) # Process message try: @@ -148,9 +153,6 @@ class ApiMessage(ApiHandler): task = context.communicate(UserMessage(message=message, attachments=attachment_paths, id=msg_id)) result = await task.result() - # Clean up expired chats - self._cleanup_expired_chats() - return { "context_id": context_id, "response": result @@ -159,24 +161,3 @@ class ApiMessage(ApiHandler): except Exception as e: PrintStyle.error(f"External API error: {e}") return Response(f'{{"error": "{str(e)}"}}', status=500, mimetype="application/json") - - @classmethod - def _cleanup_expired_chats(cls): - """Clean up expired chats""" - with cls._cleanup_lock: - now = datetime.now() - expired_contexts = [ - context_id for context_id, expiry in cls._chat_lifetimes.items() - if now > expiry - ] - - for context_id in expired_contexts: - try: - context = AgentContext.get(context_id) - if context: - context.reset() - AgentContext.remove(context_id) - del cls._chat_lifetimes[context_id] - PrintStyle().print(f"Cleaned up expired chat: {context_id}") - except Exception as e: - PrintStyle.error(f"Failed to cleanup chat {context_id}: {e}") diff --git a/extensions/python/job_loop/_20_cleanup_expired_api_chats.py b/extensions/python/job_loop/_20_cleanup_expired_api_chats.py new file mode 100644 index 000000000..b8d9da0a0 --- /dev/null +++ b/extensions/python/job_loop/_20_cleanup_expired_api_chats.py @@ -0,0 +1,63 @@ +from datetime import datetime, timedelta, timezone +from typing import Any + +from agent import AgentContext +from helpers import persist_chat +from helpers.extension import Extension +from helpers.print_style import PrintStyle +from helpers.state_monitor_integration import mark_dirty_all + + +CHECK_INTERVAL = timedelta(hours=1) +LIFETIME_KEY = "lifetime_hours" + + +class CleanupExpiredApiChats(Extension): + _last_check: datetime | None = None + + async def execute(self, data: dict[str, Any] | None = None, **kwargs): + now = datetime.now(timezone.utc) + if type(self)._last_check and now - type(self)._last_check < CHECK_INTERVAL: + return + type(self)._last_check = now + + removed = 0 + for context in list(AgentContext.all()): + lifetime_hours = context.get_data(LIFETIME_KEY) + if lifetime_hours is None: + continue + + try: + lifetime = timedelta(hours=float(lifetime_hours)) + except (TypeError, ValueError): + PrintStyle.error( + f"Invalid chat lifetime for {context.id}: {lifetime_hours}" + ) + continue + + if lifetime <= timedelta(0) or context.is_running(): + continue + + last_message = _as_utc(context.last_message) + if now - last_message <= lifetime: + continue + + try: + context.reset() + AgentContext.remove(context.id) + persist_chat.remove_chat(context.id) + removed += 1 + PrintStyle().print(f"Cleaned up expired API chat: {context.id}") + except Exception as e: + PrintStyle.error(f"Failed to cleanup expired API chat {context.id}: {e}") + + if removed: + mark_dirty_all(reason="job_loop.CleanupExpiredApiChats") + + +def _as_utc(value: datetime | None) -> datetime: + if value is None: + return datetime.fromtimestamp(0, timezone.utc) + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + return value.astimezone(timezone.utc) diff --git a/tests/test_api_chat_lifetime.py b/tests/test_api_chat_lifetime.py new file mode 100644 index 000000000..2e893537b --- /dev/null +++ b/tests/test_api_chat_lifetime.py @@ -0,0 +1,84 @@ +from datetime import datetime, timedelta, timezone +import json +from pathlib import Path +import sys +import threading + +import pytest + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from agent import AgentContext +from initialize import initialize_agent + + +class _CompletedTask: + async def result(self): + return "ok" + + +@pytest.mark.asyncio +async def test_api_message_persists_lifetime_hours_in_context_data(monkeypatch): + from api.api_message import ApiMessage + from helpers import persist_chat + + monkeypatch.setattr(AgentContext, "communicate", lambda self, msg: _CompletedTask()) + + handler = ApiMessage(app=None, thread_lock=threading.RLock()) # type: ignore[arg-type] + output = await handler.process( + { + "message": "hello", + "lifetime_hours": 1, + }, + request=None, # type: ignore[arg-type] + ) + + context_id = output["context_id"] # type: ignore[index] + context = AgentContext.get(context_id) + restored = None + try: + assert context is not None + assert context.get_data("lifetime_hours") == 1.0 + + serialized = json.loads(persist_chat.export_json_chat(context)) + assert serialized["data"]["lifetime_hours"] == 1.0 + + AgentContext.remove(context_id) + restored = persist_chat._deserialize_context(serialized) + assert restored.get_data("lifetime_hours") == 1.0 + finally: + AgentContext.remove(context_id) + if restored: + AgentContext.remove(restored.id) + + +@pytest.mark.asyncio +async def test_job_loop_removes_expired_lifetime_chat(monkeypatch): + from extensions.python.job_loop._20_cleanup_expired_api_chats import ( + CleanupExpiredApiChats, + ) + import extensions.python.job_loop._20_cleanup_expired_api_chats as cleanup_module + + removed_chats = [] + dirty_reasons = [] + monkeypatch.setattr(cleanup_module.persist_chat, "remove_chat", removed_chats.append) + monkeypatch.setattr( + cleanup_module, + "mark_dirty_all", + lambda reason: dirty_reasons.append(reason), + ) + + context = AgentContext( + config=initialize_agent(), + last_message=datetime.now(timezone.utc) - timedelta(hours=2), + ) + context.set_data("lifetime_hours", 1) + CleanupExpiredApiChats._last_check = None + + await CleanupExpiredApiChats(agent=None).execute() + + assert AgentContext.get(context.id) is None + assert removed_chats == [context.id] + assert dirty_reasons == ["job_loop.CleanupExpiredApiChats"]