diff --git a/.env.example b/.env.example index dd5d44599..8871a2dff 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,9 @@ # Environment that the agent will run in. ENV=local +# Browser streaming mode: "cdp" for CDP screencast, "vnc" (default) for VNC streaming. +BROWSER_STREAMING_MODE=vnc + # LLM Provider Configurations: # ENABLE_OPENAI: Set to true to enable OpenAI as a language model provider. ENABLE_OPENAI=false @@ -143,4 +146,11 @@ SKYVERN_AUTH_BITWARDEN_CLIENT_SECRET=your-client-secret-here # BITWARDEN_TIMEOUT_SECONDS=60 # Shared Redis URL used by any service that needs Redis (pub/sub, cache, etc.) -# REDIS_URL=redis://localhost:6379/0 \ No newline at end of file +# REDIS_URL=redis://localhost:6379/0 + +# Notification registry type: "local" (default, in-process) or "redis" (multi-pod) +# NOTIFICATION_REGISTRY_TYPE=local + +# Optional: override Redis URL specifically for notifications (falls back to REDIS_URL) +# NOTIFICATION_REDIS_URL= +# REDIS_URL=redis://localhost:6379/0 diff --git a/skyvern-frontend/.env.example b/skyvern-frontend/.env.example index c125b069a..3086cf6a4 100644 --- a/skyvern-frontend/.env.example +++ b/skyvern-frontend/.env.example @@ -1,3 +1,8 @@ +# Browser streaming mode: +# - cdp: use CDP screencast (for local development without VNC) +# - vnc (default): use VNC streaming +VITE_BROWSER_STREAMING_MODE=vnc + VITE_API_BASE_URL=http://localhost:8000/api/v1 # server to load artifacts from file URIs diff --git a/skyvern-frontend/src/components/BrowserStream.tsx b/skyvern-frontend/src/components/BrowserStream.tsx index 4c91d8fb7..8da472678 100644 --- a/skyvern-frontend/src/components/BrowserStream.tsx +++ b/skyvern-frontend/src/components/BrowserStream.tsx @@ -257,10 +257,9 @@ function BrowserStream({ useEffect(() => { if (prevMessageConnectedRef.current && !isMessageConnected) { setMessagesDisconnectedTrigger((x) => x + 1); - onClose?.(); } prevMessageConnectedRef.current = isMessageConnected; - }, [isMessageConnected, onClose]); + }, [isMessageConnected]); // vnc socket useEffect( diff --git a/skyvern-frontend/src/routes/browserSessions/BrowserSession.tsx b/skyvern-frontend/src/routes/browserSessions/BrowserSession.tsx index bc90540f0..58e2c1acd 100644 --- a/skyvern-frontend/src/routes/browserSessions/BrowserSession.tsx +++ b/skyvern-frontend/src/routes/browserSessions/BrowserSession.tsx @@ -1,5 +1,5 @@ import { ReloadIcon, StopIcon } from "@radix-ui/react-icons"; -import { useState } from "react"; +import { useEffect, useState } from "react"; import { Outlet, useLocation, useParams } from "react-router-dom"; import { useQuery } from "@tanstack/react-query"; @@ -23,9 +23,13 @@ import { useCredentialGetter } from "@/hooks/useCredentialGetter"; import { useCloseBrowserSessionMutation } from "@/routes/browserSessions/hooks/useCloseBrowserSessionMutation"; import { CopyText } from "@/routes/workflows/editor/Workspace"; import { type BrowserSession as BrowserSessionType } from "@/routes/workflows/types/browserSessionTypes"; +import { browserStreamingMode } from "@/util/env"; import { BrowserSessionDownloads } from "./BrowserSessionDownloads"; import { BrowserSessionVideo } from "./BrowserSessionVideo"; +import { BrowserSessionStream } from "./BrowserSessionStream"; + +const isCdpMode = browserStreamingMode === "cdp"; type TabName = "stream" | "recordings" | "downloads"; @@ -38,6 +42,11 @@ function BrowserSession() { ? "downloads" : "stream"; const [isDialogOpen, setIsDialogOpen] = useState(false); + const [vncFailed, setVncFailed] = useState(false); + + useEffect(() => { + setVncFailed(false); + }, [browserSessionId]); const credentialGetter = useCredentialGetter(); @@ -189,12 +198,25 @@ function BrowserSession() { pointerEvents: activeTab === "stream" ? "auto" : "none", }} > - + {/* VNC streaming */} + {browserSession.vnc_streaming_supported && !vncFailed && ( + setVncFailed(true)} + /> + )} + {isCdpMode && + browserSessionId && + (!browserSession.vnc_streaming_supported || vncFailed) && ( + + )}
(""); + const [streamFormat, setStreamFormat] = useState("png"); + const [viewportWidth, setViewportWidth] = useState(1280); + const [viewportHeight, setViewportHeight] = useState(720); + const [currentUrl, setCurrentUrl] = useState(""); + const credentialGetter = useCredentialGetter(); + + const socketRef = useRef(null); + + const inputWsUrl = interactive + ? `${newWssBaseUrl}/stream/cdp_input/browser_session/${browserSessionId}` + : null; + + const { + userIsControlling, + setUserIsControlling, + inputReady, + containerRef, + handlers, + } = useCdpInput({ + inputWsUrl, + interactive, + viewportWidth, + viewportHeight, + }); + + useEffect(() => { + async function run() { + const credentialParam = await getCredentialParam(credentialGetter); + + if (socketRef.current) { + socketRef.current.close(); + } + socketRef.current = new WebSocket( + `${newWssBaseUrl}/stream/browser_sessions/${browserSessionId}?${credentialParam}`, + ); + + socketRef.current.addEventListener("message", (event) => { + try { + const message: StreamMessage = JSON.parse(event.data); + if (message.screenshot) { + setStreamImgSrc(message.screenshot); + } + if (message.format) { + setStreamFormat(message.format); + } + if (message.viewport_width) { + setViewportWidth(message.viewport_width); + } + if (message.viewport_height) { + setViewportHeight(message.viewport_height); + } + if (message.url !== undefined) { + setCurrentUrl(message.url); + } + if ( + message.status === "completed" || + message.status === "failed" || + message.status === "timeout" + ) { + socketRef.current?.close(); + } + } catch (e) { + console.error("Failed to parse message", e); + } + }); + + socketRef.current.addEventListener("close", () => { + socketRef.current = null; + }); + } + run(); + + return () => { + if (socketRef.current) { + socketRef.current.close(); + socketRef.current = null; + } + }; + }, [credentialGetter, browserSessionId]); + + if (streamImgSrc.length > 0) { + return ( + + ); + } + + return ( +
+ Starting stream... +
+ ); +} + +export { BrowserSessionStream }; diff --git a/skyvern-frontend/src/routes/streaming/InteractiveStreamView.tsx b/skyvern-frontend/src/routes/streaming/InteractiveStreamView.tsx new file mode 100644 index 000000000..b99bf31eb --- /dev/null +++ b/skyvern-frontend/src/routes/streaming/InteractiveStreamView.tsx @@ -0,0 +1,104 @@ +import type { RefObject } from "react"; +import { GlobeIcon } from "@radix-ui/react-icons"; +import { ZoomableImage } from "@/components/ZoomableImage"; +import { Button } from "@/components/ui/button"; +import { cn } from "@/util/utils"; + +interface InteractiveStreamViewProps { + streamImgSrc: string; + streamFormat: string; + interactive: boolean; + userIsControlling: boolean; + setUserIsControlling: (v: boolean) => void; + inputReady: boolean; + containerRef: RefObject; + showControlButtons: boolean; + handlers: { + handleMouseDown: (e: React.MouseEvent) => void; + handleMouseUp: (e: React.MouseEvent) => void; + handleMouseMove: (e: React.MouseEvent) => void; + handleKeyDown: (e: React.KeyboardEvent) => void; + handleKeyUp: (e: React.KeyboardEvent) => void; + }; + currentUrl?: string; +} + +function UrlBar({ url }: { url: string }) { + return ( +
+ + {url} +
+ ); +} + +function InteractiveStreamView({ + streamImgSrc, + streamFormat, + interactive, + userIsControlling, + setUserIsControlling, + inputReady, + containerRef, + showControlButtons, + handlers, + currentUrl, +}: InteractiveStreamViewProps) { + const imgDataUrl = `data:image/${streamFormat};base64,${streamImgSrc}`; + + if (interactive) { + return ( +
+ {currentUrl && } + {showControlButtons && !userIsControlling && inputReady && ( +
+ +
+ )} + {showControlButtons && userIsControlling && ( + + )} + e.preventDefault()} + draggable={false} + /> +
+ ); + } + + return ( +
+ {currentUrl && } + +
+ ); +} + +export { InteractiveStreamView }; diff --git a/skyvern-frontend/src/routes/streaming/cdpInputUtils.ts b/skyvern-frontend/src/routes/streaming/cdpInputUtils.ts new file mode 100644 index 000000000..9937da288 --- /dev/null +++ b/skyvern-frontend/src/routes/streaming/cdpInputUtils.ts @@ -0,0 +1,73 @@ +export function mouseButtonName(button: number): string { + if (button === 2) return "right"; + if (button === 1) return "middle"; + return "left"; +} + +export function getModifiers( + e: Pick, +): number { + let m = 0; + if (e.altKey) m |= 1; + if (e.ctrlKey) m |= 2; + if (e.metaKey) m |= 4; + if (e.shiftKey) m |= 8; + return m; +} + +/** + * Map pixel coordinates from a rendered image back to viewport coordinates, + * accounting for object-contain letterboxing. + */ +export function mapCoordinates( + clientX: number, + clientY: number, + rect: DOMRect, + vpW: number, + vpH: number, +): { x: number; y: number } | null { + const containerAspect = rect.width / rect.height; + const imageAspect = vpW / vpH; + + let renderedW: number, renderedH: number, offsetX: number, offsetY: number; + if (containerAspect > imageAspect) { + renderedH = rect.height; + renderedW = rect.height * imageAspect; + offsetX = (rect.width - renderedW) / 2; + offsetY = 0; + } else { + renderedW = rect.width; + renderedH = rect.width / imageAspect; + offsetX = 0; + offsetY = (rect.height - renderedH) / 2; + } + + const localX = clientX - rect.left - offsetX; + const localY = clientY - rect.top - offsetY; + + if (localX < 0 || localX > renderedW || localY < 0 || localY > renderedH) { + return null; + } + + return { + x: Math.round(localX * (vpW / renderedW)), + y: Math.round(localY * (vpH / renderedH)), + }; +} + +/** + * Convenience wrapper for React MouseEvent on an img element. + */ +export function mapMouseCoordinates( + e: React.MouseEvent, + vpW: number, + vpH: number, +): { x: number; y: number } | null { + return mapCoordinates( + e.clientX, + e.clientY, + e.currentTarget.getBoundingClientRect(), + vpW, + vpH, + ); +} diff --git a/skyvern-frontend/src/routes/streaming/useCdpInput.ts b/skyvern-frontend/src/routes/streaming/useCdpInput.ts new file mode 100644 index 000000000..72f7119d7 --- /dev/null +++ b/skyvern-frontend/src/routes/streaming/useCdpInput.ts @@ -0,0 +1,359 @@ +import { useCallback, useEffect, useRef, useState } from "react"; +import { useCredentialGetter } from "@/hooks/useCredentialGetter"; +import { getCredentialParam } from "@/util/env"; +import { useClientIdStore } from "@/store/useClientIdStore"; +import { + mouseButtonName, + getModifiers, + mapCoordinates, + mapMouseCoordinates, +} from "./cdpInputUtils"; + +const RECONNECTABLE_CODES = new Set([1006, 1011, 4408, 4410]); + +interface UseCdpInputOptions { + inputWsUrl: string | null; + interactive: boolean; + viewportWidth: number; + viewportHeight: number; +} + +interface UseCdpInputReturn { + userIsControlling: boolean; + setUserIsControlling: (v: boolean) => void; + inputReady: boolean; + containerRef: React.RefObject; + handlers: { + handleMouseDown: (e: React.MouseEvent) => void; + handleMouseUp: (e: React.MouseEvent) => void; + handleMouseMove: (e: React.MouseEvent) => void; + handleKeyDown: (e: React.KeyboardEvent) => void; + handleKeyUp: (e: React.KeyboardEvent) => void; + }; +} + +export function useCdpInput({ + inputWsUrl, + interactive, + viewportWidth, + viewportHeight, +}: UseCdpInputOptions): UseCdpInputReturn { + const [userIsControlling, setUserIsControlling] = useState(false); + const [inputReady, setInputReady] = useState(false); + const credentialGetter = useCredentialGetter(); + const clientId = useClientIdStore((s) => s.clientId); + + const inputSocketRef = useRef(null); + const containerRef = useRef(null); + const lastMouseMoveRef = useRef(0); + const userIsControllingRef = useRef(false); + const inputReconnectTimerRef = useRef | null>( + null, + ); + const inputReconnectAttemptsRef = useRef(0); + const inputStoppedRef = useRef(false); + const inputEventCountRef = useRef(0); + + useEffect(() => { + if (!interactive || !inputWsUrl) return; + + inputStoppedRef.current = false; + inputReconnectAttemptsRef.current = 0; + + function connectInputWs(credentialParam: string) { + if (inputStoppedRef.current) return; + if (inputSocketRef.current) { + inputSocketRef.current.close(); + } + const ws = new WebSocket( + `${inputWsUrl}?client_id=${clientId}&${credentialParam}`, + ); + inputSocketRef.current = ws; + + ws.addEventListener("open", () => { + if (inputSocketRef.current !== ws) return; + console.log("[cdp-input] WebSocket connected"); + if (userIsControllingRef.current) { + ws.send(JSON.stringify({ kind: "take-control" })); + } + }); + ws.addEventListener("error", (e) => { + console.error("[cdp-input] WebSocket error", e); + }); + ws.addEventListener("message", (event) => { + if (inputSocketRef.current !== ws) return; + try { + const msg = JSON.parse(event.data); + if (msg.kind === "ready") { + console.log( + "[cdp-input] Server ready, sending current control state", + ); + inputReconnectAttemptsRef.current = 0; + setInputReady(true); + if (userIsControllingRef.current) { + ws.send(JSON.stringify({ kind: "take-control" })); + } + } + } catch { + // ignore non-JSON messages + } + }); + ws.addEventListener("close", (event) => { + console.log("[cdp-input] WebSocket closed", event.code, event.reason); + if (inputSocketRef.current !== ws) return; + setInputReady(false); + userIsControllingRef.current = false; + setUserIsControlling(false); + inputSocketRef.current = null; + + if (!inputStoppedRef.current && RECONNECTABLE_CODES.has(event.code)) { + if (inputReconnectTimerRef.current) { + clearTimeout(inputReconnectTimerRef.current); + } + inputReconnectTimerRef.current = setTimeout(() => { + reconnectInputWs(); + }, 2000); + } + }); + } + + async function reconnectInputWs() { + if (inputStoppedRef.current) return; + if (inputReconnectAttemptsRef.current >= 5) { + console.log("[cdp-input] Max reconnect attempts reached, giving up"); + return; + } + inputReconnectAttemptsRef.current += 1; + console.log( + `[cdp-input] Reconnecting (attempt ${inputReconnectAttemptsRef.current}/5)`, + ); + try { + const credentialParam = await getCredentialParam(credentialGetter); + connectInputWs(credentialParam); + } catch (e) { + console.error("[cdp-input] Failed to get credentials for reconnect", e); + } + } + + getCredentialParam(credentialGetter).then((credentialParam) => { + connectInputWs(credentialParam); + }); + + return () => { + inputStoppedRef.current = true; + if (inputReconnectTimerRef.current) { + clearTimeout(inputReconnectTimerRef.current); + inputReconnectTimerRef.current = null; + } + if (inputSocketRef.current) { + inputSocketRef.current.close(); + inputSocketRef.current = null; + } + }; + }, [interactive, inputWsUrl, credentialGetter, clientId]); + + useEffect(() => { + userIsControllingRef.current = userIsControlling; + }, [userIsControlling]); + + useEffect(() => { + const ws = inputSocketRef.current; + const kind = userIsControlling ? "take-control" : "cede-control"; + if (!ws || ws.readyState !== WebSocket.OPEN) { + return; + } + console.log(`[cdp-input] Sending ${kind}`); + ws.send(JSON.stringify({ kind })); + if (userIsControlling) { + inputEventCountRef.current = 0; + } + }, [userIsControlling]); + + useEffect(() => { + if (userIsControlling) { + containerRef.current?.focus(); + } else { + containerRef.current?.blur(); + } + }, [userIsControlling]); + + // Wheel event listener (needs non-passive to preventDefault) + useEffect(() => { + if (!interactive || !userIsControlling) return; + const el = containerRef.current; + if (!el) return; + + const handler = (e: WheelEvent) => { + e.preventDefault(); + const ws = inputSocketRef.current; + if (!ws || ws.readyState !== WebSocket.OPEN) return; + + const img = el.querySelector("img"); + if (!img) return; + + const rect = img.getBoundingClientRect(); + const coords = mapCoordinates( + e.clientX, + e.clientY, + rect, + viewportWidth, + viewportHeight, + ); + if (!coords) return; + + ws.send( + JSON.stringify({ + type: "wheelEvent", + x: coords.x, + y: coords.y, + deltaX: Math.round(e.deltaX), + deltaY: Math.round(e.deltaY), + modifiers: getModifiers(e), + }), + ); + }; + + el.addEventListener("wheel", handler, { passive: false }); + return () => el.removeEventListener("wheel", handler); + }, [interactive, userIsControlling, viewportWidth, viewportHeight]); + + const sendInputEvent = useCallback((payload: Record) => { + const ws = inputSocketRef.current; + if (!ws || ws.readyState !== WebSocket.OPEN) { + if (inputEventCountRef.current < 3) { + console.log( + "[cdp-input] Event dropped (ws not open):", + payload.type, + payload.eventType, + ); + inputEventCountRef.current++; + } + return; + } + if (inputEventCountRef.current < 3) { + console.log("[cdp-input] Sending:", payload.type, payload.eventType); + inputEventCountRef.current++; + } + ws.send(JSON.stringify(payload)); + }, []); + + const handleMouseDown = useCallback( + (e: React.MouseEvent) => { + if (!interactive || !userIsControlling) return; + const coords = mapMouseCoordinates(e, viewportWidth, viewportHeight); + if (!coords) return; + sendInputEvent({ + type: "mouseEvent", + eventType: "mousePressed", + x: coords.x, + y: coords.y, + button: mouseButtonName(e.button), + clickCount: 1, + modifiers: getModifiers(e), + }); + }, + [ + interactive, + userIsControlling, + viewportWidth, + viewportHeight, + sendInputEvent, + ], + ); + + const handleMouseUp = useCallback( + (e: React.MouseEvent) => { + if (!interactive || !userIsControlling) return; + const coords = mapMouseCoordinates(e, viewportWidth, viewportHeight); + if (!coords) return; + sendInputEvent({ + type: "mouseEvent", + eventType: "mouseReleased", + x: coords.x, + y: coords.y, + button: mouseButtonName(e.button), + clickCount: 1, + modifiers: getModifiers(e), + }); + }, + [ + interactive, + userIsControlling, + viewportWidth, + viewportHeight, + sendInputEvent, + ], + ); + + const handleMouseMove = useCallback( + (e: React.MouseEvent) => { + if (!interactive || !userIsControlling) return; + const now = Date.now(); + if (now - lastMouseMoveRef.current < 50) return; + lastMouseMoveRef.current = now; + const coords = mapMouseCoordinates(e, viewportWidth, viewportHeight); + if (!coords) return; + sendInputEvent({ + type: "mouseEvent", + eventType: "mouseMoved", + x: coords.x, + y: coords.y, + button: "none", + clickCount: 0, + modifiers: getModifiers(e), + }); + }, + [ + interactive, + userIsControlling, + viewportWidth, + viewportHeight, + sendInputEvent, + ], + ); + + const handleKeyDown = useCallback( + (e: React.KeyboardEvent) => { + if (!interactive || !userIsControlling) return; + e.preventDefault(); + sendInputEvent({ + type: "keyEvent", + eventType: "keyDown", + key: e.key, + code: e.code, + text: e.key.length === 1 ? e.key : "", + modifiers: getModifiers(e), + }); + }, + [interactive, userIsControlling, sendInputEvent], + ); + + const handleKeyUp = useCallback( + (e: React.KeyboardEvent) => { + if (!interactive || !userIsControlling) return; + e.preventDefault(); + sendInputEvent({ + type: "keyEvent", + eventType: "keyUp", + key: e.key, + code: e.code, + modifiers: getModifiers(e), + }); + }, + [interactive, userIsControlling, sendInputEvent], + ); + + return { + userIsControlling, + setUserIsControlling, + inputReady, + containerRef, + handlers: { + handleMouseDown, + handleMouseUp, + handleMouseMove, + handleKeyDown, + handleKeyUp, + }, + }; +} diff --git a/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx b/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx index 57349b525..3464e641c 100644 --- a/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx +++ b/skyvern-frontend/src/routes/tasks/detail/TaskActions.tsx @@ -7,13 +7,13 @@ import { toast } from "@/components/ui/use-toast"; import { ZoomableImage } from "@/components/ZoomableImage"; import { useCostCalculator } from "@/hooks/useCostCalculator"; import { useCredentialGetter } from "@/hooks/useCredentialGetter"; -import { getRuntimeApiKey } from "@/util/env"; +import { getCredentialParam } from "@/util/env"; import { keepPreviousData, useQuery, useQueryClient, } from "@tanstack/react-query"; -import { useEffect, useState } from "react"; +import { useEffect, useRef, useState } from "react"; import { statusIsFinalized, statusIsNotFinalized, @@ -33,16 +33,17 @@ type StreamMessage = { task_id: string; status: string; screenshot?: string; + format?: string; }; -let socket: WebSocket | null = null; - const wssBaseUrl = import.meta.env.VITE_WSS_BASE_URL; function TaskActions() { const taskId = useFirstParam("taskId", "runId"); const credentialGetter = useCredentialGetter(); const [streamImgSrc, setStreamImgSrc] = useState(""); + const [streamFormat, setStreamFormat] = useState("png"); + const socketRef = useRef(null); const [selectedAction, setSelectedAction] = useState< number | "stream" | null >(null); @@ -81,34 +82,30 @@ function TaskActions() { } async function run() { - // Create WebSocket connection. - let credential = null; - if (credentialGetter) { - const token = await credentialGetter(); - credential = `?token=Bearer ${token}`; - } else { - const apiKey = getRuntimeApiKey(); - credential = apiKey ? `?apikey=${apiKey}` : ""; + const credentialParam = await getCredentialParam(credentialGetter); + + if (socketRef.current) { + socketRef.current.close(); } - if (socket) { - socket.close(); - } - socket = new WebSocket( - `${wssBaseUrl}/stream/tasks/${taskId}${credential}`, + socketRef.current = new WebSocket( + `${wssBaseUrl}/stream/tasks/${taskId}?${credentialParam}`, ); - // Listen for messages - socket.addEventListener("message", (event) => { + + socketRef.current.addEventListener("message", (event) => { try { const message: StreamMessage = JSON.parse(event.data); if (message.screenshot) { setStreamImgSrc(message.screenshot); } + if (message.format) { + setStreamFormat(message.format); + } if ( message.status === "completed" || message.status === "failed" || message.status === "terminated" ) { - socket?.close(); + socketRef.current?.close(); queryClient.invalidateQueries({ queryKey: ["tasks"], }); @@ -134,16 +131,16 @@ function TaskActions() { } }); - socket.addEventListener("close", () => { - socket = null; + socketRef.current.addEventListener("close", () => { + socketRef.current = null; }); } run(); return () => { - if (socket) { - socket.close(); - socket = null; + if (socketRef.current) { + socketRef.current.close(); + socketRef.current = null; } }; }, [ @@ -257,7 +254,9 @@ function TaskActions() { if (task?.status === Status.Running && streamImgSrc.length > 0) { return (
- +
); } diff --git a/skyvern-frontend/src/routes/workflows/editor/Workspace.tsx b/skyvern-frontend/src/routes/workflows/editor/Workspace.tsx index 7d58d43e1..20fd917b6 100644 --- a/skyvern-frontend/src/routes/workflows/editor/Workspace.tsx +++ b/skyvern-frontend/src/routes/workflows/editor/Workspace.tsx @@ -35,7 +35,8 @@ import { useMountEffect } from "@/hooks/useMountEffect"; import { useBrowserSessionRateLimit } from "../hooks/useBrowserSessionRateLimit"; import { useDebugSessionQuery } from "../hooks/useDebugSessionQuery"; import { useBlockScriptsQuery } from "@/routes/workflows/hooks/useBlockScriptsQuery"; -import { WorkflowRunStream } from "@/routes/workflows/workflowRun/WorkflowRunStream"; +import { BrowserSessionStream } from "@/routes/browserSessions/BrowserSessionStream"; +import { browserStreamingMode } from "@/util/env"; import { useCacheKeyValuesQuery } from "../hooks/useCacheKeyValuesQuery"; import { useBlockScriptStore } from "@/store/BlockScriptStore"; import { useRecordingStore } from "@/store/useRecordingStore"; @@ -1760,14 +1761,22 @@ function Workspace({
)} - {/* Screenshot browser} */} + {/* CDP screencast: only in local mode when VNC is not supported */} {activeDebugSession && - !activeDebugSession.vnc_streaming_supported && ( + !activeDebugSession.vnc_streaming_supported && + browserStreamingMode === "cdp" && (
-
-
- -
+
+
setIsCopilotOpen((prev) => !prev)} /> +
+ Live Browser +
+
+ {!recordingStore.isRecording && showPowerButton && ( + cycle()} /> + )} + {!recordingStore.isRecording && ( + reload()} + /> + )} +
)} + {/* Fallback: non-local without VNC (edge case) */} + {activeDebugSession && + !activeDebugSession.vnc_streaming_supported && + browserStreamingMode !== "cdp" && ( +
+ Browser streaming unavailable +
+ )} + {/* timeline */}
{ if (workflowRunId) { queryClient.invalidateQueries({ @@ -63,6 +66,15 @@ function WorkflowRunOverview() { } }, [queryClient, workflowPermanentId, workflowRunId]); + const handleVncClose = useCallback(() => { + setVncFailed(true); + invalidateQueries(); + }, [invalidateQueries]); + + useEffect(() => { + setVncFailed(false); + }, [browserSessionId]); + if (workflowRunIsLoading || workflowRunTimelineIsLoading) { return ( @@ -89,19 +101,10 @@ function WorkflowRunOverview() { finallyBlockLabel, ); - const browserSessionId = workflowRun.browser_session_id; - const isPaused = workflowRun && workflowRun.status === WorkflowRunStatus.Paused; - const showStreamingBrowser = - (!workflowRunIsFinalized && - browserSessionId && - isWorkflowRunBlock(selection) && - selection.block_type === "human_interaction") || - selection === "stream"; - - const shouldShowBrowserStream = !!( + const wantsVncStream = !!( browserSessionId && !workflowRunIsFinalized && (selection === "stream" || @@ -109,31 +112,42 @@ function WorkflowRunOverview() { selection.block_type === "human_interaction")) ); + const shouldShowBrowserStream = wantsVncStream && !vncFailed; + const shouldShowScreencastFallback = wantsVncStream && vncFailed; + + const isStreamActive = + shouldShowBrowserStream || + shouldShowScreencastFallback || + selection === "stream"; + return ( {shouldShowBrowserStream && ( )} - {!shouldShowBrowserStream && selection === "stream" && ( - - )} - {selection !== "stream" && - !showStreamingBrowser && - isAction(selection) && ( - )} - {isWorkflowRunBlock(selection) && !showStreamingBrowser && ( + {!isStreamActive && isAction(selection) && ( + + )} + {isWorkflowRunBlock(selection) && !isStreamActive && ( (""); + const [streamFormat, setStreamFormat] = useState("png"); + const [viewportWidth, setViewportWidth] = useState(1280); + const [viewportHeight, setViewportHeight] = useState(720); const showStream = alwaysShowStream || (workflowRun && statusIsNotFinalized(workflowRun)); const credentialGetter = useCredentialGetter(); @@ -35,40 +46,62 @@ function WorkflowRunStream(props?: Props) { const workflowPermanentId = workflow?.workflow_permanent_id; const queryClient = useQueryClient(); + const socketRef = useRef(null); + + const inputWsUrl = + interactive && workflowRunId + ? `${wssBaseUrl}/stream/cdp_input/workflow_run/${workflowRunId}` + : null; + + const { + userIsControlling, + setUserIsControlling, + inputReady, + containerRef, + handlers, + } = useCdpInput({ + inputWsUrl, + interactive, + viewportWidth, + viewportHeight, + }); + useEffect(() => { if (!showStream) { return; } async function run() { - // Create WebSocket connection. - let credential = null; - if (credentialGetter) { - const token = await credentialGetter(); - credential = `?token=Bearer ${token}`; - } else { - const apiKey = getRuntimeApiKey(); - credential = apiKey ? `?apikey=${apiKey}` : ""; + const credentialParam = await getCredentialParam(credentialGetter); + + if (socketRef.current) { + socketRef.current.close(); } - if (socket) { - socket.close(); - } - socket = new WebSocket( - `${wssBaseUrl}/stream/workflow_runs/${workflowRunId}${credential}`, + socketRef.current = new WebSocket( + `${wssBaseUrl}/stream/workflow_runs/${workflowRunId}?${credentialParam}`, ); - // Listen for messages - socket.addEventListener("message", (event) => { + + socketRef.current.addEventListener("message", (event) => { try { const message: StreamMessage = JSON.parse(event.data); if (message.screenshot) { setStreamImgSrc(message.screenshot); } + if (message.format) { + setStreamFormat(message.format); + } + if (message.viewport_width) { + setViewportWidth(message.viewport_width); + } + if (message.viewport_height) { + setViewportHeight(message.viewport_height); + } if ( message.status === "completed" || message.status === "failed" || message.status === "terminated" ) { - socket?.close(); + socketRef.current?.close(); queryClient.invalidateQueries({ queryKey: ["workflowRuns"], }); @@ -109,16 +142,16 @@ function WorkflowRunStream(props?: Props) { } }); - socket.addEventListener("close", () => { - socket = null; + socketRef.current.addEventListener("close", () => { + socketRef.current = null; }); } run(); return () => { - if (socket) { - socket.close(); - socket = null; + if (socketRef.current) { + socketRef.current.close(); + socketRef.current = null; } }; }, [ @@ -129,6 +162,10 @@ function WorkflowRunStream(props?: Props) { workflowPermanentId, ]); + const isRunningOrPaused = + workflowRun?.status === Status.Running || + workflowRun?.status === Status.Paused; + if (workflowRun?.status === Status.Created) { return (
@@ -146,7 +183,7 @@ function WorkflowRunStream(props?: Props) { ); } - if (workflowRun?.status === Status.Running && streamImgSrc.length === 0) { + if (isRunningOrPaused && streamImgSrc.length === 0) { return (
Starting the stream... @@ -154,29 +191,26 @@ function WorkflowRunStream(props?: Props) { ); } - if (workflowRun?.status === Status.Running && streamImgSrc.length > 0) { + const hasStream = + (isRunningOrPaused || alwaysShowStream) && streamImgSrc.length > 0; + + if (hasStream) { return ( -
- -
+ ); } if (alwaysShowStream) { - if (streamImgSrc?.length > 0) { - return ( -
- -
- ); - } - return (
Waiting for stream... diff --git a/skyvern-frontend/src/util/env.ts b/skyvern-frontend/src/util/env.ts index dbcfb1da0..e27d18d1a 100644 --- a/skyvern-frontend/src/util/env.ts +++ b/skyvern-frontend/src/util/env.ts @@ -10,6 +10,9 @@ if (!environment) { console.warn("environment environment variable was not set"); } +const browserStreamingMode = + (import.meta.env.VITE_BROWSER_STREAMING_MODE as string) ?? "vnc"; + const buildTimeApiKey: string | null = typeof import.meta.env.VITE_SKYVERN_API_KEY === "string" ? import.meta.env.VITE_SKYVERN_API_KEY @@ -94,6 +97,19 @@ function clearRuntimeApiKey(): void { } } +async function getCredentialParam( + credentialGetter: (() => Promise) | null, +): Promise { + if (credentialGetter) { + const token = await credentialGetter(); + if (token) { + return `token=Bearer ${token}`; + } + } + const apiKey = getRuntimeApiKey(); + return apiKey ? `apikey=${apiKey}` : ""; +} + const useNewRunsUrl = true as const; const enable2faNotifications = @@ -103,11 +119,13 @@ export { apiBaseUrl, runsApiBaseUrl, environment, + browserStreamingMode, artifactApiBaseUrl, apiPathPrefix, lsKeys, wssBaseUrl, newWssBaseUrl, + getCredentialParam, getRuntimeApiKey, persistRuntimeApiKey, clearRuntimeApiKey, diff --git a/skyvern/config.py b/skyvern/config.py index 9cfd35a23..254097088 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -78,6 +78,7 @@ class Settings(BaseSettings): TASK_RESPONSE_ACTION_SCREENSHOT_COUNT: int = 3 ENV: str = "local" + BROWSER_STREAMING_MODE: str = "vnc" EXECUTE_ALL_STEPS: bool = True JSON_LOGGING: bool = False LOG_RAW_API_REQUESTS: bool = True diff --git a/skyvern/forge/sdk/db/agent_db.py b/skyvern/forge/sdk/db/agent_db.py index 42e9c88a7..630a6c3d6 100644 --- a/skyvern/forge/sdk/db/agent_db.py +++ b/skyvern/forge/sdk/db/agent_db.py @@ -5847,6 +5847,7 @@ class AgentDB(BaseAlchemyDB): timeout_minutes: int | None = None, organization_id: str | None = None, completed_at: datetime | None = None, + started_at: datetime | None = None, ) -> PersistentBrowserSession: try: async with self.Session() as session: @@ -5867,6 +5868,8 @@ class AgentDB(BaseAlchemyDB): persistent_browser_session.timeout_minutes = timeout_minutes if completed_at is not None: persistent_browser_session.completed_at = completed_at + if started_at: + persistent_browser_session.started_at = started_at await session.commit() await session.refresh(persistent_browser_session) diff --git a/skyvern/forge/sdk/routes/__init__.py b/skyvern/forge/sdk/routes/__init__.py index a24204fca..4c06f74a3 100644 --- a/skyvern/forge/sdk/routes/__init__.py +++ b/skyvern/forge/sdk/routes/__init__.py @@ -10,6 +10,7 @@ from skyvern.forge.sdk.routes import scripts # noqa: F401 from skyvern.forge.sdk.routes import sdk # noqa: F401 from skyvern.forge.sdk.routes import webhooks # noqa: F401 from skyvern.forge.sdk.routes import workflow_copilot # noqa: F401 +from skyvern.forge.sdk.routes.streaming import cdp_input # noqa: F401 from skyvern.forge.sdk.routes.streaming import messages # noqa: F401 from skyvern.forge.sdk.routes.streaming import screenshot # noqa: F401 from skyvern.forge.sdk.routes.streaming import vnc # noqa: F401 diff --git a/skyvern/forge/sdk/routes/debug_sessions.py b/skyvern/forge/sdk/routes/debug_sessions.py index 2a5027839..ee1dd9f73 100644 --- a/skyvern/forge/sdk/routes/debug_sessions.py +++ b/skyvern/forge/sdk/routes/debug_sessions.py @@ -70,6 +70,19 @@ async def get_or_create_debug_session_by_user_and_workflow_permanent_id( workflow_permanent_id=workflow_permanent_id, ) + # Skip renewal for sessions that haven't started yet (browser still launching) + session = await app.DATABASE.get_persistent_browser_session( + debug_session.browser_session_id, + current_org.organization_id, + ) + if session and session.started_at is None and session.completed_at is None: + created_at_utc = ( + session.created_at.replace(tzinfo=timezone.utc) if session.created_at.tzinfo is None else session.created_at + ) + age_seconds = (datetime.now(timezone.utc) - created_at_utc).total_seconds() + if age_seconds < 120: + return debug_session + try: await app.PERSISTENT_SESSIONS_MANAGER.renew_or_close_session( debug_session.browser_session_id, diff --git a/skyvern/forge/sdk/routes/streaming/auth.py b/skyvern/forge/sdk/routes/streaming/auth.py index ed5a6b966..c469c6456 100644 --- a/skyvern/forge/sdk/routes/streaming/auth.py +++ b/skyvern/forge/sdk/routes/streaming/auth.py @@ -2,6 +2,8 @@ Streaming auth. """ +import typing as t + import structlog from fastapi import WebSocket from websockets.exceptions import ConnectionClosedOK @@ -13,6 +15,13 @@ from skyvern.forge.sdk.services.org_auth_service import get_current_org LOG = structlog.get_logger() +def require_client_id(client_id: str | None, **log_kwargs: t.Any) -> bool: + if client_id: + return True + LOG.error("No client_id provided", **log_kwargs) + return False + + class Constants: MISSING_API_KEY = "" @@ -35,7 +44,7 @@ async def get_x_api_key(organization_id: str) -> str: return x_api_key -async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> str | None: +async def auth(apikey: str | None, token: str | None, websocket: WebSocket, **log_kwargs: t.Any) -> str | None: """ Accepts the websocket connection. @@ -49,7 +58,7 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s await websocket.close(code=1002) return None except ConnectionClosedOK: - LOG.info("WebSocket connection closed cleanly.") + LOG.info("WebSocket connection closed cleanly.", **log_kwargs) return None try: @@ -60,11 +69,11 @@ async def auth(apikey: str | None, token: str | None, websocket: WebSocket) -> s await websocket.close(code=1002) return None except Exception: - LOG.exception("Error occurred while retrieving organization information.") + LOG.exception("Error occurred while retrieving organization information.", **log_kwargs) try: await websocket.close(code=1002) except ConnectionClosedOK: - LOG.info("WebSocket connection closed due to invalid credentials.") + LOG.info("WebSocket connection closed due to invalid credentials.", **log_kwargs) return None return organization_id diff --git a/skyvern/forge/sdk/routes/streaming/cdp_input.py b/skyvern/forge/sdk/routes/streaming/cdp_input.py new file mode 100644 index 000000000..b755d59e3 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/cdp_input.py @@ -0,0 +1,406 @@ +""" +CDP input channel for interactive browser control via Chrome DevTools Protocol. +""" + +import asyncio +import dataclasses +import json +import time +import typing as t + +import structlog +from fastapi import WebSocket, WebSocketDisconnect +from playwright.async_api import CDPSession +from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK + +from skyvern.forge import app +from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router +from skyvern.forge.sdk.routes.streaming.auth import auth, require_client_id +from skyvern.forge.sdk.routes.streaming.registries import ( + add_cdp_input_channel, + del_cdp_input_channel, + stream_ref_dec, + stream_ref_inc, +) +from skyvern.forge.sdk.routes.streaming.screencast import wait_for_browser_state +from skyvern.forge.sdk.schemas.persistent_browser_sessions import is_final_status +from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus + +LOG = structlog.get_logger() + +_VALID_MOUSE_TYPES = {"mousePressed", "mouseReleased", "mouseMoved"} +_VALID_MOUSE_BUTTONS = {"left", "middle", "right", "none"} +_VALID_KEY_TYPES = {"keyDown", "keyUp"} +_MAX_COORD = 10000 +_MAX_DELTA = 10000 +_MAX_KEY_LEN = 32 +_MAX_CODE_LEN = 32 +_MODIFIER_MASK = 0xF + + +@dataclasses.dataclass +class CdpInputChannel: + client_id: str + organization_id: str + websocket: WebSocket + interactor: t.Literal["agent", "user"] = "agent" + + def __post_init__(self) -> None: + add_cdp_input_channel(self) + + async def close(self) -> None: + del_cdp_input_channel(self.client_id) + + +def _validated_modifiers(msg: dict) -> int: + modifiers = msg.get("modifiers", 0) + if not isinstance(modifiers, int): + return 0 + return modifiers & _MODIFIER_MASK + + +def _validated_coords(msg: dict) -> tuple[int, int] | None: + x = msg.get("x") + y = msg.get("y") + if not isinstance(x, (int, float)) or not isinstance(y, (int, float)): + return None + return ( + max(0, min(int(x), _MAX_COORD)), + max(0, min(int(y), _MAX_COORD)), + ) + + +def _validate_mouse_event(msg: dict) -> dict | None: + event_type = msg.get("eventType") + if event_type not in _VALID_MOUSE_TYPES: + return None + + coords = _validated_coords(msg) + if coords is None: + return None + x, y = coords + + button = msg.get("button", "none") + if button not in _VALID_MOUSE_BUTTONS: + button = "none" + + click_count = msg.get("clickCount", 0) + if not isinstance(click_count, int): + click_count = 0 + click_count = max(0, min(click_count, 3)) + + return { + "type": event_type, + "x": x, + "y": y, + "button": button, + "clickCount": click_count, + "modifiers": _validated_modifiers(msg), + } + + +def _validate_key_event(msg: dict) -> dict | None: + event_type = msg.get("eventType") + if event_type not in _VALID_KEY_TYPES: + return None + + key = msg.get("key", "") + if not isinstance(key, str) or len(key) > _MAX_KEY_LEN: + return None + + code = msg.get("code", "") + if not isinstance(code, str) or len(code) > _MAX_CODE_LEN: + return None + + result: dict[str, t.Any] = { + "type": event_type, + "key": key, + "code": code, + "modifiers": _validated_modifiers(msg), + } + + # Only include text for printable single characters on keyDown + text = msg.get("text", "") + if isinstance(text, str) and len(text) == 1 and text.isprintable() and event_type == "keyDown": + result["text"] = text + + return result + + +def _validate_wheel_event(msg: dict) -> dict | None: + coords = _validated_coords(msg) + if coords is None: + return None + x, y = coords + + delta_x = msg.get("deltaX", 0) + delta_y = msg.get("deltaY", 0) + if not isinstance(delta_x, (int, float)): + delta_x = 0 + if not isinstance(delta_y, (int, float)): + delta_y = 0 + delta_x = max(-_MAX_DELTA, min(int(delta_x), _MAX_DELTA)) + delta_y = max(-_MAX_DELTA, min(int(delta_y), _MAX_DELTA)) + + return { + "type": "mouseWheel", + "x": x, + "y": y, + "deltaX": delta_x, + "deltaY": delta_y, + "modifiers": _validated_modifiers(msg), + } + + +async def _close_ws_safely(websocket: WebSocket, code: int, reason: str = "") -> None: + try: + await websocket.close(code=code, reason=reason) + except Exception: + pass + + +_EVENT_DISPATCH_MAP: dict[str, tuple[t.Callable[[dict], dict | None], str]] = { + "mouseEvent": (_validate_mouse_event, "Input.dispatchMouseEvent"), + "keyEvent": (_validate_key_event, "Input.dispatchKeyEvent"), + "wheelEvent": (_validate_wheel_event, "Input.dispatchMouseEvent"), +} + + +async def _dispatch_event( + cdp_session: CDPSession, + kind: str, + msg: dict, + log_id_key: str, + log_id_value: str, +) -> None: + entry = _EVENT_DISPATCH_MAP.get(kind) + if entry is None: + return + validator, cdp_method = entry + validated = validator(msg) + if validated: + await cdp_session.send(cdp_method, validated) + else: + LOG.warning( + "CDP input: validation failed", + **{log_id_key: log_id_value}, + kind=kind, + raw_event_type=msg.get("eventType"), + ) + + +async def _run_input_loop( + websocket: WebSocket, + channel: CdpInputChannel, + cdp_session: CDPSession, + log_id_key: str, + log_id_value: str, +) -> None: + dropped_log_count = 0 + while True: + try: + raw = await websocket.receive_text() + except WebSocketDisconnect: + break + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + LOG.warning("CDP input: malformed JSON", **{log_id_key: log_id_value}) + continue + + kind = msg.get("kind") or msg.get("type") + + if kind == "take-control": + channel.interactor = "user" + LOG.info("CDP input: take-control received", **{log_id_key: log_id_value}, client_id=channel.client_id) + continue + if kind == "cede-control": + channel.interactor = "agent" + LOG.info("CDP input: cede-control received", **{log_id_key: log_id_value}, client_id=channel.client_id) + continue + + if channel.interactor != "user": + if dropped_log_count < 5: + LOG.info( + "CDP input: event dropped", + interactor=channel.interactor, + **{log_id_key: log_id_value}, + kind=kind, + ) + dropped_log_count += 1 + continue + + try: + await _dispatch_event(cdp_session, kind, msg, log_id_key, log_id_value) + except Exception: + LOG.warning( + "CDP input: failed to dispatch event; closing input channel", + **{log_id_key: log_id_value}, + kind=kind, + exc_info=True, + ) + await websocket.close(code=4411, reason="dispatch_failed") + break + + +@legacy_base_router.websocket("/stream/cdp_input/workflow_run/{workflow_run_id}") +async def cdp_input_stream( + websocket: WebSocket, + workflow_run_id: str, + client_id: str | None = None, + apikey: str | None = None, + token: str | None = None, +) -> None: + organization_id = await auth(apikey=apikey, token=token, websocket=websocket, workflow_run_id=workflow_run_id) + if organization_id is None: + return + + if not require_client_id(client_id, workflow_run_id=workflow_run_id): + await _close_ws_safely(websocket, 1002) + return + assert client_id is not None + + channel = CdpInputChannel( + client_id=client_id, + organization_id=organization_id, + websocket=websocket, + ) + + cdp_session: CDPSession | None = None + try: + deadline = time.monotonic() + 120 + while True: + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + if not workflow_run or workflow_run.organization_id != organization_id: + LOG.info("CDP input: workflow run not found", workflow_run_id=workflow_run_id) + await websocket.close(code=4404, reason="workflow_run_not_found") + return + if workflow_run.status == WorkflowRunStatus.running or workflow_run.status.is_final(): + break + if workflow_run.status == WorkflowRunStatus.paused: + break + if time.monotonic() >= deadline: + LOG.warning("CDP input: timed out waiting for running status", workflow_run_id=workflow_run_id) + await websocket.close(code=4408, reason="wait_timeout") + return + await asyncio.sleep(1) + + browser_state = await wait_for_browser_state(workflow_run_id, "workflow_run") + if browser_state is None: + LOG.warning("CDP input: timed out waiting for browser state", workflow_run_id=workflow_run_id) + await websocket.close(code=4408, reason="browser_state_timeout") + return + + page = await browser_state.get_working_page() + if page is None: + LOG.warning("CDP input: no working page", workflow_run_id=workflow_run_id) + await websocket.close(code=4410, reason="no_working_page") + return + + cdp_session = await page.context.new_cdp_session(page) + stream_ref_inc(workflow_run_id) + + LOG.info("CDP input channel ready", workflow_run_id=workflow_run_id, client_id=client_id) + await websocket.send_json({"kind": "ready"}) + + await _run_input_loop(websocket, channel, cdp_session, "workflow_run_id", workflow_run_id) + + except ConnectionClosedOK: + LOG.info("CDP input: WS closed cleanly", workflow_run_id=workflow_run_id) + except ConnectionClosedError: + LOG.warning("CDP input: WS connection error", workflow_run_id=workflow_run_id) + except WebSocketDisconnect: + LOG.info("CDP input: WS disconnected", workflow_run_id=workflow_run_id) + except Exception: + LOG.warning("CDP input: unexpected error", workflow_run_id=workflow_run_id, exc_info=True) + finally: + if cdp_session is not None: + await stream_ref_dec(workflow_run_id) + try: + await cdp_session.detach() + except Exception: + pass + await channel.close() + LOG.info("CDP input channel closed", workflow_run_id=workflow_run_id, client_id=client_id) + + +@base_router.websocket("/stream/cdp_input/browser_session/{browser_session_id}") +async def cdp_input_browser_session_stream( + websocket: WebSocket, + browser_session_id: str, + client_id: str | None = None, + apikey: str | None = None, + token: str | None = None, +) -> None: + organization_id = await auth(apikey=apikey, token=token, websocket=websocket, browser_session_id=browser_session_id) + if organization_id is None: + return + + if not require_client_id(client_id, browser_session_id=browser_session_id): + await _close_ws_safely(websocket, 1002) + return + assert client_id is not None + + channel = CdpInputChannel( + client_id=client_id, + organization_id=organization_id, + websocket=websocket, + ) + + cdp_session: CDPSession | None = None + try: + session = await app.PERSISTENT_SESSIONS_MANAGER.get_session( + session_id=browser_session_id, + organization_id=organization_id, + ) + if not session: + LOG.info("CDP input: browser session not found", browser_session_id=browser_session_id) + await websocket.close(code=4404, reason="browser_session_not_found") + return + if is_final_status(session.status): + LOG.info("CDP input: browser session already finalized", browser_session_id=browser_session_id) + await websocket.close(code=4404, reason="browser_session_finalized") + return + + browser_state = await wait_for_browser_state(browser_session_id, "browser_session") + if browser_state is None: + LOG.warning("CDP input: timed out waiting for browser state", browser_session_id=browser_session_id) + await websocket.close(code=4408, reason="browser_state_timeout") + return + + page = await browser_state.get_working_page() + if page is None: + LOG.warning("CDP input: no working page", browser_session_id=browser_session_id) + await websocket.close(code=4410, reason="no_working_page") + return + + cdp_session = await page.context.new_cdp_session(page) + # stream_ref_inc/dec is intentionally omitted for browser sessions. + # Browser state lives in PersistentSessionsManager._browser_sessions, + # not BrowserManager.pages, so there is no entry to protect from eviction. + + LOG.info("CDP input channel ready", browser_session_id=browser_session_id, client_id=client_id) + await websocket.send_json({"kind": "ready"}) + + await _run_input_loop(websocket, channel, cdp_session, "browser_session_id", browser_session_id) + + except ConnectionClosedOK: + LOG.info("CDP input: WS closed cleanly", browser_session_id=browser_session_id) + except ConnectionClosedError: + LOG.warning("CDP input: WS connection error", browser_session_id=browser_session_id) + except WebSocketDisconnect: + LOG.info("CDP input: WS disconnected", browser_session_id=browser_session_id) + except Exception: + LOG.warning("CDP input: unexpected error", browser_session_id=browser_session_id, exc_info=True) + finally: + if cdp_session is not None: + try: + await cdp_session.detach() + except Exception: + pass + await channel.close() + LOG.info("CDP input channel closed", browser_session_id=browser_session_id, client_id=client_id) diff --git a/skyvern/forge/sdk/routes/streaming/registries.py b/skyvern/forge/sdk/routes/streaming/registries.py index 12143a893..91e879588 100644 --- a/skyvern/forge/sdk/routes/streaming/registries.py +++ b/skyvern/forge/sdk/routes/streaming/registries.py @@ -4,7 +4,7 @@ Contains registries for coordinating active WS connections (aka "channels", see NOTE: in AWS we had to turn on what amounts to sticky sessions for frontend apps, so that an individual frontend app instance is guaranteed to always connect to -the same backend api instance. This is beccause the two registries here are +the same backend api instance. This is because the two registries here are tied together via a `client_id` string. The tale-of-the-tape is this: @@ -12,6 +12,11 @@ The tale-of-the-tape is this: - one dedicated to streaming VNC's RFB protocol - the other dedicated to messaging (JSON) - both of these channels are stateful and need to coordinate with one another + +Additionally, this module manages: + - CDP input channels for interactive browser control + - Stream reference counts that defer browser state cleanup while active + CDP streams hold references to a workflow run's browser """ from __future__ import annotations @@ -21,6 +26,7 @@ import typing as t import structlog if t.TYPE_CHECKING: + from skyvern.forge.sdk.routes.streaming.cdp_input import CdpInputChannel from skyvern.forge.sdk.routes.streaming.channels.message import MessageChannel from skyvern.forge.sdk.routes.streaming.channels.vnc import VncChannel @@ -90,3 +96,55 @@ def del_message_channel(client_id: str, *, expected: MessageChannel | None = Non return del message_channels[client_id] + + +# Stream reference counts per workflow_run_id. +_stream_refcounts: dict[str, int] = {} +_deferred_close_params: dict[str, bool] = {} + + +def stream_ref_inc(workflow_run_id: str) -> None: + _stream_refcounts[workflow_run_id] = _stream_refcounts.get(workflow_run_id, 0) + 1 + + +async def stream_ref_dec(workflow_run_id: str) -> None: + count = _stream_refcounts.get(workflow_run_id, 0) - 1 + if count <= 0: + _stream_refcounts.pop(workflow_run_id, None) + close_on_completion = _deferred_close_params.pop(workflow_run_id, None) + if close_on_completion is not None: + from skyvern.forge import app + + browser_state = app.BROWSER_MANAGER.pages.get(workflow_run_id) + if browser_state is not None: + try: + await browser_state.close(close_browser_on_completion=close_on_completion) + except Exception: + LOG.warning( + "stream_ref_dec: error closing deferred browser state", + workflow_run_id=workflow_run_id, + exc_info=True, + ) + app.BROWSER_MANAGER.evict_page(workflow_run_id) + else: + _stream_refcounts[workflow_run_id] = count + + +def stream_ref_active(workflow_run_id: str) -> bool: + return _stream_refcounts.get(workflow_run_id, 0) > 0 + + +def set_deferred_close_params(workflow_run_id: str, close_browser_on_completion: bool) -> None: + _deferred_close_params[workflow_run_id] = close_browser_on_completion + + +# a registry for CDP input channels, keyed by `client_id` +cdp_input_channels: dict[str, CdpInputChannel] = {} + + +def add_cdp_input_channel(channel: CdpInputChannel) -> None: + cdp_input_channels[channel.client_id] = channel + + +def del_cdp_input_channel(client_id: str) -> None: + cdp_input_channels.pop(client_id, None) diff --git a/skyvern/forge/sdk/routes/streaming/screencast.py b/skyvern/forge/sdk/routes/streaming/screencast.py new file mode 100644 index 000000000..b99135b51 --- /dev/null +++ b/skyvern/forge/sdk/routes/streaming/screencast.py @@ -0,0 +1,202 @@ +""" +CDP screencast loop for local-mode browser streaming. + +Uses Chrome's Page.startScreencast() to stream JPEG frames from the browser +over a WebSocket connection. +""" + +import asyncio +from collections.abc import Awaitable, Callable + +import structlog +from fastapi import WebSocket +from playwright.async_api import CDPSession + +from skyvern.forge import app +from skyvern.webeye.browser_state import BrowserState + +LOG = structlog.get_logger() + +DEFAULT_WIDTH = 1280 +DEFAULT_HEIGHT = 720 + + +async def wait_for_browser_state( + entity_id: str, + entity_type: str, + workflow_run_id: str | None = None, + timeout: float = 120, + poll_interval: float = 1.0, +) -> BrowserState | None: + elapsed = 0.0 + while elapsed < timeout: + browser_state = await _resolve_browser_state(entity_id, entity_type, workflow_run_id) + + if browser_state is not None: + page = await browser_state.get_working_page() + if page is not None: + return browser_state + + await asyncio.sleep(poll_interval) + elapsed += poll_interval + + return None + + +async def _resolve_browser_state( + entity_id: str, + entity_type: str, + workflow_run_id: str | None = None, +) -> BrowserState | None: + if entity_type == "workflow_run": + return app.BROWSER_MANAGER.get_for_workflow_run(entity_id) + if entity_type == "task": + return app.BROWSER_MANAGER.get_for_task(entity_id, workflow_run_id) + if entity_type == "browser_session": + return await app.PERSISTENT_SESSIONS_MANAGER.get_browser_state(entity_id) + return None + + +async def start_screencast_loop( + websocket: WebSocket, + browser_state: BrowserState, + entity_id: str, + entity_type: str, + check_finalized: Callable[[], Awaitable[bool]], +) -> None: + id_key = f"{entity_type}_id" + cdp_session: CDPSession | None = None + frame_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=2) + viewport_info: dict[str, int] = {"width": DEFAULT_WIDTH, "height": DEFAULT_HEIGHT} + + async def _ack_frame(session_id: int) -> None: + if cdp_session is None: + return + try: + await cdp_session.send("Page.screencastFrameAck", {"sessionId": session_id}) + except Exception: + pass + + def _update_viewport_from_metadata(metadata: dict) -> None: + device_width = metadata.get("deviceWidth") + device_height = metadata.get("deviceHeight") + if isinstance(device_width, (int, float)) and device_width > 0: + viewport_info["width"] = int(device_width) + if isinstance(device_height, (int, float)) and device_height > 0: + viewport_info["height"] = int(device_height) + + async def _on_frame(params: dict) -> None: + data = params.get("data", "") + session_id = params.get("sessionId", 0) + metadata = params.get("metadata", {}) + if metadata: + _update_viewport_from_metadata(metadata) + asyncio.create_task(_ack_frame(session_id)) + if not data: + return + # Drop oldest frame if queue is full to keep latency low + if frame_queue.full(): + try: + frame_queue.get_nowait() + except asyncio.QueueEmpty: + pass + try: + frame_queue.put_nowait(data) + except asyncio.QueueFull: + pass + + async def _frame_forwarding_loop() -> None: + while True: + data = await frame_queue.get() + current_url = "" + try: + page = await browser_state.get_working_page() + if page is not None: + current_url = page.url + except Exception: + pass + try: + await websocket.send_json( + { + id_key: entity_id, + "status": "running", + "screenshot": data, + "format": "jpeg", + "viewport_width": viewport_info["width"], + "viewport_height": viewport_info["height"], + "url": current_url, + } + ) + except Exception: + break + + async def _completion_polling_loop() -> None: + while True: + await asyncio.sleep(2) + try: + if await check_finalized(): + return + except Exception: + LOG.warning( + "Error checking finalization status", + entity_id=entity_id, + entity_type=entity_type, + exc_info=True, + ) + + try: + page = await browser_state.get_working_page() + if page is None: + raise RuntimeError("No working page available for screencast") + + cdp_session = await page.context.new_cdp_session(page) + cdp_session.on("Page.screencastFrame", _on_frame) + await cdp_session.send( + "Page.startScreencast", + { + "format": "jpeg", + "quality": 60, + "maxWidth": DEFAULT_WIDTH, + "maxHeight": DEFAULT_HEIGHT, + }, + ) + LOG.info("CDP screencast started", entity_id=entity_id, entity_type=entity_type) + + forward_task = asyncio.create_task(_frame_forwarding_loop()) + poll_task = asyncio.create_task(_completion_polling_loop()) + + done, pending = await asyncio.wait( + [forward_task, poll_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + # Re-raise forwarding errors (e.g. WebSocket disconnect) + for task in done: + exc = task.exception() + if task is forward_task and exc is not None: + raise exc + + except Exception: + LOG.info( + "Screencast loop ended", + entity_id=entity_id, + entity_type=entity_type, + exc_info=True, + ) + finally: + if cdp_session is not None: + try: + await cdp_session.send("Page.stopScreencast", {}) + except Exception: + pass + try: + await cdp_session.detach() + except Exception: + pass + LOG.info("CDP screencast cleaned up", entity_id=entity_id, entity_type=entity_type) diff --git a/skyvern/forge/sdk/routes/streaming/screenshot.py b/skyvern/forge/sdk/routes/streaming/screenshot.py index 1d46ad632..5a06bd37d 100644 --- a/skyvern/forge/sdk/routes/streaming/screenshot.py +++ b/skyvern/forge/sdk/routes/streaming/screenshot.py @@ -4,6 +4,7 @@ Provides WS endpoints for streaming screenshots. Screenshot streaming is created on the basis of one of these database entities: - task (run) - workflow run + - browser session Screenshot streaming is used for a run that is invoked without a browser session. Otherwise, VNC streaming is used. @@ -11,6 +12,8 @@ Otherwise, VNC streaming is used. import asyncio import base64 +import time +from collections.abc import Awaitable, Callable from datetime import datetime import structlog @@ -18,14 +21,18 @@ from fastapi import HTTPException, WebSocket, WebSocketDisconnect from pydantic import ValidationError from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK +from skyvern.config import settings from skyvern.forge import app -from skyvern.forge.sdk.routes.routers import legacy_base_router +from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router +from skyvern.forge.sdk.routes.streaming.screencast import start_screencast_loop, wait_for_browser_state +from skyvern.forge.sdk.schemas.persistent_browser_sessions import is_final_status from skyvern.forge.sdk.schemas.tasks import TaskStatus from skyvern.forge.sdk.services.org_auth_service import get_current_org from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus LOG = structlog.get_logger() STREAMING_TIMEOUT = 300 +WAIT_FOR_RUNNING_TIMEOUT = 120 @legacy_base_router.websocket("/stream/tasks/{task_id}") @@ -56,6 +63,11 @@ async def task_stream( return LOG.info("Started task streaming", task_id=task_id, organization_id=organization_id) + + if settings.BROWSER_STREAMING_MODE == "cdp": + await _local_screencast_for_task(websocket, task_id, organization_id) + return + # timestamp last time when streaming activity happens last_activity_timestamp = datetime.utcnow() @@ -173,6 +185,11 @@ async def workflow_run_streaming( workflow_run_id=workflow_run_id, organization_id=organization_id, ) + + if settings.BROWSER_STREAMING_MODE == "cdp": + await _local_screencast_for_workflow_run(websocket, workflow_run_id, organization_id) + return + # timestamp last time when streaming activity happens last_activity_timestamp = datetime.utcnow() @@ -229,7 +246,7 @@ async def workflow_run_streaming( ) return - if workflow_run.status == WorkflowRunStatus.running: + if workflow_run.status in (WorkflowRunStatus.running, WorkflowRunStatus.paused): file_name = f"{workflow_run_id}.png" screenshot = await app.STORAGE.get_streaming_file(organization_id, file_name) if screenshot: @@ -280,3 +297,242 @@ async def workflow_run_streaming( organization_id=organization_id, ) return + + +@base_router.websocket("/stream/browser_sessions/{browser_session_id}") +async def browser_session_streaming( + websocket: WebSocket, + browser_session_id: str, + apikey: str | None = None, + token: str | None = None, +) -> None: + try: + await websocket.accept() + if not token and not apikey: + await websocket.send_text("No valid credential provided") + await websocket.close() + return + except ConnectionClosedOK: + LOG.info("BrowserSession Streaming: ConnectionClosedOK error. Streaming won't start") + return + + try: + organization = await get_current_org(x_api_key=apikey, authorization=token) + organization_id = organization.organization_id + except Exception: + LOG.exception("Error while getting organization", browser_session_id=browser_session_id) + try: + await websocket.send_text("Invalid credential provided") + except ConnectionClosedOK: + LOG.info("BrowserSession Streaming: ConnectionClosedOK error while sending invalid credential message") + return + + LOG.info( + "BrowserSession Streaming: Started", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + + if settings.BROWSER_STREAMING_MODE == "cdp": + await _local_screencast_for_browser_session(websocket, browser_session_id, organization_id) + return + + await websocket.close(code=4001, reason="use-vnc-streaming") + return + + +async def _send_status(websocket: WebSocket, id_key: str, entity_id: str, status: str) -> None: + await websocket.send_json({id_key: entity_id, "status": status}) + + +async def _run_local_screencast( + websocket: WebSocket, + entity_id: str, + entity_type: str, + wait_for_running: Callable[[], Awaitable[str | None]], + check_finalized: Callable[[], Awaitable[bool]], + get_current_status: Callable[[], Awaitable[str | None]], + get_workflow_run_id: Callable[[], str | None] | None = None, +) -> None: + id_key = f"{entity_type}_id" + try: + early_exit_status = await wait_for_running() + if early_exit_status is not None: + await _send_status(websocket, id_key, entity_id, early_exit_status) + return + + workflow_run_id = get_workflow_run_id() if get_workflow_run_id else None + browser_state = await wait_for_browser_state( + entity_id, + entity_type, + workflow_run_id=workflow_run_id, + ) + if browser_state is None: + LOG.warning("Timed out waiting for browser state", **{id_key: entity_id}) + await _send_status(websocket, id_key, entity_id, "timeout") + return + + await start_screencast_loop( + websocket=websocket, + browser_state=browser_state, + entity_id=entity_id, + entity_type=entity_type, + check_finalized=check_finalized, + ) + + final_status = await get_current_status() + if final_status is not None: + try: + await _send_status(websocket, id_key, entity_id, final_status) + except Exception: + LOG.debug("Could not send final status (WebSocket likely closed)", **{id_key: entity_id}) + + except (WebSocketDisconnect, ConnectionClosedOK): + LOG.info("WebSocket closed during local screencast", **{id_key: entity_id}) + except ConnectionClosedError: + LOG.warning("WebSocket connection error during local screencast", **{id_key: entity_id}) + except Exception: + LOG.warning("Error in local screencast", **{id_key: entity_id}, exc_info=True) + + +async def _local_screencast_for_workflow_run( + websocket: WebSocket, + workflow_run_id: str, + organization_id: str, +) -> None: + async def wait_for_running() -> str | None: + deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT + while True: + workflow_run = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + if not workflow_run or workflow_run.organization_id != organization_id: + return "not_found" + if workflow_run.status.is_final(): + return workflow_run.status + if workflow_run.status in (WorkflowRunStatus.running, WorkflowRunStatus.paused): + return None + if time.monotonic() >= deadline: + LOG.warning( + "Timed out waiting for running status", + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return "timeout" + await asyncio.sleep(1) + + async def check_finalized() -> bool: + wr = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return wr is None or wr.status.is_final() + + async def get_current_status() -> str | None: + wr = await app.DATABASE.get_workflow_run( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + return wr.status if wr else None + + await _run_local_screencast( + websocket=websocket, + entity_id=workflow_run_id, + entity_type="workflow_run", + wait_for_running=wait_for_running, + check_finalized=check_finalized, + get_current_status=get_current_status, + ) + + +async def _local_screencast_for_task( + websocket: WebSocket, + task_id: str, + organization_id: str, +) -> None: + task_workflow_run_id: str | None = None + + async def wait_for_running() -> str | None: + nonlocal task_workflow_run_id + deadline = time.monotonic() + WAIT_FOR_RUNNING_TIMEOUT + while True: + task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id) + if not task: + return "not_found" + if task.status.is_final(): + return task.status + if task.status == TaskStatus.running: + task_workflow_run_id = task.workflow_run_id + return None + if time.monotonic() >= deadline: + LOG.warning( + "Timed out waiting for running status", + task_id=task_id, + organization_id=organization_id, + ) + return "timeout" + await asyncio.sleep(1) + + async def check_finalized() -> bool: + task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id) + return task is None or task.status.is_final() + + async def get_current_status() -> str | None: + task = await app.DATABASE.get_task(task_id=task_id, organization_id=organization_id) + return task.status if task else None + + await _run_local_screencast( + websocket=websocket, + entity_id=task_id, + entity_type="task", + wait_for_running=wait_for_running, + check_finalized=check_finalized, + get_current_status=get_current_status, + get_workflow_run_id=lambda: task_workflow_run_id, + ) + + +async def _local_screencast_for_browser_session( + websocket: WebSocket, + browser_session_id: str, + organization_id: str, +) -> None: + async def wait_for_running() -> str | None: + session = await app.PERSISTENT_SESSIONS_MANAGER.get_session( + session_id=browser_session_id, + organization_id=organization_id, + ) + if not session: + LOG.warning( + "Browser session not found for organization", + browser_session_id=browser_session_id, + organization_id=organization_id, + ) + return "not_found" + if is_final_status(session.status): + return session.status + return None + + async def check_finalized() -> bool: + s = await app.PERSISTENT_SESSIONS_MANAGER.get_session( + session_id=browser_session_id, + organization_id=organization_id, + ) + return s is None or is_final_status(s.status) + + async def get_current_status() -> str | None: + s = await app.PERSISTENT_SESSIONS_MANAGER.get_session( + session_id=browser_session_id, + organization_id=organization_id, + ) + return s.status if s else None + + await _run_local_screencast( + websocket=websocket, + entity_id=browser_session_id, + entity_type="browser_session", + wait_for_running=wait_for_running, + check_finalized=check_finalized, + get_current_status=get_current_status, + ) diff --git a/skyvern/forge/sdk/routes/streaming/vnc.py b/skyvern/forge/sdk/routes/streaming/vnc.py index 512beb820..ab2e81bb6 100644 --- a/skyvern/forge/sdk/routes/streaming/vnc.py +++ b/skyvern/forge/sdk/routes/streaming/vnc.py @@ -14,7 +14,7 @@ from fastapi import WebSocket from websockets.exceptions import ConnectionClosedOK from skyvern.forge.sdk.routes.routers import base_router, legacy_base_router -from skyvern.forge.sdk.routes.streaming.auth import auth +from skyvern.forge.sdk.routes.streaming.auth import auth, require_client_id from skyvern.forge.sdk.routes.streaming.channels.vnc import ( Loops, VncChannel, @@ -70,14 +70,14 @@ async def stream( token: str | None = None, workflow_run_id: str | None = None, ) -> None: - if not client_id: - LOG.error( - "Client ID not provided for vnc stream.", - browser_session_id=browser_session_id, - task_id=task_id, - workflow_run_id=workflow_run_id, - ) + if not require_client_id( + client_id, + browser_session_id=browser_session_id, + task_id=task_id, + workflow_run_id=workflow_run_id, + ): return + assert client_id is not None LOG.debug( "Starting vnc stream.", diff --git a/skyvern/webeye/browser_manager.py b/skyvern/webeye/browser_manager.py index 2fca1f85f..03c420646 100644 --- a/skyvern/webeye/browser_manager.py +++ b/skyvern/webeye/browser_manager.py @@ -48,6 +48,8 @@ class BrowserManager(Protocol): browser_session_id: str | None = None, ) -> BrowserState: ... + def evict_page(self, page_id: str) -> None: ... + def get_for_task(self, task_id: str, workflow_run_id: str | None = None) -> BrowserState | None: ... def get_for_workflow_run( diff --git a/skyvern/webeye/default_persistent_sessions_manager.py b/skyvern/webeye/default_persistent_sessions_manager.py index 8cd187735..d4f661f83 100644 --- a/skyvern/webeye/default_persistent_sessions_manager.py +++ b/skyvern/webeye/default_persistent_sessions_manager.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from datetime import datetime, timedelta, timezone from math import floor @@ -7,7 +8,6 @@ from pathlib import Path import structlog from playwright._impl._errors import TargetClosedError -from playwright.async_api import async_playwright from skyvern.config import settings from skyvern.exceptions import BrowserSessionNotRenewable, MissingBrowserAddressError @@ -22,11 +22,10 @@ from skyvern.forge.sdk.schemas.persistent_browser_sessions import ( is_final_status, ) from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput -from skyvern.webeye.browser_factory import BrowserContextFactory from skyvern.webeye.browser_state import BrowserState -from skyvern.webeye.cdp_ports import _allocate_cdp_port, _release_cdp_port +from skyvern.webeye.cdp_ports import _release_cdp_port from skyvern.webeye.persistent_sessions_manager import PersistentSessionsManager -from skyvern.webeye.real_browser_state import RealBrowserState +from skyvern.webeye.real_browser_manager import RealBrowserManager LOG = structlog.get_logger() @@ -182,7 +181,8 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager): """Default (OSS) implementation of PersistentSessionsManager protocol.""" instance: DefaultPersistentSessionsManager | None = None - _browser_sessions: dict[str, BrowserSession] = {} + _browser_sessions: dict[str, BrowserSession] = dict() + _background_tasks: set[asyncio.Task[None]] = set() database: AgentDB def __new__(cls, database: AgentDB) -> DefaultPersistentSessionsManager: @@ -198,71 +198,6 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager): def watch_session_pool(self) -> None: """No-op in OSS: browsers run in-process, no external pool to monitor.""" - async def _launch_browser_for_session( - self, - session_id: str, - organization_id: str, - ) -> None: - """Launch a browser process and register it as a persistent session.""" - cdp_port = _allocate_cdp_port() - LOG.info("Launching browser for persistent session", session_id=session_id, cdp_port=cdp_port) - - pw = None - browser_state = None - try: - pw = await async_playwright().start() - browser_context, browser_artifacts, browser_cleanup = await BrowserContextFactory.create_browser_context( - pw, - organization_id=organization_id, - cdp_port=cdp_port, - ) - - browser_state = RealBrowserState( - pw=pw, - browser_context=browser_context, - page=None, - browser_artifacts=browser_artifacts, - browser_cleanup=browser_cleanup, - ) - await browser_state.get_or_create_page(organization_id=organization_id) - - self._browser_sessions[session_id] = BrowserSession( - browser_state=browser_state, - cdp_port=cdp_port, - ) - - browser_address = f"http://127.0.0.1:{cdp_port}" - await self.database.set_persistent_browser_session_browser_address( - browser_session_id=session_id, - browser_address=browser_address, - ip_address="127.0.0.1", - ecs_task_arn=None, - organization_id=organization_id, - ) - await self.database.update_persistent_browser_session( - session_id, - organization_id=organization_id, - status=PersistentBrowserSessionStatus.running, - ) - - LOG.info("Browser launched for persistent session", session_id=session_id, browser_address=browser_address) - except BaseException: - _release_cdp_port(cdp_port) - # Close whichever resource was successfully created. - # browser_state.close() stops playwright internally, so only fall - # back to pw.stop() when no browser_state was created. - if browser_state is not None: - try: - await browser_state.close() - except Exception: - LOG.warning("Failed to close browser_state during cleanup", exc_info=True) - elif pw is not None: - try: - await pw.stop() - except Exception: - LOG.warning("Failed to stop playwright during cleanup", exc_info=True) - raise - async def begin_session( self, *, @@ -356,26 +291,78 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager): browser_profile_id=browser_profile_id, ) + # In local mode, launch the browser immediately for standalone sessions + # so the screencast/CDP input endpoints can connect. + if settings.BROWSER_STREAMING_MODE == "cdp" and runnable_id is None: + session_id = session.persistent_browser_session_id + task = asyncio.create_task( + self._launch_browser_for_session(session_id, organization_id, proxy_location, url) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + return session + + async def _launch_browser_for_session( + self, + session_id: str, + organization_id: str, + proxy_location: ProxyLocationInput | None = None, + url: str | None = None, + ) -> None: try: - await self._launch_browser_for_session( - session_id=session.persistent_browser_session_id, + browser_state = await RealBrowserManager._create_browser_state( + proxy_location=proxy_location, + url=url, organization_id=organization_id, ) - # Re-fetch to get updated browser_address/ip_address/started_at - updated = await self.database.get_persistent_browser_session( - session_id=session.persistent_browser_session_id, - organization_id=organization_id, - ) - if updated: - return updated - except Exception: - LOG.exception( - "Failed to launch browser for session, session will have no browser", - session_id=session.persistent_browser_session_id, + await browser_state.get_or_create_page( + url=url or "about:blank", + proxy_location=proxy_location, organization_id=organization_id, ) - return session + session = await self.get_session(session_id, organization_id) + if session is None or is_final_status(session.status) or session.completed_at is not None: + LOG.info( + "Session closed during browser launch, discarding browser", + browser_session_id=session_id, + ) + await browser_state.close() + return + + if session_id in self._browser_sessions: + LOG.info( + "Session already has browser state, discarding duplicate", + browser_session_id=session_id, + ) + await browser_state.close() + return + + self._browser_sessions[session_id] = BrowserSession(browser_state=browser_state) + + result = await self.update_status(session_id, organization_id, PersistentBrowserSessionStatus.running) + if result is None: + self._browser_sessions.pop(session_id, None) + await browser_state.close() + return + # Set started_at so renewal knows the browser is live + await self.database.update_persistent_browser_session( + session_id, + organization_id=organization_id, + started_at=datetime.now(timezone.utc), + ) + LOG.info( + "Browser launched for standalone session", + browser_session_id=session_id, + organization_id=organization_id, + ) + except Exception: + LOG.exception( + "Failed to launch browser for standalone session", + browser_session_id=session_id, + organization_id=organization_id, + ) async def occupy_browser_session( self, @@ -396,7 +383,21 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager): try: return await renew_session(self.database, session_id, organization_id) except BrowserSessionNotRenewable: - await self.close_session(organization_id, session_id) + session = await self.get_session(session_id, organization_id) + # Don't close sessions that haven't started yet (browser still launching) + # unless they're stuck (older than 120s) + if session is not None and session.started_at is None and session.completed_at is None: + created_at_utc = ( + session.created_at.replace(tzinfo=timezone.utc) + if session.created_at.tzinfo is None + else session.created_at + ) + age_seconds = (datetime.now(timezone.utc) - created_at_utc).total_seconds() + if age_seconds < 120: + raise + # Session doesn't exist, has started, is completed, or is stuck — close it + if session is None or session.completed_at is None: + await self.close_session(organization_id, session_id) raise async def update_status( @@ -488,7 +489,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager): ) await self.database.close_persistent_browser_session(browser_session_id, organization_id) - if settings.ENV == "local": + if settings.BROWSER_STREAMING_MODE == "cdp": await self.database.archive_browser_session_address(browser_session_id, organization_id) async def close_all_sessions(self, organization_id: str) -> None: @@ -499,7 +500,7 @@ class DefaultPersistentSessionsManager(PersistentSessionsManager): async def cleanup_stale_sessions(self) -> None: """Close sessions left active by a previous process.""" - if settings.ENV != "local": + if settings.BROWSER_STREAMING_MODE != "cdp": return stale_sessions = await self.database.get_uncompleted_persistent_browser_sessions() for db_session in stale_sessions: diff --git a/skyvern/webeye/real_browser_manager.py b/skyvern/webeye/real_browser_manager.py index e15b9bfe3..224dc16c1 100644 --- a/skyvern/webeye/real_browser_manager.py +++ b/skyvern/webeye/real_browser_manager.py @@ -62,6 +62,9 @@ class RealBrowserManager(BrowserManager): browser_cleanup=browser_cleanup, ) + def evict_page(self, page_id: str) -> None: + self.pages.pop(page_id, None) + def get_for_task(self, task_id: str, workflow_run_id: str | None = None) -> BrowserState | None: if task_id in self.pages: return self.pages[task_id] @@ -389,7 +392,12 @@ class RealBrowserManager(BrowserManager): organization_id: str | None = None, ) -> BrowserState | None: LOG.info("Cleaning up for workflow run") - browser_state_to_close = self.pages.pop(workflow_run_id, None) + browser_state_to_close = self.pages.get(workflow_run_id) + + from skyvern.forge.sdk.routes.streaming.registries import set_deferred_close_params, stream_ref_active + + streams_active = stream_ref_active(workflow_run_id) + if browser_state_to_close: # If another workflow run still references this browser state (e.g. a # parent whose in-memory browser was shared via use_parent_browser_session), @@ -414,10 +422,21 @@ class RealBrowserManager(BrowserManager): await browser_state_to_close.browser_context.tracing.stop(path=trace_path) LOG.info("Stopped tracing", trace_path=trace_path) - await browser_state_to_close.close(close_browser_on_completion=effective_close) + if streams_active: + # Defer close until the last stream disconnects + LOG.info( + "Deferring browser close — active CDP streams", + workflow_run_id=workflow_run_id, + ) + set_deferred_close_params(workflow_run_id, close_browser_on_completion) + else: + await browser_state_to_close.close(close_browser_on_completion=effective_close) + + if not streams_active: + self.pages.pop(workflow_run_id, None) for task_id in task_ids: task_browser_state = self.pages.pop(task_id, None) - if task_browser_state is None: + if task_browser_state is None or streams_active: continue # Same shared-state check for task-level entries shared = any(bs is task_browser_state for bs in self.pages.values()) diff --git a/tests/unit_tests/_stub_streaming.py b/tests/unit_tests/_stub_streaming.py new file mode 100644 index 000000000..583df92cf --- /dev/null +++ b/tests/unit_tests/_stub_streaming.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import sys +from types import ModuleType +from typing import Sequence +from unittest.mock import MagicMock + +# Importing a streaming route module triggers skyvern.forge.sdk.routes.__init__, +# which eagerly imports sibling route modules with heavy side effects (e.g. AWS clients). +# Stub those modules so streaming-only tests stay fast and deterministic. +_BASE_STUB_MODULES: list[str] = [ + "skyvern.forge.sdk.api.aws", + "skyvern.forge.sdk.routes.agent_protocol", + "skyvern.forge.sdk.routes.browser_profiles", + "skyvern.forge.sdk.routes.browser_sessions", + "skyvern.forge.sdk.routes.credentials", + "skyvern.forge.sdk.routes.debug_sessions", + "skyvern.forge.sdk.routes.prompts", + "skyvern.forge.sdk.routes.pylon", + "skyvern.forge.sdk.routes.run_blocks", + "skyvern.forge.sdk.routes.scripts", + "skyvern.forge.sdk.routes.sdk", + "skyvern.forge.sdk.routes.webhooks", + "skyvern.forge.sdk.routes.workflow_copilot", + "skyvern.forge.sdk.routes.streaming.cdp_input", + "skyvern.forge.sdk.routes.streaming.messages", + "skyvern.forge.sdk.routes.streaming.notifications", + "skyvern.forge.sdk.routes.streaming.vnc", +] + + +def import_with_stubs(module_path: str, extra_stubs: Sequence[str] = ()) -> ModuleType: + """Import a streaming module after temporarily stubbing heavy dependencies.""" + all_stubs = list(_BASE_STUB_MODULES) + list(extra_stubs) + installed: dict[str, MagicMock] = {} + for mod in all_stubs: + if mod not in sys.modules: + installed[mod] = sys.modules[mod] = MagicMock() + + try: + __import__(module_path) + return sys.modules[module_path] + finally: + for mod in installed: + del sys.modules[mod] diff --git a/tests/unit_tests/test_streaming_screencast.py b/tests/unit_tests/test_streaming_screencast.py new file mode 100644 index 000000000..3856ee7a4 --- /dev/null +++ b/tests/unit_tests/test_streaming_screencast.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import pytest + +from tests.unit_tests._stub_streaming import import_with_stubs + +screencast = import_with_stubs( + "skyvern.forge.sdk.routes.streaming.screencast", + extra_stubs=["skyvern.forge.sdk.routes.streaming.screenshot"], +) + + +def _make_app(browser_manager=None, persistent_sessions_manager=None): + """Build a fake app namespace to replace screencast.app (an AppHolder proxy).""" + return SimpleNamespace( + BROWSER_MANAGER=browser_manager or SimpleNamespace(), + PERSISTENT_SESSIONS_MANAGER=persistent_sessions_manager or SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_resolve_browser_state_for_workflow_run(monkeypatch: pytest.MonkeyPatch) -> None: + expected_state = object() + fake_app = _make_app( + browser_manager=SimpleNamespace(get_for_workflow_run=Mock(return_value=expected_state), get_for_task=Mock()), + persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock()), + ) + monkeypatch.setattr(screencast, "app", fake_app) + + result = await screencast._resolve_browser_state("wr_123", "workflow_run") + + assert result is expected_state + fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_called_once_with("wr_123") + fake_app.BROWSER_MANAGER.get_for_task.assert_not_called() + fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_resolve_browser_state_for_task(monkeypatch: pytest.MonkeyPatch) -> None: + expected_state = object() + fake_app = _make_app( + browser_manager=SimpleNamespace(get_for_workflow_run=Mock(), get_for_task=Mock(return_value=expected_state)), + persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock()), + ) + monkeypatch.setattr(screencast, "app", fake_app) + + result = await screencast._resolve_browser_state("task_123", "task", workflow_run_id="wr_123") + + assert result is expected_state + fake_app.BROWSER_MANAGER.get_for_task.assert_called_once_with("task_123", "wr_123") + fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_not_called() + fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_resolve_browser_state_for_browser_session(monkeypatch: pytest.MonkeyPatch) -> None: + expected_state = object() + fake_app = _make_app( + browser_manager=SimpleNamespace(get_for_workflow_run=Mock(), get_for_task=Mock()), + persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock(return_value=expected_state)), + ) + monkeypatch.setattr(screencast, "app", fake_app) + + result = await screencast._resolve_browser_state("bs_123", "browser_session") + + assert result is expected_state + fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_awaited_once_with("bs_123") + fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_not_called() + fake_app.BROWSER_MANAGER.get_for_task.assert_not_called() + + +@pytest.mark.asyncio +async def test_resolve_browser_state_unknown_entity_type(monkeypatch: pytest.MonkeyPatch) -> None: + fake_app = _make_app( + browser_manager=SimpleNamespace(get_for_workflow_run=Mock(), get_for_task=Mock()), + persistent_sessions_manager=SimpleNamespace(get_browser_state=AsyncMock()), + ) + monkeypatch.setattr(screencast, "app", fake_app) + + result = await screencast._resolve_browser_state("id_123", "unknown") + + assert result is None + fake_app.BROWSER_MANAGER.get_for_workflow_run.assert_not_called() + fake_app.BROWSER_MANAGER.get_for_task.assert_not_called() + fake_app.PERSISTENT_SESSIONS_MANAGER.get_browser_state.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_wait_for_browser_state_returns_when_working_page_is_available( + monkeypatch: pytest.MonkeyPatch, +) -> None: + browser_state = SimpleNamespace(get_working_page=AsyncMock(return_value=object())) + resolve_mock = AsyncMock(return_value=browser_state) + sleep_mock = AsyncMock() + monkeypatch.setattr(screencast, "_resolve_browser_state", resolve_mock) + monkeypatch.setattr(screencast.asyncio, "sleep", sleep_mock) + + result = await screencast.wait_for_browser_state("wr_123", "workflow_run", timeout=1, poll_interval=0.1) + + assert result is browser_state + resolve_mock.assert_awaited_once_with("wr_123", "workflow_run", None) + browser_state.get_working_page.assert_awaited_once() + sleep_mock.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_wait_for_browser_state_returns_none_on_timeout(monkeypatch: pytest.MonkeyPatch) -> None: + browser_state = SimpleNamespace(get_working_page=AsyncMock(return_value=None)) + resolve_mock = AsyncMock(return_value=browser_state) + sleep_mock = AsyncMock() + monkeypatch.setattr(screencast, "_resolve_browser_state", resolve_mock) + monkeypatch.setattr(screencast.asyncio, "sleep", sleep_mock) + + result = await screencast.wait_for_browser_state( + "bs_123", + "browser_session", + timeout=0.3, + poll_interval=0.1, + ) + + assert result is None + assert resolve_mock.await_count == 3 + assert browser_state.get_working_page.await_count == 3 + assert sleep_mock.await_count == 3 diff --git a/tests/unit_tests/test_streaming_screenshot_local.py b/tests/unit_tests/test_streaming_screenshot_local.py new file mode 100644 index 000000000..a646fb788 --- /dev/null +++ b/tests/unit_tests/test_streaming_screenshot_local.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + +from tests.unit_tests._stub_streaming import import_with_stubs + +screenshot = import_with_stubs("skyvern.forge.sdk.routes.streaming.screenshot") + + +@pytest.mark.asyncio +async def test_run_local_screencast_happy_path(monkeypatch: pytest.MonkeyPatch) -> None: + websocket = object() + browser_state = object() + wait_for_running = AsyncMock(return_value=None) + check_finalized = AsyncMock(return_value=False) + get_current_status = AsyncMock(return_value="completed") + wait_for_browser_state_mock = AsyncMock(return_value=browser_state) + start_screencast_loop_mock = AsyncMock() + send_status_mock = AsyncMock() + monkeypatch.setattr(screenshot, "wait_for_browser_state", wait_for_browser_state_mock) + monkeypatch.setattr(screenshot, "start_screencast_loop", start_screencast_loop_mock) + monkeypatch.setattr(screenshot, "_send_status", send_status_mock) + + await screenshot._run_local_screencast( + websocket=websocket, + entity_id="task_123", + entity_type="task", + wait_for_running=wait_for_running, + check_finalized=check_finalized, + get_current_status=get_current_status, + get_workflow_run_id=lambda: "wr_123", + ) + + wait_for_running.assert_awaited_once() + wait_for_browser_state_mock.assert_awaited_once_with("task_123", "task", workflow_run_id="wr_123") + start_screencast_loop_mock.assert_awaited_once_with( + websocket=websocket, + browser_state=browser_state, + entity_id="task_123", + entity_type="task", + check_finalized=check_finalized, + ) + get_current_status.assert_awaited_once() + send_status_mock.assert_awaited_once_with(websocket, "task_id", "task_123", "completed") + + +@pytest.mark.asyncio +async def test_run_local_screencast_timeout_when_browser_state_not_available( + monkeypatch: pytest.MonkeyPatch, +) -> None: + websocket = object() + wait_for_running = AsyncMock(return_value=None) + check_finalized = AsyncMock(return_value=False) + get_current_status = AsyncMock(return_value="completed") + wait_for_browser_state_mock = AsyncMock(return_value=None) + start_screencast_loop_mock = AsyncMock() + send_status_mock = AsyncMock() + monkeypatch.setattr(screenshot, "wait_for_browser_state", wait_for_browser_state_mock) + monkeypatch.setattr(screenshot, "start_screencast_loop", start_screencast_loop_mock) + monkeypatch.setattr(screenshot, "_send_status", send_status_mock) + + await screenshot._run_local_screencast( + websocket=websocket, + entity_id="bs_123", + entity_type="browser_session", + wait_for_running=wait_for_running, + check_finalized=check_finalized, + get_current_status=get_current_status, + ) + + wait_for_running.assert_awaited_once() + wait_for_browser_state_mock.assert_awaited_once_with( + "bs_123", + "browser_session", + workflow_run_id=None, + ) + start_screencast_loop_mock.assert_not_awaited() + get_current_status.assert_not_awaited() + send_status_mock.assert_awaited_once_with(websocket, "browser_session_id", "bs_123", "timeout")