feat: add browser session support to OSS Docker deployment (#4891)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Andrew Neilson 2026-03-03 23:58:26 -08:00 committed by GitHub
parent c6d62e3fa0
commit 328bce3cdd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 337 additions and 106 deletions

View file

@ -15,7 +15,8 @@ RUN pip install --upgrade pip setuptools wheel
RUN pip install --no-cache-dir --upgrade -r requirements.txt
RUN playwright install-deps
RUN playwright install
RUN apt-get install -y xauth x11-apps netpbm gpg ca-certificates && apt-get clean
RUN apt-get install -y xauth x11-apps netpbm gpg ca-certificates x11vnc && apt-get clean
RUN pip install --no-cache-dir websockify
COPY .nvmrc /app/.nvmrc
COPY nodesource-repo.gpg.key /tmp/nodesource-repo.gpg.key

View file

@ -25,6 +25,7 @@ services:
# comment out if you want to externally call skyvern API
ports:
- 8000:8000
- 6080:6080 # for VNC WebSocket streaming
- 9222:9222 # for cdp browser forwarding
volumes:
- ./artifacts:/data/artifacts

View file

@ -56,6 +56,16 @@ Xvfb :99 -screen 0 1920x1080x16 &
xvfb=$!
DISPLAY=:99 xterm 2>/dev/null &
echo "Starting x11vnc on display :99..."
# VNC runs without a password (-nopw) because port 5900 is not exposed outside
# the container. Browser streaming reaches users via websockify on port 6080.
mkdir -p /data/log
x11vnc -display :99 -forever -nopw -shared -rfbport 5900 -bg -o /dev/null 2>/data/log/x11vnc.err
echo "Starting websockify on port 6080 -> localhost:5900..."
websockify 6080 localhost:5900 --daemon
python run_streaming.py > /dev/null &
# Run the command and pass in all three arguments

View file

@ -91,6 +91,12 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
except Exception:
LOG.exception("Failed to execute api app startup event")
# Close browser sessions left active by a previous process
try:
await forge_app.PERSISTENT_SESSIONS_MANAGER.cleanup_stale_sessions()
except Exception:
LOG.exception("Failed to clean up stale browser sessions")
# Start cleanup scheduler if enabled
cleanup_task = start_cleanup_scheduler()
if cleanup_task:
@ -126,6 +132,11 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncGenerator[None, Any]:
if redis_client is not None:
await redis_client.close()
# Close all persistent browser sessions
from skyvern.webeye.default_persistent_sessions_manager import DefaultPersistentSessionsManager
await DefaultPersistentSessionsManager.close()
if forge_app.api_app_shutdown_event:
LOG.info("Calling api app shutdown event")
try:

View file

@ -216,6 +216,7 @@ def create_forge_app() -> ForgeApp:
app.WORKFLOW_SERVICE = WorkflowService()
app.AGENT_FUNCTION = AgentFunction()
app.PERSISTENT_SESSIONS_MANAGER = DefaultPersistentSessionsManager(database=app.DATABASE)
app.PERSISTENT_SESSIONS_MANAGER.watch_session_pool()
app.BROWSER_SESSION_RECORDING_SERVICE = BrowserSessionRecordingService()
app.AZURE_CLIENT_FACTORY = RealAzureClientFactory()

View file

@ -1,4 +1,5 @@
import json
import uuid
from datetime import datetime, timedelta
from typing import Any, List, Literal, Sequence, overload
@ -5301,6 +5302,7 @@ class AgentDB(BaseAlchemyDB):
if persistent_browser_session.completed_at:
return PersistentBrowserSession.model_validate(persistent_browser_session)
persistent_browser_session.completed_at = datetime.utcnow()
persistent_browser_session.status = "completed"
await session.commit()
await session.refresh(persistent_browser_session)
return PersistentBrowserSession.model_validate(persistent_browser_session)
@ -5315,6 +5317,34 @@ class AgentDB(BaseAlchemyDB):
LOG.error("UnexpectedError", exc_info=True)
raise
async def archive_browser_session_address(self, session_id: str, organization_id: str) -> None:
"""Suffix browser_address with a unique tag so the unique constraint
no longer blocks new sessions that reuse the same local address."""
try:
async with self.Session() as session:
row = (
await session.scalars(
select(PersistentBrowserSessionModel)
.filter_by(persistent_browser_session_id=session_id)
.filter_by(organization_id=organization_id)
.filter_by(deleted_at=None)
)
).first()
if not row or not row.browser_address:
return
if "::closed::" in row.browser_address:
return
row.browser_address = f"{row.browser_address}::closed::{uuid.uuid4().hex}"
await session.commit()
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_all_active_persistent_browser_sessions(self) -> List[PersistentBrowserSessionModel]:
"""Get all active persistent browser sessions across all organizations."""
try:
@ -5328,6 +5358,21 @@ class AgentDB(BaseAlchemyDB):
LOG.error("UnexpectedError", exc_info=True)
raise
async def get_uncompleted_persistent_browser_sessions(self) -> List[PersistentBrowserSessionModel]:
"""Get all browser sessions that have not been completed or deleted."""
try:
async with self.Session() as session:
result = await session.execute(
select(PersistentBrowserSessionModel).filter_by(deleted_at=None).filter_by(completed_at=None)
)
return result.scalars().all()
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
except Exception:
LOG.error("UnexpectedError", exc_info=True)
raise
async def create_task_run(
self,
task_run_type: RunType,

View file

@ -68,33 +68,3 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s
return None
return organization_id
# NOTE(jdo:streaming-local-dev): use this instead of the above `auth`
async def _auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
"""
Local dev auth: extracts org_id from API key without strict validation.
Falls back to o_temp123 if no key provided.
"""
try:
await websocket.accept()
except ConnectionClosedOK:
LOG.info("WebSocket connection closed cleanly.")
return None
# Try to extract real org_id from the API key
if apikey:
try:
from jose import jwt
from skyvern.config import settings
payload = jwt.decode(apikey, settings.SECRET_KEY, algorithms=["HS256"])
org_id = payload.get("sub")
if org_id:
return org_id
except Exception:
LOG.warning("Local auth: failed to decode API key, falling back to o_temp123")
return "o_temp123"

View file

@ -222,7 +222,7 @@ class MessageChannel:
except Exception:
pass
del_message_channel(self.client_id)
del_message_channel(self.client_id, expected=self)
return self

View file

@ -221,7 +221,7 @@ class VncChannel:
except Exception:
pass
del_vnc_channel(self.client_id)
del_vnc_channel(self.client_id, expected=self)
return self

View file

@ -5,10 +5,8 @@ Provides WS endpoints for streaming messages to/from our frontend application.
import structlog
from fastapi import WebSocket
from skyvern.config import settings
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.channels.message import (
Loops,
MessageChannel,
@ -62,7 +60,6 @@ async def messages(
client_id: str | None = None,
token: str | None = None,
) -> None:
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:

View file

@ -6,12 +6,10 @@ import structlog
from fastapi import WebSocket, WebSocketDisconnect
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.notification.factory import NotificationRegistryFactory
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
from skyvern.forge.sdk.routes.streaming.auth import auth
LOG = structlog.get_logger()
HEARTBEAT_INTERVAL = 60
@ -40,7 +38,6 @@ async def _notification_stream_handler(
apikey: str | None = None,
token: str | None = None,
) -> None:
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:
LOG.info("Notifications: Authentication failed")

View file

@ -35,15 +35,22 @@ def add_vnc_channel(vnc_channel: VncChannel) -> None:
vnc_channels[vnc_channel.client_id] = vnc_channel
def get_vnc_channel(client_id: str) -> t.Union[VncChannel, None]:
return vnc_channels.get(client_id, None)
def get_vnc_channel(client_id: str) -> VncChannel | None:
return vnc_channels.get(client_id)
def del_vnc_channel(client_id: str) -> None:
try:
def del_vnc_channel(client_id: str, *, expected: VncChannel | None = None) -> None:
candidate = vnc_channels.get(client_id)
if candidate is None:
return
# Prevent stale channel shutdown from deleting a newer channel that reused
# the same client_id during route transitions/reconnects.
if expected is not None and candidate is not expected:
return
del vnc_channels[client_id]
except KeyError:
pass
# a registry for message channels, keyed by `client_id`
@ -54,25 +61,32 @@ def add_message_channel(message_channel: MessageChannel) -> None:
message_channels[message_channel.client_id] = message_channel
def get_message_channel(client_id: str) -> t.Union[MessageChannel, None]:
candidate = message_channels.get(client_id, None)
def get_message_channel(client_id: str) -> MessageChannel | None:
candidate = message_channels.get(client_id)
if candidate and candidate.is_open:
if candidate is None:
return None
if candidate.is_open:
return candidate
if candidate:
LOG.info(
"MessageChannel: message channel is not open; deleting it",
client_id=candidate.client_id,
)
del_message_channel(candidate.client_id)
del_message_channel(candidate.client_id, expected=candidate)
return None
def del_message_channel(client_id: str) -> None:
try:
def del_message_channel(client_id: str, *, expected: MessageChannel | None = None) -> None:
candidate = message_channels.get(client_id)
if candidate is None:
return
# Prevent stale channel shutdown from deleting a newer channel that reused
# the same client_id during route transitions/reconnects.
if expected is not None and candidate is not expected:
return
del message_channels[client_id]
except KeyError:
pass

View file

@ -21,7 +21,6 @@ from datetime import datetime, timedelta
import structlog
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession, is_final_status
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
@ -45,18 +44,6 @@ async def verify_browser_session(
"""
Verify the browser session exists, and is usable.
"""
if settings.ENV == "local":
dummy_browser_session = AddressablePersistentBrowserSession(
persistent_browser_session_id=browser_session_id,
organization_id=organization_id,
browser_address="0.0.0.0:9223",
ip_address="localhost",
created_at=datetime.now(),
modified_at=datetime.now(),
)
return dummy_browser_session
browser_session = await app.PERSISTENT_SESSIONS_MANAGER.get_session(browser_session_id, organization_id)
if not browser_session:
@ -184,28 +171,6 @@ async def verify_workflow_run(
with it.
"""
if settings.ENV == "local":
dummy_workflow_run = WorkflowRun(
workflow_id="123",
workflow_permanent_id="wpid_123",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
status=WorkflowRunStatus.running,
created_at=datetime.now(),
modified_at=datetime.now(),
)
dummy_browser_session = AddressablePersistentBrowserSession(
persistent_browser_session_id=workflow_run_id,
organization_id=organization_id,
browser_address="0.0.0.0:9223",
ip_address="localhost",
created_at=datetime.now(),
modified_at=datetime.now(),
)
return dummy_workflow_run, dummy_browser_session
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,

View file

@ -13,10 +13,8 @@ import structlog
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedOK
from skyvern.config import settings
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
from skyvern.forge.sdk.routes.streaming.auth import _auth as local_auth
from skyvern.forge.sdk.routes.streaming.auth import auth as real_auth
from skyvern.forge.sdk.routes.streaming.auth import auth
from skyvern.forge.sdk.routes.streaming.channels.vnc import (
Loops,
VncChannel,
@ -89,7 +87,6 @@ async def stream(
workflow_run_id=workflow_run_id,
)
auth = local_auth if settings.ENV == "local" else real_auth
organization_id = await auth(apikey=apikey, token=token, websocket=websocket)
if not organization_id:

View file

@ -0,0 +1,29 @@
"""CDP port allocation for local browser sessions."""
from __future__ import annotations
import socket
_CDP_PORT_RANGE_START = 9223
_CDP_PORT_RANGE_END = 9322
_allocated_ports: set[int] = set()
def _allocate_cdp_port() -> int:
"""Find an available port in the CDP port range for a browser session."""
for port in range(_CDP_PORT_RANGE_START, _CDP_PORT_RANGE_END + 1):
if port in _allocated_ports:
continue
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", port))
_allocated_ports.add(port)
return port
except OSError:
pass
raise RuntimeError(f"No available CDP ports in range {_CDP_PORT_RANGE_START}-{_CDP_PORT_RANGE_END}")
def _release_cdp_port(port: int) -> None:
"""Return a CDP port to the available pool."""
_allocated_ports.discard(port)

View file

@ -7,6 +7,7 @@ from pathlib import Path
import structlog
from playwright._impl._errors import TargetClosedError
from playwright.async_api import async_playwright
from skyvern.config import settings
from skyvern.exceptions import BrowserSessionNotRenewable, MissingBrowserAddressError
@ -21,8 +22,11 @@ from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
is_final_status,
)
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
from skyvern.webeye.browser_factory import BrowserContextFactory
from skyvern.webeye.browser_state import BrowserState
from skyvern.webeye.cdp_ports import _allocate_cdp_port, _release_cdp_port
from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager
from skyvern.webeye.real_browser_state import RealBrowserState
LOG = structlog.get_logger()
@ -30,6 +34,7 @@ LOG = structlog.get_logger()
@dataclass
class BrowserSession:
browser_state: BrowserState
cdp_port: int | None = None
async def validate_session_for_renewal(
@ -177,7 +182,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
"""Default (OSS) implementation of PersistentSessionsManager protocol."""
instance: DefaultPersistentSessionsManager | None = None
_browser_sessions: dict[str, BrowserSession] = dict()
_browser_sessions: dict[str, BrowserSession] = {}
database: AgentDB
def __new__(cls, database: AgentDB) -> DefaultPersistentSessionsManager:
@ -191,7 +196,72 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
return cls.instance
def watch_session_pool(self) -> None:
return None
"""No-op in OSS: browsers run in-process, no external pool to monitor."""
async def _launch_browser_for_session(
self,
session_id: str,
organization_id: str,
) -> None:
"""Launch a browser process and register it as a persistent session."""
cdp_port = _allocate_cdp_port()
LOG.info("Launching browser for persistent session", session_id=session_id, cdp_port=cdp_port)
pw = None
browser_state = None
try:
pw = await async_playwright().start()
browser_context, browser_artifacts, browser_cleanup = await BrowserContextFactory.create_browser_context(
pw,
organization_id=organization_id,
cdp_port=cdp_port,
)
browser_state = RealBrowserState(
pw=pw,
browser_context=browser_context,
page=None,
browser_artifacts=browser_artifacts,
browser_cleanup=browser_cleanup,
)
await browser_state.get_or_create_page(organization_id=organization_id)
self._browser_sessions[session_id] = BrowserSession(
browser_state=browser_state,
cdp_port=cdp_port,
)
browser_address = f"http://127.0.0.1:{cdp_port}"
await self.database.set_persistent_browser_session_browser_address(
browser_session_id=session_id,
browser_address=browser_address,
ip_address="127.0.0.1",
ecs_task_arn=None,
organization_id=organization_id,
)
await self.database.update_persistent_browser_session(
session_id,
organization_id=organization_id,
status=PersistentBrowserSessionStatus.running,
)
LOG.info("Browser launched for persistent session", session_id=session_id, browser_address=browser_address)
except BaseException:
_release_cdp_port(cdp_port)
# Close whichever resource was successfully created.
# browser_state.close() stops playwright internally, so only fall
# back to pw.stop() when no browser_state was created.
if browser_state is not None:
try:
await browser_state.close()
except Exception:
LOG.warning("Failed to close browser_state during cleanup", exc_info=True)
elif pw is not None:
try:
await pw.stop()
except Exception:
LOG.warning("Failed to stop playwright during cleanup", exc_info=True)
raise
async def begin_session(
self,
@ -275,7 +345,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
"Creating new browser session",
organization_id=organization_id,
)
return await self.database.create_persistent_browser_session(
session = await self.database.create_persistent_browser_session(
organization_id=organization_id,
runnable_type=runnable_type,
runnable_id=runnable_id,
@ -286,6 +356,27 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
browser_profile_id=browser_profile_id,
)
try:
await self._launch_browser_for_session(
session_id=session.persistent_browser_session_id,
organization_id=organization_id,
)
# Re-fetch to get updated browser_address/ip_address/started_at
updated = await self.database.get_persistent_browser_session(
session_id=session.persistent_browser_session_id,
organization_id=organization_id,
)
if updated:
return updated
except Exception:
LOG.exception(
"Failed to launch browser for session, session will have no browser",
session_id=session.persistent_browser_session_id,
organization_id=organization_id,
)
return session
async def occupy_browser_session(
self,
session_id: str,
@ -371,6 +462,8 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
)
self._browser_sessions.pop(browser_session_id, None)
if browser_session.cdp_port is not None:
_release_cdp_port(browser_session.cdp_port)
try:
await browser_session.browser_state.close()
@ -395,6 +488,8 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
)
await self.database.close_persistent_browser_session(browser_session_id, organization_id)
if settings.ENV == "local":
await self.database.archive_browser_session_address(browser_session_id, organization_id)
async def close_all_sessions(self, organization_id: str) -> None:
"""Close all browser sessions for an organization."""
@ -402,6 +497,24 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
for browser_session in browser_sessions:
await self.close_session(organization_id, browser_session.persistent_browser_session_id)
async def cleanup_stale_sessions(self) -> None:
"""Close sessions left active by a previous process."""
if settings.ENV != "local":
return
stale_sessions = await self.database.get_uncompleted_persistent_browser_sessions()
for db_session in stale_sessions:
LOG.info(
"Closing stale browser session from previous run",
session_id=db_session.persistent_browser_session_id,
organization_id=db_session.organization_id,
)
await self.database.close_persistent_browser_session(
db_session.persistent_browser_session_id, db_session.organization_id
)
await self.database.archive_browser_session_address(
db_session.persistent_browser_session_id, db_session.organization_id
)
@classmethod
async def close(cls) -> None:
"""Close all browser sessions across all organizations."""

View file

@ -105,6 +105,10 @@ class PersistentSessionsManager(Protocol):
"""Close all browser sessions for an organization."""
...
async def cleanup_stale_sessions(self) -> None:
"""Clean up sessions left active by a previous process."""
...
@classmethod
async def close(cls) -> None:
"""Close all browser sessions across all organizations."""

View file

@ -0,0 +1,76 @@
import socket
import pytest
from skyvern.webeye.cdp_ports import (
_CDP_PORT_RANGE_END,
_CDP_PORT_RANGE_START,
_allocate_cdp_port,
_allocated_ports,
_release_cdp_port,
)
@pytest.fixture(autouse=True)
def _clean_allocated_ports():
"""Ensure the global allocated-ports set is empty before and after each test."""
_allocated_ports.clear()
yield
_allocated_ports.clear()
class TestAllocateCdpPort:
def test_returns_port_in_range(self):
port = _allocate_cdp_port()
assert _CDP_PORT_RANGE_START <= port <= _CDP_PORT_RANGE_END
def test_port_is_tracked(self):
port = _allocate_cdp_port()
assert port in _allocated_ports
def test_consecutive_calls_return_different_ports(self):
p1 = _allocate_cdp_port()
p2 = _allocate_cdp_port()
assert p1 != p2
def test_skips_already_allocated_ports(self):
first = _allocate_cdp_port()
second = _allocate_cdp_port()
assert second != first
assert {first, second}.issubset(_allocated_ports)
def test_skips_ports_bound_by_other_processes(self):
# Bind the first port in the range so the allocator must skip it.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("127.0.0.1", _CDP_PORT_RANGE_START))
try:
port = _allocate_cdp_port()
assert port != _CDP_PORT_RANGE_START
assert port in _allocated_ports
finally:
sock.close()
def test_raises_when_range_exhausted(self):
# Mark every port in the range as allocated.
for p in range(_CDP_PORT_RANGE_START, _CDP_PORT_RANGE_END + 1):
_allocated_ports.add(p)
with pytest.raises(RuntimeError, match="No available CDP ports"):
_allocate_cdp_port()
class TestReleaseCdpPort:
def test_release_removes_from_tracking(self):
port = _allocate_cdp_port()
_release_cdp_port(port)
assert port not in _allocated_ports
def test_release_allows_reallocation(self):
p1 = _allocate_cdp_port()
_release_cdp_port(p1)
p2 = _allocate_cdp_port()
assert p2 == p1
def test_release_nonexistent_port_is_noop(self):
_release_cdp_port(99999)
assert 99999 not in _allocated_ports