Add persist_browser_session flag to workflows (#777)

This commit is contained in:
Kerem Yilmaz 2024-09-06 12:01:56 -07:00 committed by GitHub
parent be1c8ba060
commit 95b2e53c46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 139 additions and 1 deletions

View file

@ -0,0 +1,33 @@
"""Add persist_browser_session flag to workflows
Revision ID: c50f0aa0ef24
Revises: 0de9150bc624
Create Date: 2024-09-06 18:42:42.677573+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c50f0aa0ef24"
down_revision: Union[str, None] = "0de9150bc624"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("workflows", sa.Column("persist_browser_session", sa.Boolean(), nullable=True))
op.execute("UPDATE workflows SET persist_browser_session = False WHERE persist_browser_session IS NULL")
op.alter_column("workflows", "persist_browser_session", nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflows", "persist_browser_session")
# ### end Alembic commands ###

View file

@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from skyvern import constants
from skyvern.constants import SKYVERN_DIR from skyvern.constants import SKYVERN_DIR
@ -47,6 +48,7 @@ class Settings(BaseSettings):
GENERATE_PRESIGNED_URLS: bool = False GENERATE_PRESIGNED_URLS: bool = False
AWS_S3_BUCKET_ARTIFACTS: str = "skyvern-artifacts" AWS_S3_BUCKET_ARTIFACTS: str = "skyvern-artifacts"
AWS_S3_BUCKET_SCREENSHOTS: str = "skyvern-screenshots" AWS_S3_BUCKET_SCREENSHOTS: str = "skyvern-screenshots"
AWS_S3_BUCKET_BROWSER_SESSIONS: str = "skyvern-browser-sessions"
# Supported storage types: local, s3 # Supported storage types: local, s3
SKYVERN_STORAGE_TYPE: str = "local" SKYVERN_STORAGE_TYPE: str = "local"
@ -71,6 +73,9 @@ class Settings(BaseSettings):
# streaming settings # streaming settings
STREAMING_FILE_BASE_PATH: str = "/tmp" STREAMING_FILE_BASE_PATH: str = "/tmp"
# Saved browser session settings
BROWSER_SESSION_BASE_PATH: str = f"{constants.REPO_ROOT_DIR}/browser_sessions"
##################### #####################
# Bitwarden Configs # # Bitwarden Configs #
##################### #####################

View file

@ -75,6 +75,11 @@ def zip_files(files_path: str, zip_file_path: str) -> str:
return zip_file_path return zip_file_path
def unzip_files(zip_file_path: str, output_dir: str) -> None:
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(output_dir)
def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path: def get_path_for_workflow_download_directory(workflow_run_id: str) -> Path:
return Path(f"{REPO_ROOT_DIR}/downloads/{workflow_run_id}/") return Path(f"{REPO_ROOT_DIR}/downloads/{workflow_run_id}/")

View file

@ -59,3 +59,11 @@ class BaseStorage(ABC):
@abstractmethod @abstractmethod
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None: async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
pass pass
@abstractmethod
async def store_browser_session(self, organization_id: str, workflow_permanent_id: str, directory: str) -> None:
pass
@abstractmethod
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
pass

View file

@ -1,3 +1,5 @@
import os
import shutil
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
@ -85,6 +87,38 @@ class LocalStorage(BaseStorage):
) )
return None return None
async def store_browser_session(self, organization_id: str, workflow_permanent_id: str, directory: str) -> None:
stored_folder_path = (
Path(SettingsManager.get_settings().BROWSER_SESSION_BASE_PATH) / organization_id / workflow_permanent_id
)
if directory == str(stored_folder_path):
return
self._create_directories_if_not_exists(stored_folder_path)
LOG.info(
"Storing browser session locally",
organization_id=organization_id,
workflow_permanent_id=workflow_permanent_id,
directory=directory,
browser_session_path=stored_folder_path,
)
# Copy all files from the directory to the stored folder
for root, _, files in os.walk(directory):
for file in files:
source_file_path = Path(root) / file
relative_path = source_file_path.relative_to(directory)
target_file_path = stored_folder_path / relative_path
self._create_directories_if_not_exists(target_file_path)
shutil.copy2(source_file_path, target_file_path)
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
stored_folder_path = (
Path(SettingsManager.get_settings().BROWSER_SESSION_BASE_PATH) / organization_id / workflow_permanent_id
)
if not stored_folder_path.exists():
return None
return str(stored_folder_path)
@staticmethod @staticmethod
def _parse_uri_to_path(uri: str) -> str: def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri) parsed_uri = urlparse(uri)

View file

@ -1,7 +1,10 @@
import shutil
import tempfile
from datetime import datetime from datetime import datetime
from skyvern.config import settings from skyvern.config import settings
from skyvern.forge.sdk.api.aws import AsyncAWSClient from skyvern.forge.sdk.api.aws import AsyncAWSClient
from skyvern.forge.sdk.api.files import unzip_files
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage from skyvern.forge.sdk.artifact.storage.base import FILE_EXTENTSION_MAP, BaseStorage
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
@ -40,3 +43,24 @@ class S3Storage(BaseStorage):
async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None: async def get_streaming_file(self, organization_id: str, file_name: str, use_default: bool = True) -> bytes | None:
path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}" path = f"s3://{settings.AWS_S3_BUCKET_SCREENSHOTS}/{settings.ENV}/{organization_id}/{file_name}"
return await self.async_client.download_file(path, log_exception=False) return await self.async_client.download_file(path, log_exception=False)
async def store_browser_session(self, organization_id: str, workflow_permanent_id: str, directory: str) -> None:
# Zip the directory to a temp file
temp_zip_file = tempfile.NamedTemporaryFile()
zip_file_path = shutil.make_archive(temp_zip_file.name, "zip", directory)
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
await self.async_client.upload_file_from_path(browser_session_uri, zip_file_path)
async def retrieve_browser_session(self, organization_id: str, workflow_permanent_id: str) -> str | None:
browser_session_uri = f"s3://{settings.AWS_S3_BUCKET_BROWSER_SESSIONS}/{settings.ENV}/{organization_id}/{workflow_permanent_id}.zip"
downloaded_zip_bytes = await self.async_client.download_file(browser_session_uri, log_exception=True)
if not downloaded_zip_bytes:
return None
temp_zip_file = tempfile.NamedTemporaryFile(delete=False)
temp_zip_file.write(downloaded_zip_bytes)
temp_zip_file_path = temp_zip_file.name
temp_dir = tempfile.mkdtemp(prefix="skyvern_browser_session_")
unzip_files(temp_zip_file_path, temp_dir)
temp_zip_file.close()
return temp_dir

View file

@ -819,6 +819,7 @@ class AgentDB:
proxy_location: ProxyLocation | None = None, proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None, webhook_callback_url: str | None = None,
totp_verification_url: str | None = None, totp_verification_url: str | None = None,
persist_browser_session: bool = False,
workflow_permanent_id: str | None = None, workflow_permanent_id: str | None = None,
version: int | None = None, version: int | None = None,
is_saved_task: bool = False, is_saved_task: bool = False,
@ -832,6 +833,7 @@ class AgentDB:
proxy_location=proxy_location, proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url, webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url, totp_verification_url=totp_verification_url,
persist_browser_session=persist_browser_session,
is_saved_task=is_saved_task, is_saved_task=is_saved_task,
) )
if workflow_permanent_id: if workflow_permanent_id:

View file

@ -180,6 +180,7 @@ class WorkflowModel(Base):
proxy_location = Column(Enum(ProxyLocation)) proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String) webhook_callback_url = Column(String)
totp_verification_url = Column(String) totp_verification_url = Column(String)
persist_browser_session = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column( modified_at = Column(

View file

@ -162,6 +162,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
workflow_permanent_id=workflow_model.workflow_permanent_id, workflow_permanent_id=workflow_model.workflow_permanent_id,
webhook_callback_url=workflow_model.webhook_callback_url, webhook_callback_url=workflow_model.webhook_callback_url,
totp_verification_url=workflow_model.totp_verification_url, totp_verification_url=workflow_model.totp_verification_url,
persist_browser_session=workflow_model.persist_browser_session,
proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None), proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None),
version=workflow_model.version, version=workflow_model.version,
is_saved_task=workflow_model.is_saved_task, is_saved_task=workflow_model.is_saved_task,

View file

@ -51,6 +51,7 @@ class Workflow(BaseModel):
proxy_location: ProxyLocation | None = None proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None webhook_callback_url: str | None = None
totp_verification_url: str | None = None totp_verification_url: str | None = None
persist_browser_session: bool = False
created_at: datetime created_at: datetime
modified_at: datetime modified_at: datetime

View file

@ -225,5 +225,6 @@ class WorkflowCreateYAMLRequest(BaseModel):
proxy_location: ProxyLocation | None = None proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None webhook_callback_url: str | None = None
totp_verification_url: str | None = None totp_verification_url: str | None = None
persist_browser_session: bool = False
workflow_definition: WorkflowDefinitionYAML workflow_definition: WorkflowDefinitionYAML
is_saved_task: bool = False is_saved_task: bool = False

View file

@ -286,6 +286,7 @@ class WorkflowService:
proxy_location: ProxyLocation | None = None, proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None, webhook_callback_url: str | None = None,
totp_verification_url: str | None = None, totp_verification_url: str | None = None,
persist_browser_session: bool = False,
workflow_permanent_id: str | None = None, workflow_permanent_id: str | None = None,
version: int | None = None, version: int | None = None,
is_saved_task: bool = False, is_saved_task: bool = False,
@ -298,6 +299,7 @@ class WorkflowService:
proxy_location=proxy_location, proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url, webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url, totp_verification_url=totp_verification_url,
persist_browser_session=persist_browser_session,
workflow_permanent_id=workflow_permanent_id, workflow_permanent_id=workflow_permanent_id,
version=version, version=version,
is_saved_task=is_saved_task, is_saved_task=is_saved_task,
@ -657,6 +659,7 @@ class WorkflowService:
tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id) tasks = await self.get_tasks_by_workflow_run_id(workflow_run.workflow_run_id)
all_workflow_task_ids = [task.task_id for task in tasks] all_workflow_task_ids = [task.task_id for task in tasks]
browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run( browser_state = await app.BROWSER_MANAGER.cleanup_for_workflow_run(
workflow,
workflow_run.workflow_run_id, workflow_run.workflow_run_id,
all_workflow_task_ids, all_workflow_task_ids,
close_browser_on_completion, close_browser_on_completion,
@ -826,6 +829,7 @@ class WorkflowService:
proxy_location=request.proxy_location, proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url, webhook_callback_url=request.webhook_callback_url,
totp_verification_url=request.totp_verification_url, totp_verification_url=request.totp_verification_url,
persist_browser_session=request.persist_browser_session,
workflow_permanent_id=workflow_permanent_id, workflow_permanent_id=workflow_permanent_id,
version=existing_version + 1, version=existing_version + 1,
is_saved_task=request.is_saved_task, is_saved_task=request.is_saved_task,
@ -839,6 +843,7 @@ class WorkflowService:
proxy_location=request.proxy_location, proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url, webhook_callback_url=request.webhook_callback_url,
totp_verification_url=request.totp_verification_url, totp_verification_url=request.totp_verification_url,
persist_browser_session=request.persist_browser_session,
is_saved_task=request.is_saved_task, is_saved_task=request.is_saved_task,
) )
# Create parameters from the request # Create parameters from the request

View file

@ -22,6 +22,7 @@ from skyvern.exceptions import (
UnknownBrowserType, UnknownBrowserType,
UnknownErrorWhileCreatingBrowserContext, UnknownErrorWhileCreatingBrowserContext,
) )
from skyvern.forge import app
from skyvern.forge.sdk.core.skyvern_context import current from skyvern.forge.sdk.core.skyvern_context import current
from skyvern.forge.sdk.schemas.tasks import ProxyLocation from skyvern.forge.sdk.schemas.tasks import ProxyLocation
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
@ -89,11 +90,13 @@ class BrowserContextFactory:
video_artifacts: list[VideoArtifact] | None = None, video_artifacts: list[VideoArtifact] | None = None,
har_path: str | None = None, har_path: str | None = None,
traces_dir: str | None = None, traces_dir: str | None = None,
browser_session_dir: str | None = None,
) -> BrowserArtifacts: ) -> BrowserArtifacts:
return BrowserArtifacts( return BrowserArtifacts(
video_artifacts=video_artifacts or [], video_artifacts=video_artifacts or [],
har_path=har_path, har_path=har_path,
traces_dir=traces_dir, traces_dir=traces_dir,
browser_session_dir=browser_session_dir,
) )
@classmethod @classmethod
@ -137,6 +140,7 @@ class BrowserArtifacts(BaseModel):
video_artifacts: list[VideoArtifact] = [] video_artifacts: list[VideoArtifact] = []
har_path: str | None = None har_path: str | None = None
traces_dir: str | None = None traces_dir: str | None = None
browser_session_dir: str | None = None
async def _create_headless_chromium( async def _create_headless_chromium(
@ -386,3 +390,9 @@ class BrowserState:
async def take_screenshot(self, full_page: bool = False, file_path: str | None = None) -> bytes: async def take_screenshot(self, full_page: bool = False, file_path: str | None = None) -> bytes:
page = await self.__assert_page() page = await self.__assert_page()
return await SkyvernFrame.take_screenshot(page=page, full_page=full_page, file_path=file_path) return await SkyvernFrame.take_screenshot(page=page, full_page=full_page, file_path=file_path)
async def store_browser_session(self, organization_id: str, workflow_permanent_id: str) -> None:
if self.browser_artifacts.browser_session_dir:
await app.STORAGE.store_browser_session(
organization_id, workflow_permanent_id, self.browser_artifacts.browser_session_dir
)

View file

@ -9,7 +9,7 @@ from playwright.async_api import async_playwright
from skyvern.constants import BROWSER_CLOSE_TIMEOUT from skyvern.constants import BROWSER_CLOSE_TIMEOUT
from skyvern.exceptions import MissingBrowserState from skyvern.exceptions import MissingBrowserState
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun
from skyvern.webeye.browser_factory import BrowserContextFactory, BrowserState, VideoArtifact from skyvern.webeye.browser_factory import BrowserContextFactory, BrowserState, VideoArtifact
LOG = structlog.get_logger() LOG = structlog.get_logger()
@ -182,6 +182,7 @@ class BrowserManager:
async def cleanup_for_workflow_run( async def cleanup_for_workflow_run(
self, self,
workflow: Workflow,
workflow_run_id: str, workflow_run_id: str,
task_ids: list[str], task_ids: list[str],
close_browser_on_completion: bool = True, close_browser_on_completion: bool = True,
@ -195,6 +196,13 @@ class BrowserManager:
await browser_state_to_close.browser_context.tracing.stop(path=trace_path) await browser_state_to_close.browser_context.tracing.stop(path=trace_path)
LOG.info("Stopped tracing", trace_path=trace_path) LOG.info("Stopped tracing", trace_path=trace_path)
if workflow.persist_browser_session:
await browser_state_to_close.store_browser_session(
organization_id=workflow.organization_id,
workflow_permanent_id=workflow.workflow_permanent_id,
)
LOG.info("Persisted browser session for workflow run", workflow_run_id=workflow_run_id)
await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion) await browser_state_to_close.close(close_browser_on_completion=close_browser_on_completion)
for task_id in task_ids: for task_id in task_ids:
self.pages.pop(task_id, None) self.pages.pop(task_id, None)