diff --git a/backend/app/utils/agent.py b/backend/app/utils/agent.py index 15c89af9c..74e70fa3b 100644 --- a/backend/app/utils/agent.py +++ b/backend/app/utils/agent.py @@ -17,7 +17,7 @@ from camel.toolkits import FunctionTool, RegisteredAgentToolkit from camel.types.agents import ToolCallingRecord from app.component.environment import env from app.utils.toolkit.abstract_toolkit import AbstractToolkit -from app.utils.toolkit.hybrid_browser_python_toolkit import HybridBrowserPythonToolkit +from app.utils.toolkit.hybrid_browser_toolkit import HybridBrowserToolkit from app.utils.toolkit.excel_toolkit import ExcelToolkit from app.utils.toolkit.file_write_toolkit import FileWriteToolkit from app.utils.toolkit.google_calendar_toolkit import GoogleCalendarToolkit @@ -666,7 +666,7 @@ def search_agent(options: Chat): message_handler=HumanToolkit(options.task_id, Agents.search_agent).send_message_to_user ) - web_toolkit_custom = HybridBrowserPythonToolkit( + web_toolkit_custom = HybridBrowserToolkit( options.task_id, headless=False, browser_log_to_file=True, @@ -812,7 +812,7 @@ Your capabilities include: prune_tool_calls_from_memory=True, tool_names=[ SearchToolkit.toolkit_name(), - HybridBrowserPythonToolkit.toolkit_name(), + HybridBrowserToolkit.toolkit_name(), HumanToolkit.toolkit_name(), NoteTakingToolkit.toolkit_name(), TerminalToolkit.toolkit_name(), diff --git a/backend/app/utils/toolkit/hybrid_browser_toolkit.py b/backend/app/utils/toolkit/hybrid_browser_toolkit.py index 7ea680454..90e6de814 100644 --- a/backend/app/utils/toolkit/hybrid_browser_toolkit.py +++ b/backend/app/utils/toolkit/hybrid_browser_toolkit.py @@ -1,14 +1,19 @@ import os import subprocess import time -from typing import Any, Dict, List +import asyncio +import json +from typing import Any, Dict, List, Optional +from loguru import logger +import websockets +import websockets.exceptions + from camel.models import BaseModelBackend from camel.toolkits.hybrid_browser_toolkit.hybrid_browser_toolkit_ts import ( HybridBrowserToolkit as BaseHybridBrowserToolkit, ) -from camel.toolkits.hybrid_browser_toolkit.ws_wrapper import WebSocketBrowserWrapper as BaseWebSocketBrowserWrapper -from loguru import logger -import websockets +from camel.toolkits.hybrid_browser_toolkit.ws_wrapper import \ + WebSocketBrowserWrapper as BaseWebSocketBrowserWrapper from app.component.command import bun, uv from app.service.task import Agents from app.utils.listen.toolkit_listen import listen_toolkit @@ -16,6 +21,69 @@ from app.utils.toolkit.abstract_toolkit import AbstractToolkit class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper): + def __init__(self, config: Optional[Dict[str, Any]] = None): + """Initialize wrapper.""" + super().__init__(config) + logger.info(f"WebSocketBrowserWrapper using ts_dir: {self.ts_dir}") + + async def _receive_loop(self): + """Background task to receive messages from WebSocket with enhanced logging.""" + logger.debug("WebSocket receive loop started") + disconnect_reason = None + + try: + while self.websocket: + try: + response_data = await self.websocket.recv() + response = json.loads(response_data) + + message_id = response.get('id') + if message_id and message_id in self._pending_responses: + # Set the result for the waiting coroutine + future = self._pending_responses.pop(message_id) + if not future.done(): + future.set_result(response) + logger.debug( + f"Processed response for message {message_id}") + else: + # Log unexpected messages + logger.warning( + f"Received unexpected message: {response}") + + except asyncio.CancelledError: + disconnect_reason = "Receive loop cancelled" + logger.info(f"WebSocket disconnect: {disconnect_reason}") + break + except websockets.exceptions.ConnectionClosed as e: + disconnect_reason = f"WebSocket closed: code={e.code}, reason={e.reason}" + logger.warning( + f"WebSocket disconnect: {disconnect_reason}") + break + except websockets.exceptions.WebSocketException as e: + disconnect_reason = f"WebSocket error: {type(e).__name__}: {e}" + logger.error( + f"WebSocket disconnect: {disconnect_reason}") + break + except json.JSONDecodeError as e: + logger.error(f"Failed to decode WebSocket message: {e}") + continue # Try to continue on JSON errors + except Exception as e: + disconnect_reason = f"Unexpected error: {type(e).__name__}: {e}" + logger.error( + f"WebSocket disconnect: {disconnect_reason}", + exc_info=True) + # Notify all pending futures of the error + for future in self._pending_responses.values(): + if not future.done(): + future.set_exception(e) + self._pending_responses.clear() + break + finally: + logger.info( + f"WebSocket receive loop terminated. Reason: {disconnect_reason or 'Normal shutdown'}") + # Mark the websocket as None to indicate disconnection + self.websocket = None + async def start(self): # Check if node_modules exists (dependencies installed) node_modules_path = os.path.join(self.ts_dir, "node_modules") @@ -44,7 +112,14 @@ class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper): ) if build_result.returncode != 0: logger.error(f"TypeScript build failed: {build_result.stderr}") - raise RuntimeError(f"TypeScript build failed: {build_result.stderr}") + raise RuntimeError( + f"TypeScript build failed: {build_result.stderr}") + else: + # Log warnings but don't fail on them + if build_result.stderr: + logger.warning( + f"TypeScript build warnings: {build_result.stderr}") + logger.info("TypeScript build completed successfully") # Start the WebSocket server self.process = subprocess.Popen( @@ -64,7 +139,8 @@ class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper): if self.process.poll() is not None: # Process died stderr = self.process.stderr.read() # type: ignore - raise RuntimeError(f"WebSocket server failed to start: {stderr}") + raise RuntimeError( + f"WebSocket server failed to start: {stderr}") try: line = self.process.stdout.readline() # type: ignore @@ -72,13 +148,15 @@ class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper): if line.startswith("SERVER_READY:"): self.server_port = int(line.split(":")[1].strip()) server_ready = True - logger.info(f"WebSocket server ready on port {self.server_port}") + logger.info( + f"WebSocket server ready on port {self.server_port}") except (ValueError, IndexError): continue if not server_ready: self.process.kill() - raise RuntimeError("WebSocket server failed to start within timeout") + raise RuntimeError( + "WebSocket server failed to start within timeout") # Connect to the WebSocket server try: @@ -91,44 +169,195 @@ class WebSocketBrowserWrapper(BaseWebSocketBrowserWrapper): logger.info("Connected to WebSocket server") except Exception as e: self.process.kill() - raise RuntimeError(f"Failed to connect to WebSocket server: {e}") from e + raise RuntimeError( + f"Failed to connect to WebSocket server: {e}") from e + + # Start the background receiver task - THIS WAS MISSING! + self._receive_task = asyncio.create_task(self._receive_loop()) + logger.debug("Started WebSocket receiver task") # Initialize the browser toolkit logger.debug(f"send init {self.config}") - await self._send_command("init", self.config) - logger.debug("WebSocket server initialized successfully") + try: + await self._send_command("init", self.config) + logger.debug("WebSocket server initialized successfully") + except RuntimeError as e: + if "Timeout waiting for response to command: init" in str(e): + logger.warning( + "Init timeout - continuing anyway (CDP connection may be slow)") + # Continue without error - the WebSocket server is likely still initializing + else: + raise + + async def _send_command(self, command: str, params: Dict[str, Any]) -> \ + Dict[str, Any]: + """Send a command to the WebSocket server with enhanced error handling.""" + try: + # First ensure we have a valid connection + if self.websocket is None: + raise RuntimeError("WebSocket connection not established") + + # Check connection state before sending + if hasattr(self.websocket, 'state'): + import websockets.protocol + if self.websocket.state != websockets.protocol.State.OPEN: + raise RuntimeError( + f"WebSocket is in {self.websocket.state} state, not OPEN") + + logger.debug( + f"Sending command '{command}' with params: {params}") + + # Call parent's _send_command + result = await super()._send_command(command, params) + + logger.debug(f"Command '{command}' completed successfully") + return result + + except RuntimeError as e: + logger.error(f"Failed to send command '{command}': {e}") + # Check if it's a connection issue + if "WebSocket" in str(e) or "connection" in str(e).lower(): + # Mark connection as dead + self.websocket = None + raise + except Exception as e: + logger.error( + f"Unexpected error sending command '{command}': {type(e).__name__}: {e}") + raise -websocket_browser_wrapper = None -"""ensure only one instance of websocket_browser_wrapper""" +# WebSocket connection pool +class WebSocketConnectionPool: + """Manage WebSocket browser connections with session-based pooling.""" + + def __init__(self): + self._connections: Dict[str, WebSocketBrowserWrapper] = {} + self._lock = asyncio.Lock() + + async def get_connection(self, session_id: str, config: Dict[ + str, Any]) -> WebSocketBrowserWrapper: + """Get or create a connection for the given session ID.""" + async with self._lock: + # Check if we have an existing connection for this session + if session_id in self._connections: + wrapper = self._connections[session_id] + + # Comprehensive connection health check + is_healthy = False + if wrapper.websocket: + try: + # Check WebSocket state based on available attributes + if hasattr(wrapper.websocket, 'state'): + import websockets.protocol + is_healthy = wrapper.websocket.state == websockets.protocol.State.OPEN + if not is_healthy: + logger.debug( + f"Session {session_id} WebSocket state: {wrapper.websocket.state}") + elif hasattr(wrapper.websocket, 'open'): + is_healthy = wrapper.websocket.open + else: + # Try ping as last resort + try: + await asyncio.wait_for( + wrapper.websocket.ping(), timeout=1.0) + is_healthy = True + except: + is_healthy = False + except Exception as e: + logger.debug( + f"Health check failed for session {session_id}: {e}") + is_healthy = False + + if is_healthy: + logger.debug( + f"Reusing healthy WebSocket connection for session {session_id}") + return wrapper + else: + # Connection is unhealthy, clean it up + logger.info( + f"Removing unhealthy WebSocket connection for session {session_id}") + try: + await wrapper.stop() + except Exception as e: + logger.debug( + f"Error stopping unhealthy wrapper: {e}") + del self._connections[session_id] + + # Create a new connection + logger.info( + f"Creating new WebSocket connection for session {session_id}") + wrapper = WebSocketBrowserWrapper(config) + await wrapper.start() + self._connections[session_id] = wrapper + logger.info( + f"Successfully created WebSocket connection for session {session_id}") + return wrapper + + async def close_connection(self, session_id: str): + """Close and remove a connection for the given session ID.""" + async with self._lock: + if session_id in self._connections: + wrapper = self._connections[session_id] + try: + await wrapper.stop() + except Exception as e: + logger.error( + f"Error closing WebSocket connection for session {session_id}: {e}") + del self._connections[session_id] + logger.info( + f"Closed WebSocket connection for session {session_id}") + + async def _close_connection_unlocked(self, session_id: str): + """Close connection without acquiring lock (for internal use).""" + if session_id in self._connections: + wrapper = self._connections[session_id] + try: + await wrapper.stop() + except Exception as e: + logger.error( + f"Error closing WebSocket connection for session {session_id}: {e}") + del self._connections[session_id] + logger.info( + f"Closed WebSocket connection for session {session_id}") + + async def close_all(self): + """Close all connections in the pool.""" + async with self._lock: + for session_id in list(self._connections.keys()): + await self._close_connection_unlocked(session_id) + logger.info("Closed all WebSocket connections") + + +# Global connection pool instance +websocket_connection_pool = WebSocketConnectionPool() class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit): agent_name: str = Agents.search_agent def __init__( - self, - api_task_id: str, - *, - headless: bool = False, - user_data_dir: str | None = None, - stealth: bool = True, - web_agent_model: BaseModelBackend | None = None, - cache_dir: str = "tmp/", - enabled_tools: List[str] | None = None, - browser_log_to_file: bool = False, - session_id: str | None = None, - default_start_url: str = "https://google.com/", - default_timeout: int | None = None, - short_timeout: int | None = None, - navigation_timeout: int | None = None, - network_idle_timeout: int | None = None, - screenshot_timeout: int | None = None, - page_stability_timeout: int | None = None, - dom_content_loaded_timeout: int | None = None, - viewport_limit: bool = False, - connect_over_cdp: bool = False, - cdp_url: str | None = None, + self, + api_task_id: str, + *, + headless: bool = False, + user_data_dir: str | None = None, + stealth: bool = True, + web_agent_model: BaseModelBackend | None = None, + cache_dir: str = "tmp/", + enabled_tools: List[str] | None = None, + browser_log_to_file: bool = False, + session_id: str | None = None, + default_start_url: str = "https://google.com/", + default_timeout: int | None = None, + short_timeout: int | None = None, + navigation_timeout: int | None = None, + network_idle_timeout: int | None = None, + screenshot_timeout: int | None = None, + page_stability_timeout: int | None = None, + dom_content_loaded_timeout: int | None = None, + viewport_limit: bool = False, + connect_over_cdp: bool = True, + cdp_url: str | None = "http://localhost:9222", ) -> None: self.api_task_id = api_task_id super().__init__( @@ -154,15 +383,26 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit): ) async def _ensure_ws_wrapper(self): - """Ensure WebSocket wrapper is initialized.""" - if self._ws_wrapper is None: - global websocket_browser_wrapper - if websocket_browser_wrapper is None: - websocket_browser_wrapper = WebSocketBrowserWrapper(self._ws_config) - self._ws_wrapper = websocket_browser_wrapper - await self._ws_wrapper.start() + """Ensure WebSocket wrapper is initialized using connection pool.""" + global websocket_connection_pool - def clone_for_new_session(self, new_session_id: str | None = None) -> "HybridBrowserToolkit": + # Get session ID from config or use default + session_id = self._ws_config.get('session_id', 'default') + + # Get or create connection from pool + self._ws_wrapper = await websocket_connection_pool.get_connection( + session_id, self._ws_config) + + # Additional health check + if self._ws_wrapper.websocket is None: + logger.warning( + f"WebSocket connection for session {session_id} is None after pool retrieval, recreating...") + await websocket_connection_pool.close_connection(session_id) + self._ws_wrapper = await websocket_connection_pool.get_connection( + session_id, self._ws_config) + + def clone_for_new_session(self, + new_session_id: str | None = None) -> "HybridBrowserToolkit": import uuid if new_session_id is None: @@ -195,6 +435,28 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit): def toolkit_name(cls) -> str: return "Browser Toolkit" + async def close(self): + """Close the browser toolkit and release WebSocket connection.""" + try: + # Close browser if needed + if self._ws_wrapper: + await super().browser_close() + except Exception as e: + logger.error(f"Error closing browser: {e}") + + # Release connection from pool + session_id = self._ws_config.get('session_id', 'default') + await websocket_connection_pool.close_connection(session_id) + logger.info( + f"Released WebSocket connection for session {session_id}") + + def __del__(self): + """Cleanup when object is garbage collected.""" + if hasattr(self, '_ws_wrapper') and self._ws_wrapper: + session_id = self._ws_config.get('session_id', 'default') + logger.debug( + f"HybridBrowserToolkit for session {session_id} is being garbage collected") + @listen_toolkit(BaseHybridBrowserToolkit.browser_open) async def browser_open(self) -> Dict[str, Any]: return await super().browser_open() @@ -205,7 +467,15 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit): @listen_toolkit(BaseHybridBrowserToolkit.browser_visit_page) async def browser_visit_page(self, url: str) -> Dict[str, Any]: - return await super().browser_visit_page(url) + logger.debug(f"browser_visit_page called with URL: {url}") + try: + result = await super().browser_visit_page(url) + logger.debug(f"browser_visit_page succeeded for URL: {url}") + return result + except Exception as e: + logger.error( + f"browser_visit_page failed for URL {url}: {type(e).__name__}: {e}") + raise @listen_toolkit(BaseHybridBrowserToolkit.browser_back) async def browser_back(self) -> Dict[str, Any]: @@ -220,8 +490,10 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit): return await super().browser_get_page_snapshot() @listen_toolkit(BaseHybridBrowserToolkit.browser_get_som_screenshot) - async def browser_get_som_screenshot(self, read_image: bool = False, instruction: str | None = None) -> str: - return await super().browser_get_som_screenshot(read_image, instruction) + async def browser_get_som_screenshot(self, read_image: bool = False, + instruction: str | None = None) -> str: + return await super().browser_get_som_screenshot(read_image, + instruction) @listen_toolkit(BaseHybridBrowserToolkit.browser_click) async def browser_click(self, *, ref: str) -> Dict[str, Any]: @@ -232,19 +504,23 @@ class HybridBrowserToolkit(BaseHybridBrowserToolkit, AbstractToolkit): return await super().browser_type(ref=ref, text=text) @listen_toolkit(BaseHybridBrowserToolkit.browser_select) - async def browser_select(self, *, ref: str, value: str) -> Dict[str, Any]: + async def browser_select(self, *, ref: str, value: str) -> Dict[ + str, Any]: return await super().browser_select(ref=ref, value=value) @listen_toolkit(BaseHybridBrowserToolkit.browser_scroll) - async def browser_scroll(self, *, direction: str, amount: int = 500) -> Dict[str, Any]: - return await super().browser_scroll(direction=direction, amount=amount) + async def browser_scroll(self, *, direction: str, amount: int = 500) -> \ + Dict[str, Any]: + return await super().browser_scroll(direction=direction, + amount=amount) @listen_toolkit(BaseHybridBrowserToolkit.browser_enter) async def browser_enter(self) -> Dict[str, Any]: return await super().browser_enter() @listen_toolkit(BaseHybridBrowserToolkit.browser_wait_user) - async def browser_wait_user(self, timeout_sec: float | None = None) -> Dict[str, Any]: + async def browser_wait_user(self, timeout_sec: float | None = None) -> \ + Dict[str, Any]: return await super().browser_wait_user(timeout_sec) @listen_toolkit(BaseHybridBrowserToolkit.browser_switch_tab)