@@ -146,7 +183,7 @@ function WorkflowRunStream(props?: Props) {
);
}
- if (workflowRun?.status === Status.Running && streamImgSrc.length === 0) {
+ if (isRunningOrPaused && streamImgSrc.length === 0) {
return (
Starting the stream...
@@ -154,29 +191,26 @@ function WorkflowRunStream(props?: Props) {
);
}
- if (workflowRun?.status === Status.Running && streamImgSrc.length > 0) {
+ const hasStream =
+ (isRunningOrPaused || alwaysShowStream) && streamImgSrc.length > 0;
+
+ if (hasStream) {
return (
-
-
-
+
);
}
if (alwaysShowStream) {
- if (streamImgSrc?.length > 0) {
- return (
-
-
-
- );
- }
-
return (
Waiting for stream...
diff --git a/skyvern-frontend/src/util/env.ts b/skyvern-frontend/src/util/env.ts
index dbcfb1da0..e27d18d1a 100644
--- a/skyvern-frontend/src/util/env.ts
+++ b/skyvern-frontend/src/util/env.ts
@@ -10,6 +10,9 @@ if (!environment) {
console.warn("environment environment variable was not set");
}
+const browserStreamingMode =
+ (import.meta.env.VITE_BROWSER_STREAMING_MODE as string) ?? "vnc";
+
const buildTimeApiKey: string | null =
typeof import.meta.env.VITE_SKYVERN_API_KEY === "string"
? import.meta.env.VITE_SKYVERN_API_KEY
@@ -94,6 +97,19 @@ function clearRuntimeApiKey(): void {
}
}
+async function getCredentialParam(
+ credentialGetter: (() => Promise) | null,
+): Promise {
+ if (credentialGetter) {
+ const token = await credentialGetter();
+ if (token) {
+ return `token=Bearer ${token}`;
+ }
+ }
+ const apiKey = getRuntimeApiKey();
+ return apiKey ? `apikey=${apiKey}` : "";
+}
+
const useNewRunsUrl = true as const;
const enable2faNotifications =
@@ -103,11 +119,13 @@ export {
apiBaseUrl,
runsApiBaseUrl,
environment,
+ browserStreamingMode,
artifactApiBaseUrl,
apiPathPrefix,
lsKeys,
wssBaseUrl,
newWssBaseUrl,
+ getCredentialParam,
getRuntimeApiKey,
persistRuntimeApiKey,
clearRuntimeApiKey,
diff --git a/skyvern/config.py b/skyvern/config.py
index 9cfd35a23..254097088 100644
--- a/skyvern/config.py
+++ b/skyvern/config.py
@@ -78,6 +78,7 @@ class Settings(BaseSettings):
TASK_RESPONSE_ACTION_SCREENSHOT_COUNT: int = 3
ENV: str = "local"
+ BROWSER_STREAMING_MODE: str = "vnc"
EXECUTE_ALL_STEPS: bool = True
JSON_LOGGING: bool = False
LOG_RAW_API_REQUESTS: bool = True
diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py
index 42e9c88a7..630a6c3d6 100644
--- a/skyvern/forge/sdk/db/agent_db.py
+++ b/skyvern/forge/sdk/db/agent_db.py
@@ -5847,6 +5847,7 @@ class AgentDB(BaseAlchemyDB):
timeout_minutes: int | None = None,
organization_id: str | None = None,
completed_at: datetime | None = None,
+ started_at: datetime | None = None,
) -> PersistentBrowserSession:
try:
async with self.Session() as session:
@@ -5867,6 +5868,8 @@ class AgentDB(BaseAlchemyDB):
persistent_browser_session.timeout_minutes = timeout_minutes
if completed_at is not None:
persistent_browser_session.completed_at = completed_at
+ if started_at:
+ persistent_browser_session.started_at = started_at
await session.commit()
await session.refresh(persistent_browser_session)
diff --git a/skyvern/forge/sdk/routes/__init__.py b/skyvern/forge/sdk/routes/__init__.py
index a24204fca..4c06f74a3 100644
--- a/skyvern/forge/sdk/routes/__init__.py
+++ b/skyvern/forge/sdk/routes/__init__.py
@@ -10,6 +10,7 @@ from skyvern.forge.sdk.routes import scripts # noqa: F401
from skyvern.forge.sdk.routes import sdk # noqa: F401
from skyvern.forge.sdk.routes import webhooks # noqa: F401
from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401
+from skyvern.forge.sdk.routes.streaming import cdp_input # noqa: F401
from skyvern.forge.sdk.routes.streaming import messages # noqa: F401
from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401
from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401
diff --git a/skyvern/forge/sdk/routes/debug_sessions.py b/skyvern/forge/sdk/routes/debug_sessions.py
index 2a5027839..ee1dd9f73 100644
--- a/skyvern/forge/sdk/routes/debug_sessions.py
+++ b/skyvern/forge/sdk/routes/debug_sessions.py
@@ -70,6 +70,19 @@ async def get_or_create_debug_session_by_user_and_workflow_permanent_id(
workflow_permanent_id=workflow_permanent_id,
)
+ # Skip renewal for sessions that haven't started yet (browser still launching)
+ session = await app.DATABASE.get_persistent_browser_session(
+ debug_session.browser_session_id,
+ current_org.organization_id,
+ )
+ if session and session.started_at is None and session.completed_at is None:
+ created_at_utc = (
+ session.created_at.replace(tzinfo=timezone.utc) if session.created_at.tzinfo is None else session.created_at
+ )
+ age_seconds = (datetime.now(timezone.utc) - created_at_utc).total_seconds()
+ if age_seconds < 120:
+ return debug_session
+
try:
await app.PERSISTENT_SESSIONS_MANAGER.renew_or_close_session(
debug_session.browser_session_id,
diff --git a/skyvern/forge/sdk/routes/streaming/auth.py b/skyvern/forge/sdk/routes/streaming/auth.py
index ed5a6b966..c469c6456 100644
--- a/skyvern/forge/sdk/routes/streaming/auth.py
+++ b/skyvern/forge/sdk/routes/streaming/auth.py
@@ -2,6 +2,8 @@
Streaming auth.
"""
+import typing as t
+
import structlog
from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedOK
@@ -13,6 +15,13 @@ from skyvern.forge.sdk.services.org_auth_service import get_current_org
LOG = structlog.get_logger()
+def require_client_id(client_id: str | None, **log_kwargs: t.Any) -> bool:
+ if client_id:
+ return True
+ LOG.error("No client_id provided", **log_kwargs)
+ return False
+
+
class Constants:
MISSING_API_KEY = ""
@@ -35,7 +44,7 @@ async def get_x_api_key(organization_id: str) -> str:
return x_api_key
-async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None:
+async def auth(apikey: str | None, token: str | None, websocket: WebSocket, **log_kwargs: t.Any) -> str | None:
"""
Accepts the websocket connection.
@@ -49,7 +58,7 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s
await websocket.close(code=1002)
return None
except ConnectionClosedOK:
- LOG.info("WebSocket connection closed cleanly.")
+ LOG.info("WebSocket connection closed cleanly.", **log_kwargs)
return None
try:
@@ -60,11 +69,11 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s
await websocket.close(code=1002)
return None
except Exception:
- LOG.exception("Error occurred while retrieving organization information.")
+ LOG.exception("Error occurred while retrieving organization information.", **log_kwargs)
try:
await websocket.close(code=1002)
except ConnectionClosedOK:
- LOG.info("WebSocket connection closed due to invalid credentials.")
+ LOG.info("WebSocket connection closed due to invalid credentials.", **log_kwargs)
return None
return organization_id
diff --git a/skyvern/forge/sdk/routes/streaming/cdp_input.py b/skyvern/forge/sdk/routes/streaming/cdp_input.py
new file mode 100644
index 000000000..b755d59e3
--- /dev/null
+++ b/skyvern/forge/sdk/routes/streaming/cdp_input.py
@@ -0,0 +1,406 @@
+"""
+CDP input channel for interactive browser control via Chrome DevTools Protocol.
+"""
+
+import asyncio
+import dataclasses
+import json
+import time
+import typing as t
+
+import structlog
+from fastapi import WebSocket, WebSocketDisconnect
+from playwright.async_api import CDPSession
+from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
+
+from skyvern.forge import app
+from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
+from skyvern.forge.sdk.routes.streaming.auth import auth, require_client_id
+from skyvern.forge.sdk.routes.streaming.registries import (
+ add_cdp_input_channel,
+ del_cdp_input_channel,
+ stream_ref_dec,
+ stream_ref_inc,
+)
+from skyvern.forge.sdk.routes.streaming.screencast import wait_for_browser_state
+from skyvern.forge.sdk.schemas.persistent_browser_sessions import is_final_status
+from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
+
+LOG = structlog.get_logger()
+
+_VALID_MOUSE_TYPES = {"mousePressed", "mouseReleased", "mouseMoved"}
+_VALID_MOUSE_BUTTONS = {"left", "middle", "right", "none"}
+_VALID_KEY_TYPES = {"keyDown", "keyUp"}
+_MAX_COORD = 10000
+_MAX_DELTA = 10000
+_MAX_KEY_LEN = 32
+_MAX_CODE_LEN = 32
+_MODIFIER_MASK = 0xF
+
+
+@dataclasses.dataclass
+class CdpInputChannel:
+ client_id: str
+ organization_id: str
+ websocket: WebSocket
+ interactor: t.Literal["agent", "user"] = "agent"
+
+ def __post_init__(self) -> None:
+ add_cdp_input_channel(self)
+
+ async def close(self) -> None:
+ del_cdp_input_channel(self.client_id)
+
+
+def _validated_modifiers(msg: dict) -> int:
+ modifiers = msg.get("modifiers", 0)
+ if not isinstance(modifiers, int):
+ return 0
+ return modifiers & _MODIFIER_MASK
+
+
+def _validated_coords(msg: dict) -> tuple[int, int] | None:
+ x = msg.get("x")
+ y = msg.get("y")
+ if not isinstance(x, (int, float)) or not isinstance(y, (int, float)):
+ return None
+ return (
+ max(0, min(int(x), _MAX_COORD)),
+ max(0, min(int(y), _MAX_COORD)),
+ )
+
+
+def _validate_mouse_event(msg: dict) -> dict | None:
+ event_type = msg.get("eventType")
+ if event_type not in _VALID_MOUSE_TYPES:
+ return None
+
+ coords = _validated_coords(msg)
+ if coords is None:
+ return None
+ x, y = coords
+
+ button = msg.get("button", "none")
+ if button not in _VALID_MOUSE_BUTTONS:
+ button = "none"
+
+ click_count = msg.get("clickCount", 0)
+ if not isinstance(click_count, int):
+ click_count = 0
+ click_count = max(0, min(click_count, 3))
+
+ return {
+ "type": event_type,
+ "x": x,
+ "y": y,
+ "button": button,
+ "clickCount": click_count,
+ "modifiers": _validated_modifiers(msg),
+ }
+
+
+def _validate_key_event(msg: dict) -> dict | None:
+ event_type = msg.get("eventType")
+ if event_type not in _VALID_KEY_TYPES:
+ return None
+
+ key = msg.get("key", "")
+ if not isinstance(key, str) or len(key) > _MAX_KEY_LEN:
+ return None
+
+ code = msg.get("code", "")
+ if not isinstance(code, str) or len(code) > _MAX_CODE_LEN:
+ return None
+
+ result: dict[str, t.Any] = {
+ "type": event_type,
+ "key": key,
+ "code": code,
+ "modifiers": _validated_modifiers(msg),
+ }
+
+ # Only include text for printable single characters on keyDown
+ text = msg.get("text", "")
+ if isinstance(text, str) and len(text) == 1 and text.isprintable() and event_type == "keyDown":
+ result["text"] = text
+
+ return result
+
+
+def _validate_wheel_event(msg: dict) -> dict | None:
+ coords = _validated_coords(msg)
+ if coords is None:
+ return None
+ x, y = coords
+
+ delta_x = msg.get("deltaX", 0)
+ delta_y = msg.get("deltaY", 0)
+ if not isinstance(delta_x, (int, float)):
+ delta_x = 0
+ if not isinstance(delta_y, (int, float)):
+ delta_y = 0
+ delta_x = max(-_MAX_DELTA, min(int(delta_x), _MAX_DELTA))
+ delta_y = max(-_MAX_DELTA, min(int(delta_y), _MAX_DELTA))
+
+ return {
+ "type": "mouseWheel",
+ "x": x,
+ "y": y,
+ "deltaX": delta_x,
+ "deltaY": delta_y,
+ "modifiers": _validated_modifiers(msg),
+ }
+
+
+async def _close_ws_safely(websocket: WebSocket, code: int, reason: str = "") -> None:
+ try:
+ await websocket.close(code=code, reason=reason)
+ except Exception:
+ pass
+
+
+_EVENT_DISPATCH_MAP: dict[str, tuple[t.Callable[[dict], dict | None], str]] = {
+ "mouseEvent": (_validate_mouse_event, "Input.dispatchMouseEvent"),
+ "keyEvent": (_validate_key_event, "Input.dispatchKeyEvent"),
+ "wheelEvent": (_validate_wheel_event, "Input.dispatchMouseEvent"),
+}
+
+
+async def _dispatch_event(
+ cdp_session: CDPSession,
+ kind: str,
+ msg: dict,
+ log_id_key: str,
+ log_id_value: str,
+) -> None:
+ entry = _EVENT_DISPATCH_MAP.get(kind)
+ if entry is None:
+ return
+ validator, cdp_method = entry
+ validated = validator(msg)
+ if validated:
+ await cdp_session.send(cdp_method, validated)
+ else:
+ LOG.warning(
+ "CDP input: validation failed",
+ **{log_id_key: log_id_value},
+ kind=kind,
+ raw_event_type=msg.get("eventType"),
+ )
+
+
+async def _run_input_loop(
+ websocket: WebSocket,
+ channel: CdpInputChannel,
+ cdp_session: CDPSession,
+ log_id_key: str,
+ log_id_value: str,
+) -> None:
+ dropped_log_count = 0
+ while True:
+ try:
+ raw = await websocket.receive_text()
+ except WebSocketDisconnect:
+ break
+
+ try:
+ msg = json.loads(raw)
+ except json.JSONDecodeError:
+ LOG.warning("CDP input: malformed JSON", **{log_id_key: log_id_value})
+ continue
+
+ kind = msg.get("kind") or msg.get("type")
+
+ if kind == "take-control":
+ channel.interactor = "user"
+ LOG.info("CDP input: take-control received", **{log_id_key: log_id_value}, client_id=channel.client_id)
+ continue
+ if kind == "cede-control":
+ channel.interactor = "agent"
+ LOG.info("CDP input: cede-control received", **{log_id_key: log_id_value}, client_id=channel.client_id)
+ continue
+
+ if channel.interactor != "user":
+ if dropped_log_count < 5:
+ LOG.info(
+ "CDP input: event dropped",
+ interactor=channel.interactor,
+ **{log_id_key: log_id_value},
+ kind=kind,
+ )
+ dropped_log_count += 1
+ continue
+
+ try:
+ await _dispatch_event(cdp_session, kind, msg, log_id_key, log_id_value)
+ except Exception:
+ LOG.warning(
+ "CDP input: failed to dispatch event; closing input channel",
+ **{log_id_key: log_id_value},
+ kind=kind,
+ exc_info=True,
+ )
+ await websocket.close(code=4411, reason="dispatch_failed")
+ break
+
+
+@legacy_base_router.websocket("/stream/cdp_input/workflow_run/{workflow_run_id}")
+async def cdp_input_stream(
+ websocket: WebSocket,
+ workflow_run_id: str,
+ client_id: str | None = None,
+ apikey: str | None = None,
+ token: str | None = None,
+) -> None:
+ organization_id = await auth(apikey=apikey, token=token, websocket=websocket, workflow_run_id=workflow_run_id)
+ if organization_id is None:
+ return
+
+ if not require_client_id(client_id, workflow_run_id=workflow_run_id):
+ await _close_ws_safely(websocket, 1002)
+ return
+ assert client_id is not None
+
+ channel = CdpInputChannel(
+ client_id=client_id,
+ organization_id=organization_id,
+ websocket=websocket,
+ )
+
+ cdp_session: CDPSession | None = None
+ try:
+ deadline = time.monotonic() + 120
+ while True:
+ workflow_run = await app.DATABASE.get_workflow_run(
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ if not workflow_run or workflow_run.organization_id != organization_id:
+ LOG.info("CDP input: workflow run not found", workflow_run_id=workflow_run_id)
+ await websocket.close(code=4404, reason="workflow_run_not_found")
+ return
+ if workflow_run.status == WorkflowRunStatus.running or workflow_run.status.is_final():
+ break
+ if workflow_run.status == WorkflowRunStatus.paused:
+ break
+ if time.monotonic() >= deadline:
+ LOG.warning("CDP input: timed out waiting for running status", workflow_run_id=workflow_run_id)
+ await websocket.close(code=4408, reason="wait_timeout")
+ return
+ await asyncio.sleep(1)
+
+ browser_state = await wait_for_browser_state(workflow_run_id, "workflow_run")
+ if browser_state is None:
+ LOG.warning("CDP input: timed out waiting for browser state", workflow_run_id=workflow_run_id)
+ await websocket.close(code=4408, reason="browser_state_timeout")
+ return
+
+ page = await browser_state.get_working_page()
+ if page is None:
+ LOG.warning("CDP input: no working page", workflow_run_id=workflow_run_id)
+ await websocket.close(code=4410, reason="no_working_page")
+ return
+
+ cdp_session = await page.context.new_cdp_session(page)
+ stream_ref_inc(workflow_run_id)
+
+ LOG.info("CDP input channel ready", workflow_run_id=workflow_run_id, client_id=client_id)
+ await websocket.send_json({"kind": "ready"})
+
+ await _run_input_loop(websocket, channel, cdp_session, "workflow_run_id", workflow_run_id)
+
+ except ConnectionClosedOK:
+ LOG.info("CDP input: WS closed cleanly", workflow_run_id=workflow_run_id)
+ except ConnectionClosedError:
+ LOG.warning("CDP input: WS connection error", workflow_run_id=workflow_run_id)
+ except WebSocketDisconnect:
+ LOG.info("CDP input: WS disconnected", workflow_run_id=workflow_run_id)
+ except Exception:
+ LOG.warning("CDP input: unexpected error", workflow_run_id=workflow_run_id, exc_info=True)
+ finally:
+ if cdp_session is not None:
+ await stream_ref_dec(workflow_run_id)
+ try:
+ await cdp_session.detach()
+ except Exception:
+ pass
+ await channel.close()
+ LOG.info("CDP input channel closed", workflow_run_id=workflow_run_id, client_id=client_id)
+
+
+@base_router.websocket("/stream/cdp_input/browser_session/{browser_session_id}")
+async def cdp_input_browser_session_stream(
+ websocket: WebSocket,
+ browser_session_id: str,
+ client_id: str | None = None,
+ apikey: str | None = None,
+ token: str | None = None,
+) -> None:
+ organization_id = await auth(apikey=apikey, token=token, websocket=websocket, browser_session_id=browser_session_id)
+ if organization_id is None:
+ return
+
+ if not require_client_id(client_id, browser_session_id=browser_session_id):
+ await _close_ws_safely(websocket, 1002)
+ return
+ assert client_id is not None
+
+ channel = CdpInputChannel(
+ client_id=client_id,
+ organization_id=organization_id,
+ websocket=websocket,
+ )
+
+ cdp_session: CDPSession | None = None
+ try:
+ session = await app.PERSISTENT_SESSIONS_MANAGER.get_session(
+ session_id=browser_session_id,
+ organization_id=organization_id,
+ )
+ if not session:
+ LOG.info("CDP input: browser session not found", browser_session_id=browser_session_id)
+ await websocket.close(code=4404, reason="browser_session_not_found")
+ return
+ if is_final_status(session.status):
+ LOG.info("CDP input: browser session already finalized", browser_session_id=browser_session_id)
+ await websocket.close(code=4404, reason="browser_session_finalized")
+ return
+
+ browser_state = await wait_for_browser_state(browser_session_id, "browser_session")
+ if browser_state is None:
+ LOG.warning("CDP input: timed out waiting for browser state", browser_session_id=browser_session_id)
+ await websocket.close(code=4408, reason="browser_state_timeout")
+ return
+
+ page = await browser_state.get_working_page()
+ if page is None:
+ LOG.warning("CDP input: no working page", browser_session_id=browser_session_id)
+ await websocket.close(code=4410, reason="no_working_page")
+ return
+
+ cdp_session = await page.context.new_cdp_session(page)
+ # stream_ref_inc/dec is intentionally omitted for browser sessions.
+ # Browser state lives in PersistentSessionsManager._browser_sessions,
+ # not BrowserManager.pages, so there is no entry to protect from eviction.
+
+ LOG.info("CDP input channel ready", browser_session_id=browser_session_id, client_id=client_id)
+ await websocket.send_json({"kind": "ready"})
+
+ await _run_input_loop(websocket, channel, cdp_session, "browser_session_id", browser_session_id)
+
+ except ConnectionClosedOK:
+ LOG.info("CDP input: WS closed cleanly", browser_session_id=browser_session_id)
+ except ConnectionClosedError:
+ LOG.warning("CDP input: WS connection error", browser_session_id=browser_session_id)
+ except WebSocketDisconnect:
+ LOG.info("CDP input: WS disconnected", browser_session_id=browser_session_id)
+ except Exception:
+ LOG.warning("CDP input: unexpected error", browser_session_id=browser_session_id, exc_info=True)
+ finally:
+ if cdp_session is not None:
+ try:
+ await cdp_session.detach()
+ except Exception:
+ pass
+ await channel.close()
+ LOG.info("CDP input channel closed", browser_session_id=browser_session_id, client_id=client_id)
diff --git a/skyvern/forge/sdk/routes/streaming/registries.py b/skyvern/forge/sdk/routes/streaming/registries.py
index 12143a893..91e879588 100644
--- a/skyvern/forge/sdk/routes/streaming/registries.py
+++ b/skyvern/forge/sdk/routes/streaming/registries.py
@@ -4,7 +4,7 @@ Contains registries for coordinating active WS connections (aka "channels", see
NOTE: in AWS we had to turn on what amounts to sticky sessions for frontend apps,
so that an individual frontend app instance is guaranteed to always connect to
-the same backend api instance. This is beccause the two registries here are
+the same backend api instance. This is because the two registries here are
tied together via a `client_id` string.
The tale-of-the-tape is this:
@@ -12,6 +12,11 @@ The tale-of-the-tape is this:
- one dedicated to streaming VNC's RFB protocol
- the other dedicated to messaging (JSON)
- both of these channels are stateful and need to coordinate with one another
+
+Additionally, this module manages:
+ - CDP input channels for interactive browser control
+ - Stream reference counts that defer browser state cleanup while active
+ CDP streams hold references to a workflow run's browser
"""
from __future__ import annotations
@@ -21,6 +26,7 @@ import typing as t
import structlog
if t.TYPE_CHECKING:
+ from skyvern.forge.sdk.routes.streaming.cdp_input import CdpInputChannel
from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel
from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel
@@ -90,3 +96,55 @@ def del_message_channel(client_id: str, *, expected: MessageChannel | None = Non
return
del message_channels[client_id]
+
+
+# Stream reference counts per workflow_run_id.
+_stream_refcounts: dict[str, int] = {}
+_deferred_close_params: dict[str, bool] = {}
+
+
+def stream_ref_inc(workflow_run_id: str) -> None:
+ _stream_refcounts[workflow_run_id] = _stream_refcounts.get(workflow_run_id, 0) + 1
+
+
+async def stream_ref_dec(workflow_run_id: str) -> None:
+ count = _stream_refcounts.get(workflow_run_id, 0) - 1
+ if count <= 0:
+ _stream_refcounts.pop(workflow_run_id, None)
+ close_on_completion = _deferred_close_params.pop(workflow_run_id, None)
+ if close_on_completion is not None:
+ from skyvern.forge import app
+
+ browser_state = app.BROWSER_MANAGER.pages.get(workflow_run_id)
+ if browser_state is not None:
+ try:
+ await browser_state.close(close_browser_on_completion=close_on_completion)
+ except Exception:
+ LOG.warning(
+ "stream_ref_dec: error closing deferred browser state",
+ workflow_run_id=workflow_run_id,
+ exc_info=True,
+ )
+ app.BROWSER_MANAGER.evict_page(workflow_run_id)
+ else:
+ _stream_refcounts[workflow_run_id] = count
+
+
+def stream_ref_active(workflow_run_id: str) -> bool:
+ return _stream_refcounts.get(workflow_run_id, 0) > 0
+
+
+def set_deferred_close_params(workflow_run_id: str, close_browser_on_completion: bool) -> None:
+ _deferred_close_params[workflow_run_id] = close_browser_on_completion
+
+
+# a registry for CDP input channels, keyed by `client_id`
+cdp_input_channels: dict[str, CdpInputChannel] = {}
+
+
+def add_cdp_input_channel(channel: CdpInputChannel) -> None:
+ cdp_input_channels[channel.client_id] = channel
+
+
+def del_cdp_input_channel(client_id: str) -> None:
+ cdp_input_channels.pop(client_id, None)
diff --git a/skyvern/forge/sdk/routes/streaming/screencast.py b/skyvern/forge/sdk/routes/streaming/screencast.py
new file mode 100644
index 000000000..b99135b51
--- /dev/null
+++ b/skyvern/forge/sdk/routes/streaming/screencast.py
@@ -0,0 +1,202 @@
+"""
+CDP screencast loop for local-mode browser streaming.
+
+Uses Chrome's Page.startScreencast() to stream JPEG frames from the browser
+over a WebSocket connection.
+"""
+
+import asyncio
+from collections.abc import Awaitable, Callable
+
+import structlog
+from fastapi import WebSocket
+from playwright.async_api import CDPSession
+
+from skyvern.forge import app
+from skyvern.webeye.browser_state import BrowserState
+
+LOG = structlog.get_logger()
+
+DEFAULT_WIDTH = 1280
+DEFAULT_HEIGHT = 720
+
+
+async def wait_for_browser_state(
+ entity_id: str,
+ entity_type: str,
+ workflow_run_id: str | None = None,
+ timeout: float = 120,
+ poll_interval: float = 1.0,
+) -> BrowserState | None:
+ elapsed = 0.0
+ while elapsed < timeout:
+ browser_state = await _resolve_browser_state(entity_id, entity_type, workflow_run_id)
+
+ if browser_state is not None:
+ page = await browser_state.get_working_page()
+ if page is not None:
+ return browser_state
+
+ await asyncio.sleep(poll_interval)
+ elapsed += poll_interval
+
+ return None
+
+
+async def _resolve_browser_state(
+ entity_id: str,
+ entity_type: str,
+ workflow_run_id: str | None = None,
+) -> BrowserState | None:
+ if entity_type == "workflow_run":
+ return app.BROWSER_MANAGER.get_for_workflow_run(entity_id)
+ if entity_type == "task":
+ return app.BROWSER_MANAGER.get_for_task(entity_id, workflow_run_id)
+ if entity_type == "browser_session":
+ return await app.PERSISTENT_SESSIONS_MANAGER.get_browser_state(entity_id)
+ return None
+
+
+async def start_screencast_loop(
+ websocket: WebSocket,
+ browser_state: BrowserState,
+ entity_id: str,
+ entity_type: str,
+ check_finalized: Callable[[], Awaitable[bool]],
+) -> None:
+ id_key = f"{entity_type}_id"
+ cdp_session: CDPSession | None = None
+ frame_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=2)
+ viewport_info: dict[str, int] = {"width": DEFAULT_WIDTH, "height": DEFAULT_HEIGHT}
+
+ async def _ack_frame(session_id: int) -> None:
+ if cdp_session is None:
+ return
+ try:
+ await cdp_session.send("Page.screencastFrameAck", {"sessionId": session_id})
+ except Exception:
+ pass
+
+ def _update_viewport_from_metadata(metadata: dict) -> None:
+ device_width = metadata.get("deviceWidth")
+ device_height = metadata.get("deviceHeight")
+ if isinstance(device_width, (int, float)) and device_width > 0:
+ viewport_info["width"] = int(device_width)
+ if isinstance(device_height, (int, float)) and device_height > 0:
+ viewport_info["height"] = int(device_height)
+
+ async def _on_frame(params: dict) -> None:
+ data = params.get("data", "")
+ session_id = params.get("sessionId", 0)
+ metadata = params.get("metadata", {})
+ if metadata:
+ _update_viewport_from_metadata(metadata)
+ asyncio.create_task(_ack_frame(session_id))
+ if not data:
+ return
+ # Drop oldest frame if queue is full to keep latency low
+ if frame_queue.full():
+ try:
+ frame_queue.get_nowait()
+ except asyncio.QueueEmpty:
+ pass
+ try:
+ frame_queue.put_nowait(data)
+ except asyncio.QueueFull:
+ pass
+
+ async def _frame_forwarding_loop() -> None:
+ while True:
+ data = await frame_queue.get()
+ current_url = ""
+ try:
+ page = await browser_state.get_working_page()
+ if page is not None:
+ current_url = page.url
+ except Exception:
+ pass
+ try:
+ await websocket.send_json(
+ {
+ id_key: entity_id,
+ "status": "running",
+ "screenshot": data,
+ "format": "jpeg",
+ "viewport_width": viewport_info["width"],
+ "viewport_height": viewport_info["height"],
+ "url": current_url,
+ }
+ )
+ except Exception:
+ break
+
+ async def _completion_polling_loop() -> None:
+ while True:
+ await asyncio.sleep(2)
+ try:
+ if await check_finalized():
+ return
+ except Exception:
+ LOG.warning(
+ "Error checking finalization status",
+ entity_id=entity_id,
+ entity_type=entity_type,
+ exc_info=True,
+ )
+
+ try:
+ page = await browser_state.get_working_page()
+ if page is None:
+ raise RuntimeError("No working page available for screencast")
+
+ cdp_session = await page.context.new_cdp_session(page)
+ cdp_session.on("Page.screencastFrame", _on_frame)
+ await cdp_session.send(
+ "Page.startScreencast",
+ {
+ "format": "jpeg",
+ "quality": 60,
+ "maxWidth": DEFAULT_WIDTH,
+ "maxHeight": DEFAULT_HEIGHT,
+ },
+ )
+ LOG.info("CDP screencast started", entity_id=entity_id, entity_type=entity_type)
+
+ forward_task = asyncio.create_task(_frame_forwarding_loop())
+ poll_task = asyncio.create_task(_completion_polling_loop())
+
+ done, pending = await asyncio.wait(
+ [forward_task, poll_task],
+ return_when=asyncio.FIRST_COMPLETED,
+ )
+ for task in pending:
+ task.cancel()
+ try:
+ await task
+ except (asyncio.CancelledError, Exception):
+ pass
+
+ # Re-raise forwarding errors (e.g. WebSocket disconnect)
+ for task in done:
+ exc = task.exception()
+ if task is forward_task and exc is not None:
+ raise exc
+
+ except Exception:
+ LOG.info(
+ "Screencast loop ended",
+ entity_id=entity_id,
+ entity_type=entity_type,
+ exc_info=True,
+ )
+ finally:
+ if cdp_session is not None:
+ try:
+ await cdp_session.send("Page.stopScreencast", {})
+ except Exception:
+ pass
+ try:
+ await cdp_session.detach()
+ except Exception:
+ pass
+ LOG.info("CDP screencast cleaned up", entity_id=entity_id, entity_type=entity_type)
diff --git a/skyvern/forge/sdk/routes/streaming/screenshot.py b/skyvern/forge/sdk/routes/streaming/screenshot.py
index 1d46ad632..5a06bd37d 100644
--- a/skyvern/forge/sdk/routes/streaming/screenshot.py
+++ b/skyvern/forge/sdk/routes/streaming/screenshot.py
@@ -4,6 +4,7 @@ Provides WS endpoints for streaming screenshots.
Screenshot streaming is created on the basis of one of these database entities:
- task (run)
- workflow run
+ - browser session
Screenshot streaming is used for a run that is invoked without a browser session.
Otherwise, VNC streaming is used.
@@ -11,6 +12,8 @@ Otherwise, VNC streaming is used.
import asyncio
import base64
+import time
+from collections.abc import Awaitable, Callable
from datetime import datetime
import structlog
@@ -18,14 +21,18 @@ from fastapi import HTTPException, WebSocket, WebSocketDisconnect
from pydantic import ValidationError
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
+from skyvern.config import settings
from skyvern.forge import app
-from skyvern.forge.sdk.routes.routers import legacy_base_router
+from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
+from skyvern.forge.sdk.routes.streaming.screencast import start_screencast_loop, wait_for_browser_state
+from skyvern.forge.sdk.schemas.persistent_browser_sessions import is_final_status
from skyvern.forge.sdk.schemas.tasks import TaskStatus
from skyvern.forge.sdk.services.org_auth_service import get_current_org
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
LOG = structlog.get_logger()
STREAMING_TIMEOUT = 300
+WAIT_FOR_RUNNING_TIMEOUT = 120
@legacy_base_router.websocket("/stream/tasks/{task_id}")
@@ -56,6 +63,11 @@ async def task_stream(
return
LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id)
+
+ if settings.BROWSER_STREAMING_MODE == "cdp":
+ await _local_screencast_for_task(websocket, task_id, organization_id)
+ return
+
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
@@ -173,6 +185,11 @@ async def workflow_run_streaming(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
+
+ if settings.BROWSER_STREAMING_MODE == "cdp":
+ await _local_screencast_for_workflow_run(websocket, workflow_run_id, organization_id)
+ return
+
# timestamp last time when streaming activity happens
last_activity_timestamp = datetime.utcnow()
@@ -229,7 +246,7 @@ async def workflow_run_streaming(
)
return
- if workflow_run.status == WorkflowRunStatus.running:
+ if workflow_run.status in (WorkflowRunStatus.running, WorkflowRunStatus.paused):
file_name = f"{workflow_run_id}.png"
screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name)
if screenshot:
@@ -280,3 +297,242 @@ async def workflow_run_streaming(
organization_id=organization_id,
)
return
+
+
+@base_router.websocket("/stream/browser_sessions/{browser_session_id}")
+async def browser_session_streaming(
+ websocket: WebSocket,
+ browser_session_id: str,
+ apikey: str | None = None,
+ token: str | None = None,
+) -> None:
+ try:
+ await websocket.accept()
+ if not token and not apikey:
+ await websocket.send_text("No valid credential provided")
+ await websocket.close()
+ return
+ except ConnectionClosedOK:
+ LOG.info("BrowserSession Streaming: ConnectionClosedOK error. Streaming won't start")
+ return
+
+ try:
+ organization = await get_current_org(x_api_key=apikey, authorization=token)
+ organization_id = organization.organization_id
+ except Exception:
+ LOG.exception("Error while getting organization", browser_session_id=browser_session_id)
+ try:
+ await websocket.send_text("Invalid credential provided")
+ except ConnectionClosedOK:
+ LOG.info("BrowserSession Streaming: ConnectionClosedOK error while sending invalid credential message")
+ return
+
+ LOG.info(
+ "BrowserSession Streaming: Started",
+ browser_session_id=browser_session_id,
+ organization_id=organization_id,
+ )
+
+ if settings.BROWSER_STREAMING_MODE == "cdp":
+ await _local_screencast_for_browser_session(websocket, browser_session_id, organization_id)
+ return
+
+ await websocket.close(code=4001, reason="use-vnc-streaming")
+ return
+
+
+async def _send_status(websocket: WebSocket, id_key: str, entity_id: str, status: str) -> None:
+ await websocket.send_json({id_key: entity_id, "status": status})
+
+
+async def _run_local_screencast(
+ websocket: WebSocket,
+ entity_id: str,
+ entity_type: str,
+ wait_for_running: Callable[[], Awaitable[str | None]],
+ check_finalized: Callable[[], Awaitable[bool]],
+ get_current_status: Callable[[], Awaitable[str | None]],
+ get_workflow_run_id: Callable[[], str | None] | None = None,
+) -> None:
+ id_key = f"{entity_type}_id"
+ try:
+ early_exit_status = await wait_for_running()
+ if early_exit_status is not None:
+ await _send_status(websocket, id_key, entity_id, early_exit_status)
+ return
+
+ workflow_run_id = get_workflow_run_id() if get_workflow_run_id else None
+ browser_state = await wait_for_browser_state(
+ entity_id,
+ entity_type,
+ workflow_run_id=workflow_run_id,
+ )
+ if browser_state is None:
+ LOG.warning("Timed out waiting for browser state", **{id_key: entity_id})
+ await _send_status(websocket, id_key, entity_id, "timeout")
+ return
+
+ await start_screencast_loop(
+ websocket=websocket,
+ browser_state=browser_state,
+ entity_id=entity_id,
+ entity_type=entity_type,
+ check_finalized=check_finalized,
+ )
+
+ final_status = await get_current_status()
+ if final_status is not None:
+ try:
+ await _send_status(websocket, id_key, entity_id, final_status)
+ except Exception:
+ LOG.debug("Could not send final status (WebSocket likely closed)", **{id_key: entity_id})
+
+ except (WebSocketDisconnect, ConnectionClosedOK):
+ LOG.info("WebSocket closed during local screencast", **{id_key: entity_id})
+ except ConnectionClosedError:
+ LOG.warning("WebSocket connection error during local screencast", **{id_key: entity_id})
+ except Exception:
+ LOG.warning("Error in local screencast", **{id_key: entity_id}, exc_info=True)
+
+
+async def _local_screencast_for_workflow_run(
+ websocket: WebSocket,
+ workflow_run_id: str,
+ organization_id: str,
+) -> None:
+ async def wait_for_running() -> str | None:
+ deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT
+ while True:
+ workflow_run = await app.DATABASE.get_workflow_run(
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ if not workflow_run or workflow_run.organization_id != organization_id:
+ return "not_found"
+ if workflow_run.status.is_final():
+ return workflow_run.status
+ if workflow_run.status in (WorkflowRunStatus.running, WorkflowRunStatus.paused):
+ return None
+ if time.monotonic() >= deadline:
+ LOG.warning(
+ "Timed out waiting for running status",
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ return "timeout"
+ await asyncio.sleep(1)
+
+ async def check_finalized() -> bool:
+ wr = await app.DATABASE.get_workflow_run(
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ return wr is None or wr.status.is_final()
+
+ async def get_current_status() -> str | None:
+ wr = await app.DATABASE.get_workflow_run(
+ workflow_run_id=workflow_run_id,
+ organization_id=organization_id,
+ )
+ return wr.status if wr else None
+
+ await _run_local_screencast(
+ websocket=websocket,
+ entity_id=workflow_run_id,
+ entity_type="workflow_run",
+ wait_for_running=wait_for_running,
+ check_finalized=check_finalized,
+ get_current_status=get_current_status,
+ )
+
+
+async def _local_screencast_for_task(
+ websocket: WebSocket,
+ task_id: str,
+ organization_id: str,
+) -> None:
+ task_workflow_run_id: str | None = None
+
+ async def wait_for_running() -> str | None:
+ nonlocal task_workflow_run_id
+ deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT
+ while True:
+ task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
+ if not task:
+ return "not_found"
+ if task.status.is_final():
+ return task.status
+ if task.status == TaskStatus.running:
+ task_workflow_run_id = task.workflow_run_id
+ return None
+ if time.monotonic() >= deadline:
+ LOG.warning(
+ "Timed out waiting for running status",
+ task_id=task_id,
+ organization_id=organization_id,
+ )
+ return "timeout"
+ await asyncio.sleep(1)
+
+ async def check_finalized() -> bool:
+ task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
+ return task is None or task.status.is_final()
+
+ async def get_current_status() -> str | None:
+ task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id)
+ return task.status if task else None
+
+ await _run_local_screencast(
+ websocket=websocket,
+ entity_id=task_id,
+ entity_type="task",
+ wait_for_running=wait_for_running,
+ check_finalized=check_finalized,
+ get_current_status=get_current_status,
+ get_workflow_run_id=lambda: task_workflow_run_id,
+ )
+
+
+async def _local_screencast_for_browser_session(
+ websocket: WebSocket,
+ browser_session_id: str,
+ organization_id: str,
+) -> None:
+ async def wait_for_running() -> str | None:
+ session = await app.PERSISTENT_SESSIONS_MANAGER.get_session(
+ session_id=browser_session_id,
+ organization_id=organization_id,
+ )
+ if not session:
+ LOG.warning(
+ "Browser session not found for organization",
+ browser_session_id=browser_session_id,
+ organization_id=organization_id,
+ )
+ return "not_found"
+ if is_final_status(session.status):
+ return session.status
+ return None
+
+ async def check_finalized() -> bool:
+ s = await app.PERSISTENT_SESSIONS_MANAGER.get_session(
+ session_id=browser_session_id,
+ organization_id=organization_id,
+ )
+ return s is None or is_final_status(s.status)
+
+ async def get_current_status() -> str | None:
+ s = await app.PERSISTENT_SESSIONS_MANAGER.get_session(
+ session_id=browser_session_id,
+ organization_id=organization_id,
+ )
+ return s.status if s else None
+
+ await _run_local_screencast(
+ websocket=websocket,
+ entity_id=browser_session_id,
+ entity_type="browser_session",
+ wait_for_running=wait_for_running,
+ check_finalized=check_finalized,
+ get_current_status=get_current_status,
+ )
diff --git a/skyvern/forge/sdk/routes/streaming/vnc.py b/skyvern/forge/sdk/routes/streaming/vnc.py
index 512beb820..ab2e81bb6 100644
--- a/skyvern/forge/sdk/routes/streaming/vnc.py
+++ b/skyvern/forge/sdk/routes/streaming/vnc.py
@@ -14,7 +14,7 @@ from fastapi import WebSocket
from websockets.exceptions import ConnectionClosedOK
from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router
-from skyvern.forge.sdk.routes.streaming.auth import auth
+from skyvern.forge.sdk.routes.streaming.auth import auth, require_client_id
from skyvern.forge.sdk.routes.streaming.channels.vnc import (
Loops,
VncChannel,
@@ -70,14 +70,14 @@ async def stream(
token: str | None = None,
workflow_run_id: str | None = None,
) -> None:
- if not client_id:
- LOG.error(
- "Client ID not provided for vnc stream.",
- browser_session_id=browser_session_id,
- task_id=task_id,
- workflow_run_id=workflow_run_id,
- )
+ if not require_client_id(
+ client_id,
+ browser_session_id=browser_session_id,
+ task_id=task_id,
+ workflow_run_id=workflow_run_id,
+ ):
return
+ assert client_id is not None
LOG.debug(
"Starting vnc stream.",
diff --git a/skyvern/webeye/browser_manager.py b/skyvern/webeye/browser_manager.py
index 2fca1f85f..03c420646 100644
--- a/skyvern/webeye/browser_manager.py
+++ b/skyvern/webeye/browser_manager.py
@@ -48,6 +48,8 @@ class BrowserManager(Protocol):
browser_session_id: str | None = None,
) -> BrowserState: ...
+ def evict_page(self, page_id: str) -> None: ...
+
def get_for_task(self, task_id: str, workflow_run_id: str | None = None) -> BrowserState | None: ...
def get_for_workflow_run(
diff --git a/skyvern/webeye/default_persistent_sessions_manager.py b/skyvern/webeye/default_persistent_sessions_manager.py
index 8cd187735..d4f661f83 100644
--- a/skyvern/webeye/default_persistent_sessions_manager.py
+++ b/skyvern/webeye/default_persistent_sessions_manager.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import asyncio
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from math import floor
@@ -7,7 +8,6 @@ from pathlib import Path
import structlog
from playwright._impl._errors import TargetClosedError
-from playwright.async_api import async_playwright
from skyvern.config import settings
from skyvern.exceptions import BrowserSessionNotRenewable, MissingBrowserAddressError
@@ -22,11 +22,10 @@ from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
is_final_status,
)
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
-from skyvern.webeye.browser_factory import BrowserContextFactory
from skyvern.webeye.browser_state import BrowserState
-from skyvern.webeye.cdp_ports import _allocate_cdp_port, _release_cdp_port
+from skyvern.webeye.cdp_ports import _release_cdp_port
from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager
-from skyvern.webeye.real_browser_state import RealBrowserState
+from skyvern.webeye.real_browser_manager import RealBrowserManager
LOG = structlog.get_logger()
@@ -182,7 +181,8 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
"""Default (OSS) implementation of PersistentSessionsManager protocol."""
instance: DefaultPersistentSessionsManager | None = None
- _browser_sessions: dict[str, BrowserSession] = {}
+ _browser_sessions: dict[str, BrowserSession] = dict()
+ _background_tasks: set[asyncio.Task[None]] = set()
database: AgentDB
def __new__(cls, database: AgentDB) -> DefaultPersistentSessionsManager:
@@ -198,71 +198,6 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
def watch_session_pool(self) -> None:
"""No-op in OSS: browsers run in-process, no external pool to monitor."""
- async def _launch_browser_for_session(
- self,
- session_id: str,
- organization_id: str,
- ) -> None:
- """Launch a browser process and register it as a persistent session."""
- cdp_port = _allocate_cdp_port()
- LOG.info("Launching browser for persistent session", session_id=session_id, cdp_port=cdp_port)
-
- pw = None
- browser_state = None
- try:
- pw = await async_playwright().start()
- browser_context, browser_artifacts, browser_cleanup = await BrowserContextFactory.create_browser_context(
- pw,
- organization_id=organization_id,
- cdp_port=cdp_port,
- )
-
- browser_state = RealBrowserState(
- pw=pw,
- browser_context=browser_context,
- page=None,
- browser_artifacts=browser_artifacts,
- browser_cleanup=browser_cleanup,
- )
- await browser_state.get_or_create_page(organization_id=organization_id)
-
- self._browser_sessions[session_id] = BrowserSession(
- browser_state=browser_state,
- cdp_port=cdp_port,
- )
-
- browser_address = f"http://127.0.0.1:{cdp_port}"
- await self.database.set_persistent_browser_session_browser_address(
- browser_session_id=session_id,
- browser_address=browser_address,
- ip_address="127.0.0.1",
- ecs_task_arn=None,
- organization_id=organization_id,
- )
- await self.database.update_persistent_browser_session(
- session_id,
- organization_id=organization_id,
- status=PersistentBrowserSessionStatus.running,
- )
-
- LOG.info("Browser launched for persistent session", session_id=session_id, browser_address=browser_address)
- except BaseException:
- _release_cdp_port(cdp_port)
- # Close whichever resource was successfully created.
- # browser_state.close() stops playwright internally, so only fall
- # back to pw.stop() when no browser_state was created.
- if browser_state is not None:
- try:
- await browser_state.close()
- except Exception:
- LOG.warning("Failed to close browser_state during cleanup", exc_info=True)
- elif pw is not None:
- try:
- await pw.stop()
- except Exception:
- LOG.warning("Failed to stop playwright during cleanup", exc_info=True)
- raise
-
async def begin_session(
self,
*,
@@ -356,26 +291,78 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
browser_profile_id=browser_profile_id,
)
+ # In local mode, launch the browser immediately for standalone sessions
+ # so the screencast/CDP input endpoints can connect.
+ if settings.BROWSER_STREAMING_MODE == "cdp" and runnable_id is None:
+ session_id = session.persistent_browser_session_id
+ task = asyncio.create_task(
+ self._launch_browser_for_session(session_id, organization_id, proxy_location, url)
+ )
+ self._background_tasks.add(task)
+ task.add_done_callback(self._background_tasks.discard)
+
+ return session
+
+ async def _launch_browser_for_session(
+ self,
+ session_id: str,
+ organization_id: str,
+ proxy_location: ProxyLocationInput | None = None,
+ url: str | None = None,
+ ) -> None:
try:
- await self._launch_browser_for_session(
- session_id=session.persistent_browser_session_id,
+ browser_state = await RealBrowserManager._create_browser_state(
+ proxy_location=proxy_location,
+ url=url,
organization_id=organization_id,
)
- # Re-fetch to get updated browser_address/ip_address/started_at
- updated = await self.database.get_persistent_browser_session(
- session_id=session.persistent_browser_session_id,
- organization_id=organization_id,
- )
- if updated:
- return updated
- except Exception:
- LOG.exception(
- "Failed to launch browser for session, session will have no browser",
- session_id=session.persistent_browser_session_id,
+ await browser_state.get_or_create_page(
+ url=url or "about:blank",
+ proxy_location=proxy_location,
organization_id=organization_id,
)
- return session
+ session = await self.get_session(session_id, organization_id)
+ if session is None or is_final_status(session.status) or session.completed_at is not None:
+ LOG.info(
+ "Session closed during browser launch, discarding browser",
+ browser_session_id=session_id,
+ )
+ await browser_state.close()
+ return
+
+ if session_id in self._browser_sessions:
+ LOG.info(
+ "Session already has browser state, discarding duplicate",
+ browser_session_id=session_id,
+ )
+ await browser_state.close()
+ return
+
+ self._browser_sessions[session_id] = BrowserSession(browser_state=browser_state)
+
+ result = await self.update_status(session_id, organization_id, PersistentBrowserSessionStatus.running)
+ if result is None:
+ self._browser_sessions.pop(session_id, None)
+ await browser_state.close()
+ return
+ # Set started_at so renewal knows the browser is live
+ await self.database.update_persistent_browser_session(
+ session_id,
+ organization_id=organization_id,
+ started_at=datetime.now(timezone.utc),
+ )
+ LOG.info(
+ "Browser launched for standalone session",
+ browser_session_id=session_id,
+ organization_id=organization_id,
+ )
+ except Exception:
+ LOG.exception(
+ "Failed to launch browser for standalone session",
+ browser_session_id=session_id,
+ organization_id=organization_id,
+ )
async def occupy_browser_session(
self,
@@ -396,7 +383,21 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
try:
return await renew_session(self.database, session_id, organization_id)
except BrowserSessionNotRenewable:
- await self.close_session(organization_id, session_id)
+ session = await self.get_session(session_id, organization_id)
+ # Don't close sessions that haven't started yet (browser still launching)
+ # unless they're stuck (older than 120s)
+ if session is not None and session.started_at is None and session.completed_at is None:
+ created_at_utc = (
+ session.created_at.replace(tzinfo=timezone.utc)
+ if session.created_at.tzinfo is None
+ else session.created_at
+ )
+ age_seconds = (datetime.now(timezone.utc) - created_at_utc).total_seconds()
+ if age_seconds < 120:
+ raise
+ # Session doesn't exist, has started, is completed, or is stuck — close it
+ if session is None or session.completed_at is None:
+ await self.close_session(organization_id, session_id)
raise
async def update_status(
@@ -488,7 +489,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
)
await self.database.close_persistent_browser_session(browser_session_id, organization_id)
- if settings.ENV == "local":
+ if settings.BROWSER_STREAMING_MODE == "cdp":
await self.database.archive_browser_session_address(browser_session_id, organization_id)
async def close_all_sessions(self, organization_id: str) -> None:
@@ -499,7 +500,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager):
async def cleanup_stale_sessions(self) -> None:
"""Close sessions left active by a previous process."""
- if settings.ENV != "local":
+ if settings.BROWSER_STREAMING_MODE != "cdp":
return
stale_sessions = await self.database.get_uncompleted_persistent_browser_sessions()
for db_session in stale_sessions:
diff --git a/skyvern/webeye/real_browser_manager.py b/skyvern/webeye/real_browser_manager.py
index e15b9bfe3..224dc16c1 100644
--- a/skyvern/webeye/real_browser_manager.py
+++ b/skyvern/webeye/real_browser_manager.py
@@ -62,6 +62,9 @@ class RealBrowserManager(BrowserManager):
browser_cleanup=browser_cleanup,
)
+ def evict_page(self, page_id: str) -> None:
+ self.pages.pop(page_id, None)
+
def get_for_task(self, task_id: str, workflow_run_id: str | None = None) -> BrowserState | None:
if task_id in self.pages:
return self.pages[task_id]
@@ -389,7 +392,12 @@ class RealBrowserManager(BrowserManager):
organization_id: str | None = None,
) -> BrowserState | None:
LOG.info("Cleaning up for workflow run")
- browser_state_to_close = self.pages.pop(workflow_run_id, None)
+ browser_state_to_close = self.pages.get(workflow_run_id)
+
+ from skyvern.forge.sdk.routes.streaming.registries import set_deferred_close_params, stream_ref_active
+
+ streams_active = stream_ref_active(workflow_run_id)
+
if browser_state_to_close:
# If another workflow run still references this browser state (e.g. a
# parent whose in-memory browser was shared via use_parent_browser_session),
@@ -414,10 +422,21 @@ class RealBrowserManager(BrowserManager):
await browser_state_to_close.browser_context.tracing.stop(path=trace_path)
LOG.info("Stopped tracing", trace_path=trace_path)
- await browser_state_to_close.close(close_browser_on_completion=effective_close)
+ if streams_active:
+ # Defer close until the last stream disconnects
+ LOG.info(
+ "Deferring browser close — active CDP streams",
+ workflow_run_id=workflow_run_id,
+ )
+ set_deferred_close_params(workflow_run_id, close_browser_on_completion)
+ else:
+ await browser_state_to_close.close(close_browser_on_completion=effective_close)
+
+ if not streams_active:
+ self.pages.pop(workflow_run_id, None)
for task_id in task_ids:
task_browser_state = self.pages.pop(task_id, None)
- if task_browser_state is None:
+ if task_browser_state is None or streams_active:
continue
# Same shared-state check for task-level entries
shared = any(bs is task_browser_state for bs in self.pages.values())
diff --git a/tests/unit_tests/_stub_streaming.py b/tests/unit_tests/_stub_streaming.py
new file mode 100644
index 000000000..583df92cf
--- /dev/null
+++ b/tests/unit_tests/_stub_streaming.py
@@ -0,0 +1,45 @@
+from __future__ import annotations
+
+import sys
+from types import ModuleType
+from typing import Sequence
+from unittest.mock import MagicMock
+
+# Importing a streaming route module triggers skyvern.forge.sdk.routes.__init__,
+# which eagerly imports sibling route modules with heavy side effects (e.g. AWS clients).
+# Stub those modules so streaming-only tests stay fast and deterministic.
+_BASE_STUB_MODULES: list[str] = [
+ "skyvern.forge.sdk.api.aws",
+ "skyvern.forge.sdk.routes.agent_protocol",
+ "skyvern.forge.sdk.routes.browser_profiles",
+ "skyvern.forge.sdk.routes.browser_sessions",
+ "skyvern.forge.sdk.routes.credentials",
+ "skyvern.forge.sdk.routes.debug_sessions",
+ "skyvern.forge.sdk.routes.prompts",
+ "skyvern.forge.sdk.routes.pylon",
+ "skyvern.forge.sdk.routes.run_blocks",
+ "skyvern.forge.sdk.routes.scripts",
+ "skyvern.forge.sdk.routes.sdk",
+ "skyvern.forge.sdk.routes.webhooks",
+ "skyvern.forge.sdk.routes.workflow_copilot",
+ "skyvern.forge.sdk.routes.streaming.cdp_input",
+ "skyvern.forge.sdk.routes.streaming.messages",
+ "skyvern.forge.sdk.routes.streaming.notifications",
+ "skyvern.forge.sdk.routes.streaming.vnc",
+]
+
+
+def import_with_stubs(module_path: str, extra_stubs: Sequence[str] = ()) -> ModuleType:
+ """Import a streaming module after temporarily stubbing heavy dependencies."""
+ all_stubs = list(_BASE_STUB_MODULES) + list(extra_stubs)
+ installed: dict[str, MagicMock] = {}
+ for mod in all_stubs:
+ if mod not in sys.modules:
+ installed[mod] = sys.modules[mod] = MagicMock()
+
+ try:
+ __import__(module_path)
+ return sys.modules[module_path]
+ finally:
+ for mod in installed:
+ del sys.modules[mod]
diff --git a/tests/unit_tests/test_streaming_screencast.py b/tests/unit_tests/test_streaming_screencast.py
new file mode 100644
index 000000000..3856ee7a4
--- /dev/null
+++ b/tests/unit_tests/test_streaming_screencast.py
@@ -0,0 +1,127 @@
+from __future__ import annotations
+
+from types import SimpleNamespace
+from unittest.mock import AsyncMock, Mock
+
+import pytest
+
+from tests.unit_tests._stub_streaming import import_with_stubs
+
+screencast = import_with_stubs(
+ "skyvern.forge.sdk.routes.streaming.screencast",
+ extra_stubs=["skyvern.forge.sdk.routes.streaming.screenshot"],
+)
+
+
+def _make_app(browser_manager=None, persistent_sessions_manager=None):
+ """Build a fake app namespace to replace screencast.app (an AppHolder proxy)."""
+ return SimpleNamespace(
+ BROWSER_MANAGER=browser_manager or SimpleNamespace(),
+ PERSISTENT_SESSIONS_MANAGER=persistent_sessions_manager or SimpleNamespace(),
+ )
+
+
+@pytest.mark.asyncio
+async def test_resolve_browser_state_for_workflow_run(monkeypatch: pytest.MonkeyPatch) -> None:
+ expected_state = object()
+ fake_app = _make_app(
+ browser_manager=SimpleNamespace(get_for_workflow_run=Mock(return_value=expected_state), get_for_task=Mock()),
+ persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock()),
+ )
+ monkeypatch.setattr(screencast, "app", fake_app)
+
+ result = await screencast._resolve_browser_state("wr_123", "workflow_run")
+
+ assert result is expected_state
+ fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_called_once_with("wr_123")
+ fake_app.BROWSER_MANAGER.get_for_task.assert_not_called()
+ fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_resolve_browser_state_for_task(monkeypatch: pytest.MonkeyPatch) -> None:
+ expected_state = object()
+ fake_app = _make_app(
+ browser_manager=SimpleNamespace(get_for_workflow_run=Mock(), get_for_task=Mock(return_value=expected_state)),
+ persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock()),
+ )
+ monkeypatch.setattr(screencast, "app", fake_app)
+
+ result = await screencast._resolve_browser_state("task_123", "task", workflow_run_id="wr_123")
+
+ assert result is expected_state
+ fake_app.BROWSER_MANAGER.get_for_task.assert_called_once_with("task_123", "wr_123")
+ fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_not_called()
+ fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_resolve_browser_state_for_browser_session(monkeypatch: pytest.MonkeyPatch) -> None:
+ expected_state = object()
+ fake_app = _make_app(
+ browser_manager=SimpleNamespace(get_for_workflow_run=Mock(), get_for_task=Mock()),
+ persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock(return_value=expected_state)),
+ )
+ monkeypatch.setattr(screencast, "app", fake_app)
+
+ result = await screencast._resolve_browser_state("bs_123", "browser_session")
+
+ assert result is expected_state
+ fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_awaited_once_with("bs_123")
+ fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_not_called()
+ fake_app.BROWSER_MANAGER.get_for_task.assert_not_called()
+
+
+@pytest.mark.asyncio
+async def test_resolve_browser_state_unknown_entity_type(monkeypatch: pytest.MonkeyPatch) -> None:
+ fake_app = _make_app(
+ browser_manager=SimpleNamespace(get_for_workflow_run=Mock(), get_for_task=Mock()),
+ persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock()),
+ )
+ monkeypatch.setattr(screencast, "app", fake_app)
+
+ result = await screencast._resolve_browser_state("id_123", "unknown")
+
+ assert result is None
+ fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_not_called()
+ fake_app.BROWSER_MANAGER.get_for_task.assert_not_called()
+ fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_wait_for_browser_state_returns_when_working_page_is_available(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ browser_state = SimpleNamespace(get_working_page=AsyncMock(return_value=object()))
+ resolve_mock = AsyncMock(return_value=browser_state)
+ sleep_mock = AsyncMock()
+ monkeypatch.setattr(screencast, "_resolve_browser_state", resolve_mock)
+ monkeypatch.setattr(screencast.asyncio, "sleep", sleep_mock)
+
+ result = await screencast.wait_for_browser_state("wr_123", "workflow_run", timeout=1, poll_interval=0.1)
+
+ assert result is browser_state
+ resolve_mock.assert_awaited_once_with("wr_123", "workflow_run", None)
+ browser_state.get_working_page.assert_awaited_once()
+ sleep_mock.assert_not_awaited()
+
+
+@pytest.mark.asyncio
+async def test_wait_for_browser_state_returns_none_on_timeout(monkeypatch: pytest.MonkeyPatch) -> None:
+ browser_state = SimpleNamespace(get_working_page=AsyncMock(return_value=None))
+ resolve_mock = AsyncMock(return_value=browser_state)
+ sleep_mock = AsyncMock()
+ monkeypatch.setattr(screencast, "_resolve_browser_state", resolve_mock)
+ monkeypatch.setattr(screencast.asyncio, "sleep", sleep_mock)
+
+ result = await screencast.wait_for_browser_state(
+ "bs_123",
+ "browser_session",
+ timeout=0.3,
+ poll_interval=0.1,
+ )
+
+ assert result is None
+ assert resolve_mock.await_count == 3
+ assert browser_state.get_working_page.await_count == 3
+ assert sleep_mock.await_count == 3
diff --git a/tests/unit_tests/test_streaming_screenshot_local.py b/tests/unit_tests/test_streaming_screenshot_local.py
new file mode 100644
index 000000000..a646fb788
--- /dev/null
+++ b/tests/unit_tests/test_streaming_screenshot_local.py
@@ -0,0 +1,81 @@
+from __future__ import annotations
+
+from unittest.mock import AsyncMock
+
+import pytest
+
+from tests.unit_tests._stub_streaming import import_with_stubs
+
+screenshot = import_with_stubs("skyvern.forge.sdk.routes.streaming.screenshot")
+
+
+@pytest.mark.asyncio
+async def test_run_local_screencast_happy_path(monkeypatch: pytest.MonkeyPatch) -> None:
+ websocket = object()
+ browser_state = object()
+ wait_for_running = AsyncMock(return_value=None)
+ check_finalized = AsyncMock(return_value=False)
+ get_current_status = AsyncMock(return_value="completed")
+ wait_for_browser_state_mock = AsyncMock(return_value=browser_state)
+ start_screencast_loop_mock = AsyncMock()
+ send_status_mock = AsyncMock()
+ monkeypatch.setattr(screenshot, "wait_for_browser_state", wait_for_browser_state_mock)
+ monkeypatch.setattr(screenshot, "start_screencast_loop", start_screencast_loop_mock)
+ monkeypatch.setattr(screenshot, "_send_status", send_status_mock)
+
+ await screenshot._run_local_screencast(
+ websocket=websocket,
+ entity_id="task_123",
+ entity_type="task",
+ wait_for_running=wait_for_running,
+ check_finalized=check_finalized,
+ get_current_status=get_current_status,
+ get_workflow_run_id=lambda: "wr_123",
+ )
+
+ wait_for_running.assert_awaited_once()
+ wait_for_browser_state_mock.assert_awaited_once_with("task_123", "task", workflow_run_id="wr_123")
+ start_screencast_loop_mock.assert_awaited_once_with(
+ websocket=websocket,
+ browser_state=browser_state,
+ entity_id="task_123",
+ entity_type="task",
+ check_finalized=check_finalized,
+ )
+ get_current_status.assert_awaited_once()
+ send_status_mock.assert_awaited_once_with(websocket, "task_id", "task_123", "completed")
+
+
+@pytest.mark.asyncio
+async def test_run_local_screencast_timeout_when_browser_state_not_available(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ websocket = object()
+ wait_for_running = AsyncMock(return_value=None)
+ check_finalized = AsyncMock(return_value=False)
+ get_current_status = AsyncMock(return_value="completed")
+ wait_for_browser_state_mock = AsyncMock(return_value=None)
+ start_screencast_loop_mock = AsyncMock()
+ send_status_mock = AsyncMock()
+ monkeypatch.setattr(screenshot, "wait_for_browser_state", wait_for_browser_state_mock)
+ monkeypatch.setattr(screenshot, "start_screencast_loop", start_screencast_loop_mock)
+ monkeypatch.setattr(screenshot, "_send_status", send_status_mock)
+
+ await screenshot._run_local_screencast(
+ websocket=websocket,
+ entity_id="bs_123",
+ entity_type="browser_session",
+ wait_for_running=wait_for_running,
+ check_finalized=check_finalized,
+ get_current_status=get_current_status,
+ )
+
+ wait_for_running.assert_awaited_once()
+ wait_for_browser_state_mock.assert_awaited_once_with(
+ "bs_123",
+ "browser_session",
+ workflow_run_id=None,
+ )
+ start_screencast_loop_mock.assert_not_awaited()
+ get_current_status.assert_not_awaited()
+ send_status_mock.assert_awaited_once_with(websocket, "browser_session_id", "bs_123", "timeout")