mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
feat: add browser session support to OSS Docker deployment (#4891)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
parent
c6d62e3fa0
commit
328bce3cdd
18 changed files with 337 additions and 106 deletions
|
|
@ -15,7 +15,8 @@ RUN pip install --upgrade pip setuptools wheel
|
|||
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
||||
RUN playwright install-deps
|
||||
RUN playwright install
|
||||
RUN apt-get install -y xauth x11-apps netpbm gpg ca-certificates && apt-get clean
|
||||
RUN apt-get install -y xauth x11-apps netpbm gpg ca-certificates x11vnc && apt-get clean
|
||||
RUN pip install --no-cache-dir websockify
|
||||
|
||||
COPY .nvmrc /app/.nvmrc
|
||||
COPY nodesource-repo.gpg.key /tmp/nodesource-repo.gpg.key
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ services:
|
|||
# comment out if you want to externally call skyvern API
|
||||
ports:
|
||||
- 8000:8000
|
||||
- 6080:6080 # for VNC WebSocket streaming
|
||||
- 9222:9222 # for cdp browser forwarding
|
||||
volumes:
|
||||
- ./artifacts:/data/artifacts
|
||||
|
|
|
|||
|
|
@ -56,6 +56,16 @@ Xvfb :99 -screen 0 1920x1080x16 &
|
|||
xvfb=$!
|
||||
|
||||
DISPLAY=:99 xterm 2>/dev/null &
|
||||
|
||||
echo "Starting x11vnc on display :99..."
|
||||
# VNC runs without a password (-nopw) because port 5900 is not exposed outside
|
||||
# the container. Browser streaming reaches users via websockify on port 6080.
|
||||
mkdir -p /data/log
|
||||
x11vnc -display :99 -forever -nopw -shared -rfbport 5900 -bg -o /dev/null 2>/data/log/x11vnc.err
|
||||
|
||||
echo "Starting websockify on port 6080 -> localhost:5900..."
|
||||
websockify 6080 localhost:5900 --daemon
|
||||
|
||||
python run_streaming.py > /dev/null &
|
||||
|
||||
# Run the command and pass in all three arguments
|
||||
|
|
|
|||
|
|
@ -91,6 +91,12 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
|
|||
except Exception:
|
||||
LOG.exception("Failed to execute api app startup event")
|
||||
|
||||
# Close browser sessions left active by a previous process
|
||||
try:
|
||||
await forge_app.PERSISTENT_SESSIONS_MANAGER.cleanup_stale_sessions()
|
||||
except Exception:
|
||||
LOG.exception("Failed to clean up stale browser sessions")
|
||||
|
||||
# Start cleanup scheduler if enabled
|
||||
cleanup_task = start_cleanup_scheduler()
|
||||
if cleanup_task:
|
||||
|
|
@ -126,6 +132,11 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
|
|||
if redis_client is not None:
|
||||
await redis_client.close()
|
||||
|
||||
# Close all persistent browser sessions
|
||||
from skyvern.webeye.default_persistent_sessions_manager import DefaultPersistentSessionsManager
|
||||
|
||||
await DefaultPersistentSessionsManager.close()
|
||||
|
||||
if forge_app.api_app_shutdown_event:
|
||||
LOG.info("Calling api app shutdown event")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -216,6 +216,7 @@ def create_forge_app() -> ForgeApp:
|
|||
app.WORKFLOW_SERVICE = WorkflowService()
|
||||
app.AGENT_FUNCTION = AgentFunction()
|
||||
app.PERSISTENT_SESSIONS_MANAGER = DefaultPersistentSessionsManager(database=app.DATABASE)
|
||||
app.PERSISTENT_SESSIONS_MANAGER.watch_session_pool()
|
||||
app.BROWSER_SESSION_RECORDING_SERVICE = BrowserSessionRecordingService()
|
||||
|
||||
app.AZURE_CLIENT_FACTORY = RealAzureClientFactory()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, List, Literal, Sequence, overload
|
||||
|
||||
|
|
@ -5301,6 +5302,7 @@ class AgentDB(BaseAlchemyDB):
|
|||
if persistent_browser_session.completed_at:
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
persistent_browser_session.completed_at = datetime.utcnow()
|
||||
persistent_browser_session.status = "completed"
|
||||
await session.commit()
|
||||
await session.refresh(persistent_browser_session)
|
||||
return PersistentBrowserSession.model_validate(persistent_browser_session)
|
||||
|
|
@ -5315,6 +5317,34 @@ class AgentDB(BaseAlchemyDB):
|
|||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def archive_browser_session_address(self, session_id: str, organization_id: str) -> None:
|
||||
"""Suffix browser_address with a unique tag so the unique constraint
|
||||
no longer blocks new sessions that reuse the same local address."""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
row = (
|
||||
await session.scalars(
|
||||
select(PersistentBrowserSessionModel)
|
||||
.filter_by(persistent_browser_session_id=session_id)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(deleted_at=None)
|
||||
)
|
||||
).first()
|
||||
|
||||
if not row or not row.browser_address:
|
||||
return
|
||||
if "::closed::" in row.browser_address:
|
||||
return
|
||||
|
||||
row.browser_address = f"{row.browser_address}::closed::{uuid.uuid4().hex}"
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_all_active_persistent_browser_sessions(self) -> List[PersistentBrowserSessionModel]:
|
||||
"""Get all active persistent browser sessions across all organizations."""
|
||||
try:
|
||||
|
|
@ -5328,6 +5358,21 @@ class AgentDB(BaseAlchemyDB):
|
|||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_uncompleted_persistent_browser_sessions(self) -> List[PersistentBrowserSessionModel]:
|
||||
"""Get all browser sessions that have not been completed or deleted."""
|
||||
try:
|
||||
async with self.Session() as session:
|
||||
result = await session.execute(
|
||||
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
except Exception:
|
||||
LOG.error("UnexpectedError", exc_info=True)
|
||||
raise
|
||||
|
||||
async def create_task_run(
|
||||
self,
|
||||
task_run_type: RunType,
|
||||
|
|
|
|||
|
|
@ -68,33 +68,3 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s
|
|||
return None
|
||||
|
||||
return organization_id
|
||||
|
||||
|
||||
# NOTE(jdo:streaming-local-dev): use this instead of the above `auth`
|
||||
async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
|
||||
"""
|
||||
Local dev auth: extracts org_id from API key without strict validation.
|
||||
Falls back to o_temp123 if no key provided.
|
||||
"""
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
except ConnectionClosedOK:
|
||||
LOG.info("WebSocket connection closed cleanly.")
|
||||
return None
|
||||
|
||||
# Try to extract real org_id from the API key
|
||||
if apikey:
|
||||
try:
|
||||
from jose import jwt
|
||||
|
||||
from skyvern.config import settings
|
||||
|
||||
payload = jwt.decode(apikey, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
org_id = payload.get("sub")
|
||||
if org_id:
|
||||
return org_id
|
||||
except Exception:
|
||||
LOG.warning("Local auth: failed to decode API key, falling back to o_temp123")
|
||||
|
||||
return "o_temp123"
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class MessageChannel:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
del_message_channel(self.client_id)
|
||||
del_message_channel(self.client_id, expected=self)
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -221,7 +221,7 @@ class VncChannel:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
del_vnc_channel(self.client_id)
|
||||
del_vnc_channel(self.client_id, expected=self)
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,8 @@ Provides WS endpoints for streaming messages to/from our frontend application.
|
|||
import structlog
|
||||
from fastapi import WebSocket
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming.channels.message import (
|
||||
Loops,
|
||||
MessageChannel,
|
||||
|
|
@ -62,7 +60,6 @@ async def messages(
|
|||
client_id: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
auth = local_auth if settings.ENV == "local" else real_auth
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
|
|
|
|||
|
|
@ -6,12 +6,10 @@ import structlog
|
|||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
HEARTBEAT_INTERVAL = 60
|
||||
|
|
@ -40,7 +38,6 @@ async def _notification_stream_handler(
|
|||
apikey: str | None = None,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
auth = local_auth if settings.ENV == "local" else real_auth
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
if not organization_id:
|
||||
LOG.info("Notifications: Authentication failed")
|
||||
|
|
|
|||
|
|
@ -35,15 +35,22 @@ def add_vnc_channel(vnc_channel: VncChannel) -> None:
|
|||
vnc_channels[vnc_channel.client_id] = vnc_channel
|
||||
|
||||
|
||||
def get_vnc_channel(client_id: str) -> t.Union[VncChannel, None]:
|
||||
return vnc_channels.get(client_id, None)
|
||||
def get_vnc_channel(client_id: str) -> VncChannel | None:
|
||||
return vnc_channels.get(client_id)
|
||||
|
||||
|
||||
def del_vnc_channel(client_id: str) -> None:
|
||||
try:
|
||||
del vnc_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
def del_vnc_channel(client_id: str, *, expected: VncChannel | None = None) -> None:
|
||||
candidate = vnc_channels.get(client_id)
|
||||
|
||||
if candidate is None:
|
||||
return
|
||||
|
||||
# Prevent stale channel shutdown from deleting a newer channel that reused
|
||||
# the same client_id during route transitions/reconnects.
|
||||
if expected is not None and candidate is not expected:
|
||||
return
|
||||
|
||||
del vnc_channels[client_id]
|
||||
|
||||
|
||||
# a registry for message channels, keyed by `client_id`
|
||||
|
|
@ -54,25 +61,32 @@ def add_message_channel(message_channel: MessageChannel) -> None:
|
|||
message_channels[message_channel.client_id] = message_channel
|
||||
|
||||
|
||||
def get_message_channel(client_id: str) -> t.Union[MessageChannel, None]:
|
||||
candidate = message_channels.get(client_id, None)
|
||||
def get_message_channel(client_id: str) -> MessageChannel | None:
|
||||
candidate = message_channels.get(client_id)
|
||||
|
||||
if candidate and candidate.is_open:
|
||||
if candidate is None:
|
||||
return None
|
||||
|
||||
if candidate.is_open:
|
||||
return candidate
|
||||
|
||||
if candidate:
|
||||
LOG.info(
|
||||
"MessageChannel: message channel is not open; deleting it",
|
||||
client_id=candidate.client_id,
|
||||
)
|
||||
|
||||
del_message_channel(candidate.client_id)
|
||||
|
||||
LOG.info(
|
||||
"MessageChannel: message channel is not open; deleting it",
|
||||
client_id=candidate.client_id,
|
||||
)
|
||||
del_message_channel(candidate.client_id, expected=candidate)
|
||||
return None
|
||||
|
||||
|
||||
def del_message_channel(client_id: str) -> None:
|
||||
try:
|
||||
del message_channels[client_id]
|
||||
except KeyError:
|
||||
pass
|
||||
def del_message_channel(client_id: str, *, expected: MessageChannel | None = None) -> None:
|
||||
candidate = message_channels.get(client_id)
|
||||
|
||||
if candidate is None:
|
||||
return
|
||||
|
||||
# Prevent stale channel shutdown from deleting a newer channel that reused
|
||||
# the same client_id during route transitions/reconnects.
|
||||
if expected is not None and candidate is not expected:
|
||||
return
|
||||
|
||||
del message_channels[client_id]
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from datetime import datetime, timedelta
|
|||
|
||||
import structlog
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge import app
|
||||
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession, is_final_status
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
|
@ -45,18 +44,6 @@ async def verify_browser_session(
|
|||
"""
|
||||
Verify the browser session exists, and is usable.
|
||||
"""
|
||||
if settings.ENV == "local":
|
||||
dummy_browser_session = AddressablePersistentBrowserSession(
|
||||
persistent_browser_session_id=browser_session_id,
|
||||
organization_id=organization_id,
|
||||
browser_address="0.0.0.0:9223",
|
||||
ip_address="localhost",
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
return dummy_browser_session
|
||||
|
||||
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session(browser_session_id, organization_id)
|
||||
|
||||
if not browser_session:
|
||||
|
|
@ -184,28 +171,6 @@ async def verify_workflow_run(
|
|||
with it.
|
||||
"""
|
||||
|
||||
if settings.ENV == "local":
|
||||
dummy_workflow_run = WorkflowRun(
|
||||
workflow_id="123",
|
||||
workflow_permanent_id="wpid_123",
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
status=WorkflowRunStatus.running,
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
dummy_browser_session = AddressablePersistentBrowserSession(
|
||||
persistent_browser_session_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
browser_address="0.0.0.0:9223",
|
||||
ip_address="localhost",
|
||||
created_at=datetime.now(),
|
||||
modified_at=datetime.now(),
|
||||
)
|
||||
|
||||
return dummy_workflow_run, dummy_browser_session
|
||||
|
||||
workflow_run = await app.DATABASE.get_workflow_run(
|
||||
workflow_run_id=workflow_run_id,
|
||||
organization_id=organization_id,
|
||||
|
|
|
|||
|
|
@ -13,10 +13,8 @@ import structlog
|
|||
from fastapi import WebSocket
|
||||
from websockets.exceptions import ConnectionClosedOK
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
|
||||
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
|
||||
from skyvern.forge.sdk.routes.streaming.auth import auth
|
||||
from skyvern.forge.sdk.routes.streaming.channels.vnc import (
|
||||
Loops,
|
||||
VncChannel,
|
||||
|
|
@ -89,7 +87,6 @@ async def stream(
|
|||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
auth = local_auth if settings.ENV == "local" else real_auth
|
||||
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
|
||||
|
||||
if not organization_id:
|
||||
|
|
|
|||
29
skyvern/webeye/cdp_ports.py
Normal file
29
skyvern/webeye/cdp_ports.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""CDP port allocation for local browser sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
_CDP_PORT_RANGE_START = 9223
|
||||
_CDP_PORT_RANGE_END = 9322
|
||||
_allocated_ports: set[int] = set()
|
||||
|
||||
|
||||
def _allocate_cdp_port() -> int:
|
||||
"""Find an available port in the CDP port range for a browser session."""
|
||||
for port in range(_CDP_PORT_RANGE_START, _CDP_PORT_RANGE_END + 1):
|
||||
if port in _allocated_ports:
|
||||
continue
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", port))
|
||||
_allocated_ports.add(port)
|
||||
return port
|
||||
except OSError:
|
||||
pass
|
||||
raise RuntimeError(f"No available CDP ports in range {_CDP_PORT_RANGE_START}-{_CDP_PORT_RANGE_END}")
|
||||
|
||||
|
||||
def _release_cdp_port(port: int) -> None:
|
||||
"""Return a CDP port to the available pool."""
|
||||
_allocated_ports.discard(port)
|
||||
|
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||
|
||||
import structlog
|
||||
from playwright._impl._errors import TargetClosedError
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from skyvern.config import settings
|
||||
from skyvern.exceptions import BrowserSessionNotRenewable, MissingBrowserAddressError
|
||||
|
|
@ -21,8 +22,11 @@ from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
|
|||
is_final_status,
|
||||
)
|
||||
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
|
||||
from skyvern.webeye.browser_factory import BrowserContextFactory
|
||||
from skyvern.webeye.browser_state import BrowserState
|
||||
from skyvern.webeye.cdp_ports import _allocate_cdp_port, _release_cdp_port
|
||||
from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager
|
||||
from skyvern.webeye.real_browser_state import RealBrowserState
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
|
@ -30,6 +34,7 @@ LOG = structlog.get_logger()
|
|||
@dataclass
|
||||
class BrowserSession:
|
||||
browser_state: BrowserState
|
||||
cdp_port: int | None = None
|
||||
|
||||
|
||||
async def validate_session_for_renewal(
|
||||
|
|
@ -177,7 +182,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
"""Default (OSS) implementation of PersistentSessionsManager protocol."""
|
||||
|
||||
instance: DefaultPersistentSessionsManager | None = None
|
||||
_browser_sessions: dict[str, BrowserSession] = dict()
|
||||
_browser_sessions: dict[str, BrowserSession] = {}
|
||||
database: AgentDB
|
||||
|
||||
def __new__(cls, database: AgentDB) -> DefaultPersistentSessionsManager:
|
||||
|
|
@ -191,7 +196,72 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
return cls.instance
|
||||
|
||||
def watch_session_pool(self) -> None:
|
||||
return None
|
||||
"""No-op in OSS: browsers run in-process, no external pool to monitor."""
|
||||
|
||||
async def _launch_browser_for_session(
|
||||
self,
|
||||
session_id: str,
|
||||
organization_id: str,
|
||||
) -> None:
|
||||
"""Launch a browser process and register it as a persistent session."""
|
||||
cdp_port = _allocate_cdp_port()
|
||||
LOG.info("Launching browser for persistent session", session_id=session_id, cdp_port=cdp_port)
|
||||
|
||||
pw = None
|
||||
browser_state = None
|
||||
try:
|
||||
pw = await async_playwright().start()
|
||||
browser_context, browser_artifacts, browser_cleanup = await BrowserContextFactory.create_browser_context(
|
||||
pw,
|
||||
organization_id=organization_id,
|
||||
cdp_port=cdp_port,
|
||||
)
|
||||
|
||||
browser_state = RealBrowserState(
|
||||
pw=pw,
|
||||
browser_context=browser_context,
|
||||
page=None,
|
||||
browser_artifacts=browser_artifacts,
|
||||
browser_cleanup=browser_cleanup,
|
||||
)
|
||||
await browser_state.get_or_create_page(organization_id=organization_id)
|
||||
|
||||
self._browser_sessions[session_id] = BrowserSession(
|
||||
browser_state=browser_state,
|
||||
cdp_port=cdp_port,
|
||||
)
|
||||
|
||||
browser_address = f"http://127.0.0.1:{cdp_port}"
|
||||
await self.database.set_persistent_browser_session_browser_address(
|
||||
browser_session_id=session_id,
|
||||
browser_address=browser_address,
|
||||
ip_address="127.0.0.1",
|
||||
ecs_task_arn=None,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
await self.database.update_persistent_browser_session(
|
||||
session_id,
|
||||
organization_id=organization_id,
|
||||
status=PersistentBrowserSessionStatus.running,
|
||||
)
|
||||
|
||||
LOG.info("Browser launched for persistent session", session_id=session_id, browser_address=browser_address)
|
||||
except BaseException:
|
||||
_release_cdp_port(cdp_port)
|
||||
# Close whichever resource was successfully created.
|
||||
# browser_state.close() stops playwright internally, so only fall
|
||||
# back to pw.stop() when no browser_state was created.
|
||||
if browser_state is not None:
|
||||
try:
|
||||
await browser_state.close()
|
||||
except Exception:
|
||||
LOG.warning("Failed to close browser_state during cleanup", exc_info=True)
|
||||
elif pw is not None:
|
||||
try:
|
||||
await pw.stop()
|
||||
except Exception:
|
||||
LOG.warning("Failed to stop playwright during cleanup", exc_info=True)
|
||||
raise
|
||||
|
||||
async def begin_session(
|
||||
self,
|
||||
|
|
@ -275,7 +345,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
"Creating new browser session",
|
||||
organization_id=organization_id,
|
||||
)
|
||||
return await self.database.create_persistent_browser_session(
|
||||
session = await self.database.create_persistent_browser_session(
|
||||
organization_id=organization_id,
|
||||
runnable_type=runnable_type,
|
||||
runnable_id=runnable_id,
|
||||
|
|
@ -286,6 +356,27 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
browser_profile_id=browser_profile_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._launch_browser_for_session(
|
||||
session_id=session.persistent_browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
# Re-fetch to get updated browser_address/ip_address/started_at
|
||||
updated = await self.database.get_persistent_browser_session(
|
||||
session_id=session.persistent_browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
if updated:
|
||||
return updated
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
"Failed to launch browser for session, session will have no browser",
|
||||
session_id=session.persistent_browser_session_id,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
async def occupy_browser_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
@ -371,6 +462,8 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
)
|
||||
|
||||
self._browser_sessions.pop(browser_session_id, None)
|
||||
if browser_session.cdp_port is not None:
|
||||
_release_cdp_port(browser_session.cdp_port)
|
||||
|
||||
try:
|
||||
await browser_session.browser_state.close()
|
||||
|
|
@ -395,6 +488,8 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
)
|
||||
|
||||
await self.database.close_persistent_browser_session(browser_session_id, organization_id)
|
||||
if settings.ENV == "local":
|
||||
await self.database.archive_browser_session_address(browser_session_id, organization_id)
|
||||
|
||||
async def close_all_sessions(self, organization_id: str) -> None:
|
||||
"""Close all browser sessions for an organization."""
|
||||
|
|
@ -402,6 +497,24 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
|
|||
for browser_session in browser_sessions:
|
||||
await self.close_session(organization_id, browser_session.persistent_browser_session_id)
|
||||
|
||||
async def cleanup_stale_sessions(self) -> None:
|
||||
"""Close sessions left active by a previous process."""
|
||||
if settings.ENV != "local":
|
||||
return
|
||||
stale_sessions = await self.database.get_uncompleted_persistent_browser_sessions()
|
||||
for db_session in stale_sessions:
|
||||
LOG.info(
|
||||
"Closing stale browser session from previous run",
|
||||
session_id=db_session.persistent_browser_session_id,
|
||||
organization_id=db_session.organization_id,
|
||||
)
|
||||
await self.database.close_persistent_browser_session(
|
||||
db_session.persistent_browser_session_id, db_session.organization_id
|
||||
)
|
||||
await self.database.archive_browser_session_address(
|
||||
db_session.persistent_browser_session_id, db_session.organization_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def close(cls) -> None:
|
||||
"""Close all browser sessions across all organizations."""
|
||||
|
|
|
|||
|
|
@ -105,6 +105,10 @@ class PersistentSessionsManager(Protocol):
|
|||
"""Close all browser sessions for an organization."""
|
||||
...
|
||||
|
||||
async def cleanup_stale_sessions(self) -> None:
|
||||
"""Clean up sessions left active by a previous process."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
async def close(cls) -> None:
|
||||
"""Close all browser sessions across all organizations."""
|
||||
|
|
|
|||
76
tests/unit_tests/test_cdp_port_allocator.py
Normal file
76
tests/unit_tests/test_cdp_port_allocator.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from skyvern.webeye.cdp_ports import (
|
||||
_CDP_PORT_RANGE_END,
|
||||
_CDP_PORT_RANGE_START,
|
||||
_allocate_cdp_port,
|
||||
_allocated_ports,
|
||||
_release_cdp_port,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_allocated_ports():
|
||||
"""Ensure the global allocated-ports set is empty before and after each test."""
|
||||
_allocated_ports.clear()
|
||||
yield
|
||||
_allocated_ports.clear()
|
||||
|
||||
|
||||
class TestAllocateCdpPort:
|
||||
def test_returns_port_in_range(self):
|
||||
port = _allocate_cdp_port()
|
||||
assert _CDP_PORT_RANGE_START <= port <= _CDP_PORT_RANGE_END
|
||||
|
||||
def test_port_is_tracked(self):
|
||||
port = _allocate_cdp_port()
|
||||
assert port in _allocated_ports
|
||||
|
||||
def test_consecutive_calls_return_different_ports(self):
|
||||
p1 = _allocate_cdp_port()
|
||||
p2 = _allocate_cdp_port()
|
||||
assert p1 != p2
|
||||
|
||||
def test_skips_already_allocated_ports(self):
|
||||
first = _allocate_cdp_port()
|
||||
second = _allocate_cdp_port()
|
||||
assert second != first
|
||||
assert {first, second}.issubset(_allocated_ports)
|
||||
|
||||
def test_skips_ports_bound_by_other_processes(self):
|
||||
# Bind the first port in the range so the allocator must skip it.
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("127.0.0.1", _CDP_PORT_RANGE_START))
|
||||
try:
|
||||
port = _allocate_cdp_port()
|
||||
assert port != _CDP_PORT_RANGE_START
|
||||
assert port in _allocated_ports
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
def test_raises_when_range_exhausted(self):
|
||||
# Mark every port in the range as allocated.
|
||||
for p in range(_CDP_PORT_RANGE_START, _CDP_PORT_RANGE_END + 1):
|
||||
_allocated_ports.add(p)
|
||||
|
||||
with pytest.raises(RuntimeError, match="No available CDP ports"):
|
||||
_allocate_cdp_port()
|
||||
|
||||
|
||||
class TestReleaseCdpPort:
|
||||
def test_release_removes_from_tracking(self):
|
||||
port = _allocate_cdp_port()
|
||||
_release_cdp_port(port)
|
||||
assert port not in _allocated_ports
|
||||
|
||||
def test_release_allows_reallocation(self):
|
||||
p1 = _allocate_cdp_port()
|
||||
_release_cdp_port(p1)
|
||||
p2 = _allocate_cdp_port()
|
||||
assert p2 == p1
|
||||
|
||||
def test_release_nonexistent_port_is_noop(self):
|
||||
_release_cdp_port(99999)
|
||||
assert 99999 not in _allocated_ports
|
||||
Loading…
Add table
Add a link
Reference in a new issue