Skyvern/skyvern/forge/sdk/routes/streaming/channels/message.py
Andrew Neilson 328bce3cdd
feat: add browser session support to OSS Docker deployment (#4891)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
2026-03-03 23:58:26 -08:00

613 lines
19 KiB
Python

"""
A channel for streaming whole messages between our frontend and our API server.
This channel can access a persistent browser instance through the execution channel.
What this channel looks like:
[Skyvern App] <--> [API Server]
Channel data:
JSON over WebSockets. Semantics are fire and forget. Req-resp is built on
top of that using message types.
"""
import asyncio
import dataclasses
import enum
import typing as t
import structlog
from fastapi import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState
from websockets.exceptions import ConnectionClosedError
from skyvern.forge.sdk.routes.streaming.channels.execution import execution_channel
from skyvern.forge.sdk.routes.streaming.channels.exfiltration import ExfiltratedEvent, ExfiltrationChannel
from skyvern.forge.sdk.routes.streaming.registries import (
add_message_channel,
del_message_channel,
get_vnc_channel,
)
from skyvern.forge.sdk.routes.streaming.verify import (
loop_verify_browser_session,
loop_verify_workflow_run,
verify_browser_session,
verify_workflow_run,
)
from skyvern.forge.sdk.schemas.persistent_browser_sessions import AddressablePersistentBrowserSession
from skyvern.forge.sdk.utils.aio import collect
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
LOG = structlog.get_logger()
Loops = list[asyncio.Task] # aka "queue-less actors"; or "programs"
class MessageKind(enum.StrEnum):
ASK_FOR_CLIPBOARD_RESPONSE = "ask-for-clipboard-response"
BEGIN_EXFILTRATION = "begin-exfiltration"
BROWSER_TABS = "browser-tabs"
CEDE_CONTROL = "cede-control"
END_EXFILTRATION = "end-exfiltration"
EXFILTRATED_EVENT = "exfiltrated-event"
TAKE_CONTROL = "take-control"
class ExfiltratedEventSource(enum.StrEnum):
CONSOLE = "console"
CDP = "cdp"
NOT_SPECIFIED = "[not-specified]"
@dataclasses.dataclass
class TabInfo:
id: str
title: str
url: str
# --
active: bool = False
favicon: str | None = None
isReady: bool = True
pageNumber: int | None = None
MessageKinds = t.Literal[
MessageKind.ASK_FOR_CLIPBOARD_RESPONSE,
MessageKind.BEGIN_EXFILTRATION,
MessageKind.BROWSER_TABS,
MessageKind.CEDE_CONTROL,
MessageKind.END_EXFILTRATION,
MessageKind.EXFILTRATED_EVENT,
MessageKind.TAKE_CONTROL,
]
@dataclasses.dataclass
class Message:
kind: MessageKinds
@dataclasses.dataclass
class MessageInBeginExfiltration(Message):
kind: t.Literal[MessageKind.BEGIN_EXFILTRATION] = MessageKind.BEGIN_EXFILTRATION
@dataclasses.dataclass
class MessageInEndExfiltration(Message):
kind: t.Literal[MessageKind.END_EXFILTRATION] = MessageKind.END_EXFILTRATION
@dataclasses.dataclass
class MessageInTakeControl(Message):
kind: t.Literal[MessageKind.TAKE_CONTROL] = MessageKind.TAKE_CONTROL
@dataclasses.dataclass
class MessageInCedeControl(Message):
kind: t.Literal[MessageKind.CEDE_CONTROL] = MessageKind.CEDE_CONTROL
@dataclasses.dataclass
class MessageInAskForClipboardResponse(Message):
kind: t.Literal[MessageKind.ASK_FOR_CLIPBOARD_RESPONSE] = MessageKind.ASK_FOR_CLIPBOARD_RESPONSE
text: str = ""
@dataclasses.dataclass
class MessageOutExfiltratedEvent(Message):
kind: t.Literal[MessageKind.EXFILTRATED_EVENT] = MessageKind.EXFILTRATED_EVENT
event_name: str = "[not-specified]"
# TODO(jdo): improve typing for params
params: dict = dataclasses.field(default_factory=dict)
source: ExfiltratedEventSource = ExfiltratedEventSource.NOT_SPECIFIED
timestamp: float = dataclasses.field(default_factory=lambda: 0.0) # seconds since epoch
@dataclasses.dataclass
class MessageOutTabInfo(Message):
kind: t.Literal[MessageKind.BROWSER_TABS] = MessageKind.BROWSER_TABS
tabs: list[TabInfo] = dataclasses.field(default_factory=list)
MessageIn = (
MessageInAskForClipboardResponse
| MessageInBeginExfiltration
| MessageInCedeControl
| MessageInEndExfiltration
| MessageInTakeControl
)
MessageOut = MessageOutExfiltratedEvent | MessageOutTabInfo
ChannelMessage = MessageIn | MessageOut
def reify_channel_message(data: dict) -> ChannelMessage:
kind = data.get("kind", None)
match kind:
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
text = data.get("text") or ""
return MessageInAskForClipboardResponse(text=text)
case MessageKind.BEGIN_EXFILTRATION:
return MessageInBeginExfiltration()
case MessageKind.CEDE_CONTROL:
return MessageInCedeControl()
case MessageKind.END_EXFILTRATION:
return MessageInEndExfiltration()
case MessageKind.TAKE_CONTROL:
return MessageInTakeControl()
case _:
raise ValueError(f"Unknown message kind: '{kind}'")
def message_to_dict(message: MessageOut) -> dict:
"""
Convert message to dict with enums as their values.
"""
def convert_value(obj: t.Any) -> t.Any:
if isinstance(obj, enum.Enum):
return obj.value
return obj
return dataclasses.asdict(message, dict_factory=lambda x: {k: convert_value(v) for k, v in x})
@dataclasses.dataclass
class MessageChannel:
"""
A message channel for streaming JSON messages between our frontend and our API server.
"""
client_id: str
organization_id: str
websocket: WebSocket
# --
out_queue: asyncio.Queue[MessageOut] = dataclasses.field(default_factory=asyncio.Queue) # warn: unbounded
browser_session: AddressablePersistentBrowserSession | None = None
workflow_run: WorkflowRun | None = None
def __post_init__(self) -> None:
add_message_channel(self)
@property
def class_name(self) -> str:
return self.__class__.__name__
@property
def identity(self) -> dict[str, str]:
base = {"organization_id": self.organization_id}
if self.browser_session:
return base | {"browser_session_id": self.browser_session.persistent_browser_session_id}
if self.workflow_run:
return base | {"workflow_run_id": self.workflow_run.workflow_run_id}
return base
async def close(self, code: int = 1000, reason: str | None = None) -> "MessageChannel":
LOG.info(f"{self.class_name} closing message stream.", reason=reason, code=code, **self.identity)
self.browser_session = None
self.workflow_run = None
try:
await self.websocket.close(code=code, reason=reason)
except Exception:
pass
del_message_channel(self.client_id, expected=self)
return self
@property
def is_open(self) -> bool:
if self.websocket.client_state != WebSocketState.CONNECTED:
return False
return True
async def drain(self) -> list[dict | MessageOut]:
datums: list[dict | MessageOut] = []
result = await asyncio.gather(
self.receive_from_out_queue(),
self.receive_from_user(),
)
# NOTE(jdo): mypy seems to be unable to infer this, whereas pylance has
# no issue; added explicit type hints here to help mypy out.
out_queue: list[MessageOut] = result[0]
in_queue: list[dict] = result[1]
for out_message in out_queue:
datums.append(out_message)
for in_message in in_queue:
if isinstance(in_message, dict):
datums.append(in_message)
else:
LOG.error(
f"{self.class_name} drain dropping user message: unexpected result type: {type(in_message)}",
message=in_message,
**self.identity,
)
if datums:
LOG.debug(f"{self.class_name} Drained {len(datums)} messages from message channel.", **self.identity)
return datums
async def receive_from_user(self) -> list[dict]:
datums: list[dict] = []
while True:
try:
data = await asyncio.wait_for(self.websocket.receive_json(), timeout=0.001)
datums.append(data)
except asyncio.TimeoutError:
break
except RuntimeError as ex:
if "not connected" in str(ex).lower():
break
except WebSocketDisconnect:
LOG.warning(f"{self.class_name} Disconnected while receiving message from channel", **self.identity)
break
except Exception:
LOG.exception(f"{self.class_name} Failed to receive message from message channel", **self.identity)
break
return datums
async def receive_from_out_queue(self) -> list[MessageOut]:
datums: list[MessageOut] = []
while True:
try:
data = await asyncio.wait_for(self.out_queue.get(), timeout=0.001)
datums.append(data)
except asyncio.TimeoutError:
break
except asyncio.QueueEmpty:
break
return datums
def receive_from_out_queue_nowait(self) -> list[MessageOut]:
datums: list[MessageOut] = []
while True:
try:
data = self.out_queue.get_nowait()
datums.append(data)
except asyncio.QueueEmpty:
break
return datums
# async def send(self, *, messages: list[dict]) -> t.Self:
async def send(self, *, messages: list[MessageOut]) -> t.Self:
for message in messages:
await self.out_queue.put(message)
return self
def send_nowait(self, *, messages: list[MessageOut]) -> t.Self:
for message in messages:
self.out_queue.put_nowait(message)
return self
async def ask_for_clipboard(self) -> None:
LOG.info(f"{self.class_name} Sending ask-for-clipboard to message channel", **self.identity)
try:
await self.websocket.send_json(
{
"kind": "ask-for-clipboard",
}
)
except Exception:
LOG.exception(f"{self.class_name} Failed to send ask-for-clipboard to message channel", **self.identity)
async def send_copied_text(self, copied_text: str) -> None:
LOG.info(f"{self.class_name} Sending copied text to message channel", **self.identity)
try:
await self.websocket.send_json(
{
"kind": "copied-text",
"text": copied_text,
}
)
except Exception:
LOG.exception(f"{self.class_name} Failed to send copied text to message channel", **self.identity)
async def loop_stream_messages(message_channel: MessageChannel) -> None:
"""
Stream messages and their results back and forth.
Loops until the websocket is closed.
"""
class_name = message_channel.class_name
exfiltration_channel: ExfiltrationChannel | None = None
async def send(message: MessageOut) -> None:
if message_channel.websocket.client_state != WebSocketState.CONNECTED:
return
data = message_to_dict(message)
try:
await message_channel.websocket.send_json(data)
except WebSocketDisconnect:
pass
except Exception:
LOG.exception("MessageChannel: failed to send data.")
async def handle_data(data: dict | MessageOut) -> None:
nonlocal class_name
nonlocal exfiltration_channel
message: ChannelMessage
if isinstance(data, MessageOut):
message = data
elif isinstance(data, dict):
try:
message = reify_channel_message(data)
except ValueError:
LOG.error(f"MessageChannel: cannot reify channel message from data: {data}", **message_channel.identity)
return
else:
LOG.error(
f"{class_name} cannot handle data: expected dict or MessageOut, got {type(data)}",
**message_channel.identity,
)
return
match message.kind:
case MessageKind.ASK_FOR_CLIPBOARD_RESPONSE:
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel found for message channel.",
message=message,
**message_channel.identity,
)
return
text = message.text
async with execution_channel(vnc_channel) as execute:
await execute.paste_text(text)
case MessageKind.BEGIN_EXFILTRATION:
if exfiltration_channel is not None:
LOG.error(
"MessageChannel: cannot begin exfiltration: already active.", message_channel=message_channel
)
return
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel client found for message channel - cannot exfiltrate.",
message=message,
**message_channel.identity,
)
return
def on_event(events: list[ExfiltratedEvent]) -> None:
for event in events:
message_out_exfiltrated_event = MessageOutExfiltratedEvent(
kind=t.cast(t.Literal[MessageKind.EXFILTRATED_EVENT], event.kind),
event_name=event.event_name,
params=event.params,
source=t.cast(ExfiltratedEventSource, event.source or ExfiltratedEventSource.NOT_SPECIFIED),
timestamp=event.timestamp,
)
message_channel.send_nowait(messages=[message_out_exfiltrated_event])
exfiltration_channel = await ExfiltrationChannel(
on_event=on_event,
vnc_channel=vnc_channel,
).start()
case MessageKind.BROWSER_TABS:
await send(message)
case MessageKind.CEDE_CONTROL:
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel client found for message channel.",
message=message,
**message_channel.identity,
)
return
vnc_channel.interactor = "agent"
case MessageKind.END_EXFILTRATION:
if exfiltration_channel is None:
return
await exfiltration_channel.stop()
exfiltration_channel = None
case MessageKind.EXFILTRATED_EVENT:
await send(message)
# case MessageKind.GET_TAB_INFO:
# """
# TODO(jdo): implement - this is an on-demand request for tab info, which is
# required when connecting to an existing browser session.
# """
case MessageKind.TAKE_CONTROL:
LOG.info(f"{class_name} processing take-control message.", **message_channel.identity)
vnc_channel = get_vnc_channel(message_channel.client_id)
if not vnc_channel:
LOG.error(
f"{class_name} no vnc channel client found for message channel.",
message=message,
**message_channel.identity,
)
return
vnc_channel.interactor = "user"
case _:
t.assert_never(message.kind)
async def frontend_to_backend() -> None:
nonlocal class_name
LOG.info(f"{class_name} starting frontend-to-backend loop.", **message_channel.identity)
while message_channel.is_open:
try:
datums = await message_channel.drain()
for data in datums:
if not isinstance(data, (dict, MessageOut)):
LOG.error(
f"{class_name} cannot handle message: expected dict or MessageOut, got {type(data)}",
**message_channel.identity,
)
continue
await handle_data(data)
except WebSocketDisconnect:
LOG.info(f"{class_name} frontend disconnected.", **message_channel.identity)
raise
except ConnectionClosedError:
LOG.info(f"{class_name} frontend closed channel.", **message_channel.identity)
raise
except Exception:
LOG.exception(f"{class_name} An unexpected exception occurred.", **message_channel.identity)
raise
loops = [
asyncio.create_task(frontend_to_backend()),
]
try:
await collect(loops)
except Exception:
LOG.exception(f"{class_name} An exception occurred in loop message channel stream.", **message_channel.identity)
finally:
LOG.info(f"{class_name} Closing the message channel stream.", **message_channel.identity)
await message_channel.close(reason="loop-channel-closed")
async def get_message_channel_for_browser_session(
client_id: str,
browser_session_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[MessageChannel, Loops] | None:
"""
Return a message channel for a browser session, with a list of loops to run concurrently.
"""
browser_session = await verify_browser_session(
browser_session_id=browser_session_id,
organization_id=organization_id,
)
if not browser_session:
return None
message_channel = MessageChannel(
client_id=client_id,
organization_id=organization_id,
browser_session=browser_session,
websocket=websocket,
)
loops = [
asyncio.create_task(loop_verify_browser_session(message_channel)),
asyncio.create_task(loop_stream_messages(message_channel)),
]
return message_channel, loops
async def get_message_channel_for_workflow_run(
client_id: str,
workflow_run_id: str,
organization_id: str,
websocket: WebSocket,
) -> tuple[MessageChannel, Loops] | None:
"""
Return a message channel for a workflow run, with a list of loops to run concurrently.
"""
LOG.info("Getting message channel for workflow run.", workflow_run_id=workflow_run_id)
workflow_run, browser_session = await verify_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
LOG.info(
"Message channel: no initial workflow run found.",
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
return None
if not browser_session:
return None
message_channel = MessageChannel(
client_id,
organization_id,
browser_session=browser_session,
websocket=websocket,
workflow_run=workflow_run,
)
LOG.info("Got message channel for workflow run.", message_channel=message_channel)
loops = [
asyncio.create_task(loop_verify_workflow_run(message_channel)),
asyncio.create_task(loop_stream_messages(message_channel)),
]
return message_channel, loops