mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-15 09:49:46 +00:00
Add pyupgrade pre-commit hook + modernize python code (#2611)
This commit is contained in:
parent
272985f1bb
commit
effd0c4911
18 changed files with 47 additions and 45 deletions
|
@ -51,6 +51,14 @@ repos:
|
||||||
- id: python-check-mock-methods
|
- id: python-check-mock-methods
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
|
rev: v3.20.0
|
||||||
|
hooks:
|
||||||
|
- id: pyupgrade
|
||||||
|
exclude: |
|
||||||
|
(?x)(
|
||||||
|
^skyvern/client/.*
|
||||||
|
)
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.0
|
rev: v1.16.0
|
||||||
hooks:
|
hooks:
|
||||||
|
@ -61,7 +69,7 @@ repos:
|
||||||
- types-requests
|
- types-requests
|
||||||
- types-cachetools
|
- types-cachetools
|
||||||
- alembic
|
- alembic
|
||||||
- "sqlalchemy[mypy]"
|
- 'sqlalchemy[mypy]'
|
||||||
- types-PyYAML
|
- types-PyYAML
|
||||||
- types-aiofiles
|
- types-aiofiles
|
||||||
exclude: |
|
exclude: |
|
||||||
|
@ -91,7 +99,7 @@ repos:
|
||||||
# pass_filenames: false
|
# pass_filenames: false
|
||||||
# always_run: true
|
# always_run: true
|
||||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||||
rev: "v4.0.0-alpha.8" # Use the sha or tag you want to point at
|
rev: 'v4.0.0-alpha.8' # Use the sha or tag you want to point at
|
||||||
hooks:
|
hooks:
|
||||||
- id: prettier
|
- id: prettier
|
||||||
types: [javascript]
|
types: [javascript]
|
||||||
|
|
|
@ -21,13 +21,13 @@ class WorkflowRunResultRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def load_webvoyager_case_from_json(file_path: str, group_id: str = "") -> Iterator[WebVoyagerTestCase]:
|
def load_webvoyager_case_from_json(file_path: str, group_id: str = "") -> Iterator[WebVoyagerTestCase]:
|
||||||
with open("evaluation/datasets/webvoyager_reference_answer.json", "r") as answer_file:
|
with open("evaluation/datasets/webvoyager_reference_answer.json") as answer_file:
|
||||||
webvoyager_answers: dict = json.load(answer_file)
|
webvoyager_answers: dict = json.load(answer_file)
|
||||||
|
|
||||||
if not group_id:
|
if not group_id:
|
||||||
group_id = str(uuid4())
|
group_id = str(uuid4())
|
||||||
|
|
||||||
with open(file_path, "r", encoding="utf-8") as file:
|
with open(file_path, encoding="utf-8") as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
test_case: dict[str, str] = json.loads(line)
|
test_case: dict[str, str] = json.loads(line)
|
||||||
web_name, id = test_case["id"].split("--")
|
web_name, id = test_case["id"].split("--")
|
||||||
|
@ -47,7 +47,7 @@ def load_webvoyager_case_from_json(file_path: str, group_id: str = "") -> Iterat
|
||||||
|
|
||||||
|
|
||||||
def load_records_from_json(file_path: str) -> Iterator[WorkflowRunResultRequest]:
|
def load_records_from_json(file_path: str) -> Iterator[WorkflowRunResultRequest]:
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
item: dict[str, str] = json.loads(line)
|
item: dict[str, str] = json.loads(line)
|
||||||
id = item["id"]
|
id = item["id"]
|
||||||
|
|
|
@ -35,7 +35,7 @@ def main(
|
||||||
) -> None:
|
) -> None:
|
||||||
client = SkyvernClient(base_url=base_url, credentials=cred)
|
client = SkyvernClient(base_url=base_url, credentials=cred)
|
||||||
|
|
||||||
with open(record_json_path, "r", encoding="utf-8") as file:
|
with open(record_json_path, encoding="utf-8") as file:
|
||||||
with open(output_csv_path, newline="", mode="w", encoding="utf-8") as csv_file:
|
with open(output_csv_path, newline="", mode="w", encoding="utf-8") as csv_file:
|
||||||
writer = csv.DictWriter(csv_file, fieldnames=csv_headers)
|
writer = csv.DictWriter(csv_file, fieldnames=csv_headers)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
|
@ -79,7 +79,7 @@ async def run_eval(
|
||||||
) -> None:
|
) -> None:
|
||||||
client = SkyvernClient(base_url=base_url, credentials=cred)
|
client = SkyvernClient(base_url=base_url, credentials=cred)
|
||||||
|
|
||||||
with open(record_json_path, "r", encoding="utf-8") as file:
|
with open(record_json_path, encoding="utf-8") as file:
|
||||||
with open(output_csv_path, newline="", mode="w", encoding="utf-8") as csv_file:
|
with open(output_csv_path, newline="", mode="w", encoding="utf-8") as csv_file:
|
||||||
writer = csv.DictWriter(csv_file, fieldnames=csv_headers)
|
writer = csv.DictWriter(csv_file, fieldnames=csv_headers)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import platform
|
import platform
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
|
@ -144,7 +144,7 @@ def setup_claude_desktop_config(host_system: str, path_to_env: str) -> bool:
|
||||||
claude_config: dict = {"mcpServers": {}}
|
claude_config: dict = {"mcpServers": {}}
|
||||||
if os.path.exists(path_claude_config):
|
if os.path.exists(path_claude_config):
|
||||||
try:
|
try:
|
||||||
with open(path_claude_config, "r") as f:
|
with open(path_claude_config) as f:
|
||||||
claude_config = json.load(f)
|
claude_config = json.load(f)
|
||||||
claude_config["mcpServers"].pop("Skyvern", None)
|
claude_config["mcpServers"].pop("Skyvern", None)
|
||||||
claude_config["mcpServers"]["Skyvern"] = {
|
claude_config["mcpServers"]["Skyvern"] = {
|
||||||
|
@ -196,7 +196,7 @@ def setup_cursor_config(host_system: str, path_to_env: str) -> bool:
|
||||||
cursor_config: dict = {"mcpServers": {}}
|
cursor_config: dict = {"mcpServers": {}}
|
||||||
if os.path.exists(path_cursor_config):
|
if os.path.exists(path_cursor_config):
|
||||||
try:
|
try:
|
||||||
with open(path_cursor_config, "r") as f:
|
with open(path_cursor_config) as f:
|
||||||
cursor_config = json.load(f)
|
cursor_config = json.load(f)
|
||||||
cursor_config["mcpServers"].pop("Skyvern", None)
|
cursor_config["mcpServers"].pop("Skyvern", None)
|
||||||
cursor_config["mcpServers"]["Skyvern"] = {
|
cursor_config["mcpServers"]["Skyvern"] = {
|
||||||
|
@ -245,7 +245,7 @@ def setup_windsurf_config(host_system: str, path_to_env: str) -> bool:
|
||||||
windsurf_config: dict = {"mcpServers": {}}
|
windsurf_config: dict = {"mcpServers": {}}
|
||||||
if os.path.exists(path_windsurf_config):
|
if os.path.exists(path_windsurf_config):
|
||||||
try:
|
try:
|
||||||
with open(path_windsurf_config, "r") as f:
|
with open(path_windsurf_config) as f:
|
||||||
windsurf_config = json.load(f)
|
windsurf_config = json.load(f)
|
||||||
windsurf_config["mcpServers"].pop("Skyvern", None)
|
windsurf_config["mcpServers"].pop("Skyvern", None)
|
||||||
windsurf_config["mcpServers"]["Skyvern"] = {
|
windsurf_config["mcpServers"]["Skyvern"] = {
|
||||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -21,7 +20,7 @@ tasks_app = typer.Typer(help="Manage Skyvern tasks and operations.")
|
||||||
@tasks_app.callback()
|
@tasks_app.callback()
|
||||||
def tasks_callback(
|
def tasks_callback(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
api_key: Optional[str] = typer.Option(
|
api_key: str | None = typer.Option(
|
||||||
None,
|
None,
|
||||||
"--api-key",
|
"--api-key",
|
||||||
help="Skyvern API key",
|
help="Skyvern API key",
|
||||||
|
@ -32,7 +31,7 @@ def tasks_callback(
|
||||||
ctx.obj = {"api_key": api_key}
|
ctx.obj = {"api_key": api_key}
|
||||||
|
|
||||||
|
|
||||||
def _get_client(api_key: Optional[str] = None) -> Skyvern:
|
def _get_client(api_key: str | None = None) -> Skyvern:
|
||||||
"""Instantiate a Skyvern SDK client using environment variables."""
|
"""Instantiate a Skyvern SDK client using environment variables."""
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
load_dotenv(".env")
|
load_dotenv(".env")
|
||||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -22,7 +21,7 @@ workflow_app = typer.Typer(help="Manage Skyvern workflows.")
|
||||||
@workflow_app.callback()
|
@workflow_app.callback()
|
||||||
def workflow_callback(
|
def workflow_callback(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
api_key: Optional[str] = typer.Option(
|
api_key: str | None = typer.Option(
|
||||||
None,
|
None,
|
||||||
"--api-key",
|
"--api-key",
|
||||||
help="Skyvern API key",
|
help="Skyvern API key",
|
||||||
|
@ -33,7 +32,7 @@ def workflow_callback(
|
||||||
ctx.obj = {"api_key": api_key}
|
ctx.obj = {"api_key": api_key}
|
||||||
|
|
||||||
|
|
||||||
def _get_client(api_key: Optional[str] = None) -> Skyvern:
|
def _get_client(api_key: str | None = None) -> Skyvern:
|
||||||
"""Instantiate a Skyvern SDK client using environment variables."""
|
"""Instantiate a Skyvern SDK client using environment variables."""
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
load_dotenv(".env")
|
load_dotenv(".env")
|
||||||
|
@ -46,8 +45,8 @@ def start_workflow(
|
||||||
ctx: typer.Context,
|
ctx: typer.Context,
|
||||||
workflow_id: str = typer.Argument(..., help="Workflow permanent ID"),
|
workflow_id: str = typer.Argument(..., help="Workflow permanent ID"),
|
||||||
parameters: str = typer.Option("{}", "--parameters", "-p", help="JSON parameters for the workflow"),
|
parameters: str = typer.Option("{}", "--parameters", "-p", help="JSON parameters for the workflow"),
|
||||||
title: Optional[str] = typer.Option(None, "--title", help="Title for the workflow run"),
|
title: str | None = typer.Option(None, "--title", help="Title for the workflow run"),
|
||||||
max_steps: Optional[int] = typer.Option(None, "--max-steps", help="Override the workflow max steps"),
|
max_steps: int | None = typer.Option(None, "--max-steps", help="Override the workflow max steps"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Dispatch a workflow run."""
|
"""Dispatch a workflow run."""
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -255,7 +255,7 @@ class AsyncAWSClient:
|
||||||
return await client.deregister_task_definition(taskDefinition=task_definition)
|
return await client.deregister_task_definition(taskDefinition=task_definition)
|
||||||
|
|
||||||
|
|
||||||
class S3Uri(object):
|
class S3Uri:
|
||||||
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
|
# From: https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
|
||||||
"""
|
"""
|
||||||
>>> s = S3Uri("s3://bucket/hello/world")
|
>>> s = S3Uri("s3://bucket/hello/world")
|
||||||
|
|
|
@ -38,7 +38,7 @@ class LocalStorage(BaseStorage):
|
||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
with open(file_path, "r") as f:
|
with open(file_path) as f:
|
||||||
return [line.strip() for line in f.readlines() if line.strip()]
|
return [line.strip() for line in f.readlines() if line.strip()]
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -211,7 +211,7 @@ class Block(BaseModel, abc.ABC):
|
||||||
return template.render(template_data)
|
return template.render(template_data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_subclasses(cls) -> tuple[type["Block"], ...]:
|
def get_subclasses(cls) -> tuple[type[Block], ...]:
|
||||||
return tuple(cls.__subclasses__())
|
return tuple(cls.__subclasses__())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -2123,7 +2123,7 @@ class FileParserBlock(Block):
|
||||||
def validate_file_type(self, file_url_used: str, file_path: str) -> None:
|
def validate_file_type(self, file_url_used: str, file_path: str) -> None:
|
||||||
if self.file_type == FileType.CSV:
|
if self.file_type == FileType.CSV:
|
||||||
try:
|
try:
|
||||||
with open(file_path, "r") as file:
|
with open(file_path) as file:
|
||||||
csv.Sniffer().sniff(file.read(1024))
|
csv.Sniffer().sniff(file.read(1024))
|
||||||
except csv.Error as e:
|
except csv.Error as e:
|
||||||
raise InvalidFileType(file_url=file_url_used, file_type=self.file_type, error=str(e))
|
raise InvalidFileType(file_url=file_url_used, file_type=self.file_type, error=str(e))
|
||||||
|
@ -2172,7 +2172,7 @@ class FileParserBlock(Block):
|
||||||
self.validate_file_type(self.file_url, file_path)
|
self.validate_file_type(self.file_url, file_path)
|
||||||
# Parse the file into a list of dictionaries where each dictionary represents a row in the file
|
# Parse the file into a list of dictionaries where each dictionary represents a row in the file
|
||||||
parsed_data = []
|
parsed_data = []
|
||||||
with open(file_path, "r") as file:
|
with open(file_path) as file:
|
||||||
if self.file_type == FileType.CSV:
|
if self.file_type == FileType.CSV:
|
||||||
reader = csv.DictReader(file)
|
reader = csv.DictReader(file)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
|
|
|
@ -28,7 +28,7 @@ def detect_os() -> str:
|
||||||
system = platform.system()
|
system = platform.system()
|
||||||
if system == "Linux":
|
if system == "Linux":
|
||||||
try:
|
try:
|
||||||
with open("/proc/version", "r") as f:
|
with open("/proc/version") as f:
|
||||||
version_info = f.read().lower()
|
version_info = f.read().lower()
|
||||||
if "microsoft" in version_info:
|
if "microsoft" in version_info:
|
||||||
return "wsl"
|
return "wsl"
|
||||||
|
|
|
@ -25,5 +25,5 @@ def get_json_from_file(file_path: str) -> dict[str, str]:
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
with open(file_path, "r") as json_file:
|
with open(file_path) as json_file:
|
||||||
return json.load(json_file)
|
return json.load(json_file)
|
||||||
|
|
|
@ -162,7 +162,7 @@ class BrowserContextFactory:
|
||||||
preference_template = f"{SKYVERN_DIR}/webeye/chromium_preferences.json"
|
preference_template = f"{SKYVERN_DIR}/webeye/chromium_preferences.json"
|
||||||
|
|
||||||
preference_file_content = ""
|
preference_file_content = ""
|
||||||
with open(preference_template, "r") as f:
|
with open(preference_template) as f:
|
||||||
preference_file_content = f.read()
|
preference_file_content = f.read()
|
||||||
preference_file_content = preference_file_content.replace("MASK_SAVEFILE_DEFAULT_DIRECTORY", download_dir)
|
preference_file_content = preference_file_content.replace("MASK_SAVEFILE_DEFAULT_DIRECTORY", download_dir)
|
||||||
preference_file_content = preference_file_content.replace("MASK_DOWNLOAD_DEFAULT_DIRECTORY", download_dir)
|
preference_file_content = preference_file_content.replace("MASK_DOWNLOAD_DEFAULT_DIRECTORY", download_dir)
|
||||||
|
@ -281,7 +281,7 @@ class BrowserContextFactory:
|
||||||
class VideoArtifact(BaseModel):
|
class VideoArtifact(BaseModel):
|
||||||
video_path: str | None = None
|
video_path: str | None = None
|
||||||
video_artifact_id: str | None = None
|
video_artifact_id: str | None = None
|
||||||
video_data: bytes = bytes()
|
video_data: bytes = b""
|
||||||
|
|
||||||
|
|
||||||
class BrowserArtifacts(BaseModel):
|
class BrowserArtifacts(BaseModel):
|
||||||
|
@ -385,7 +385,7 @@ def _is_port_in_use(port: int) -> bool:
|
||||||
try:
|
try:
|
||||||
s.bind(("localhost", port))
|
s.bind(("localhost", port))
|
||||||
return False
|
return False
|
||||||
except socket.error:
|
except OSError:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from playwright._impl._errors import TargetClosedError
|
from playwright._impl._errors import TargetClosedError
|
||||||
|
@ -22,7 +21,7 @@ class BrowserSession:
|
||||||
|
|
||||||
class PersistentSessionsManager:
|
class PersistentSessionsManager:
|
||||||
instance: PersistentSessionsManager | None = None
|
instance: PersistentSessionsManager | None = None
|
||||||
_browser_sessions: Dict[str, BrowserSession] = dict()
|
_browser_sessions: dict[str, BrowserSession] = dict()
|
||||||
database: AgentDB
|
database: AgentDB
|
||||||
|
|
||||||
def __new__(cls, database: AgentDB) -> PersistentSessionsManager:
|
def __new__(cls, database: AgentDB) -> PersistentSessionsManager:
|
||||||
|
@ -82,7 +81,7 @@ class PersistentSessionsManager:
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_network_info(self, session_id: str) -> Tuple[Optional[int], Optional[str]]:
|
async def get_network_info(self, session_id: str) -> tuple[int | None, str | None]:
|
||||||
"""Returns cdp port and ip address of the browser session"""
|
"""Returns cdp port and ip address of the browser session"""
|
||||||
browser_session = self._browser_sessions.get(session_id)
|
browser_session = self._browser_sessions.get(session_id)
|
||||||
if browser_session:
|
if browser_session:
|
||||||
|
|
|
@ -75,7 +75,7 @@ def load_js_script() -> str:
|
||||||
try:
|
try:
|
||||||
# TODO: Implement TS of domUtils.js and use the complied JS file instead of the raw JS file.
|
# TODO: Implement TS of domUtils.js and use the complied JS file instead of the raw JS file.
|
||||||
# This will allow our code to be type safe.
|
# This will allow our code to be type safe.
|
||||||
with open(path, "r") as f:
|
with open(path) as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
LOG.exception("Failed to load the JS script", path=path)
|
LOG.exception("Failed to load the JS script", path=path)
|
||||||
|
|
|
@ -34,9 +34,7 @@ TEXT_INPUT_DELAY = 10 # 10ms between each character input
|
||||||
TEXT_PRESS_MAX_LENGTH = 20
|
TEXT_PRESS_MAX_LENGTH = 20
|
||||||
|
|
||||||
|
|
||||||
async def resolve_locator(
|
async def resolve_locator(scrape_page: ScrapedPage, page: Page, frame: str, css: str) -> tuple[Locator, Page | Frame]:
|
||||||
scrape_page: ScrapedPage, page: Page, frame: str, css: str
|
|
||||||
) -> typing.Tuple[Locator, Page | Frame]:
|
|
||||||
iframe_path: list[str] = []
|
iframe_path: list[str] = []
|
||||||
|
|
||||||
while frame != "main.frame":
|
while frame != "main.frame":
|
||||||
|
@ -335,10 +333,10 @@ class SkyvernElement:
|
||||||
def get_frame_id(self) -> str:
|
def get_frame_id(self) -> str:
|
||||||
return self._frame_id
|
return self._frame_id
|
||||||
|
|
||||||
def get_attributes(self) -> typing.Dict:
|
def get_attributes(self) -> dict:
|
||||||
return self._attributes
|
return self._attributes
|
||||||
|
|
||||||
def get_options(self) -> typing.List[SkyvernOptionType]:
|
def get_options(self) -> list[SkyvernOptionType]:
|
||||||
options = self.__static_element.get("options", None)
|
options = self.__static_element.get("options", None)
|
||||||
if options is None:
|
if options is None:
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from playwright._impl._errors import TimeoutError
|
from playwright._impl._errors import TimeoutError
|
||||||
|
@ -21,7 +21,7 @@ def load_js_script() -> str:
|
||||||
try:
|
try:
|
||||||
# TODO: Implement TS of domUtils.js and use the complied JS file instead of the raw JS file.
|
# TODO: Implement TS of domUtils.js and use the complied JS file instead of the raw JS file.
|
||||||
# This will allow our code to be type safe.
|
# This will allow our code to be type safe.
|
||||||
with open(path, "r") as f:
|
with open(path) as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
LOG.exception("Failed to load the JS script", path=path)
|
LOG.exception("Failed to load the JS script", path=path)
|
||||||
|
@ -43,7 +43,7 @@ async def _current_viewpoint_screenshot_helper(
|
||||||
await page.wait_for_load_state(timeout=settings.BROWSER_LOADING_TIMEOUT_MS)
|
await page.wait_for_load_state(timeout=settings.BROWSER_LOADING_TIMEOUT_MS)
|
||||||
LOG.debug("Page is fully loaded, agent is about to take screenshots")
|
LOG.debug("Page is fully loaded, agent is about to take screenshots")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
screenshot: bytes = bytes()
|
screenshot: bytes = b""
|
||||||
if file_path:
|
if file_path:
|
||||||
screenshot = await page.screenshot(
|
screenshot = await page.screenshot(
|
||||||
path=file_path,
|
path=file_path,
|
||||||
|
@ -77,14 +77,14 @@ async def _scrolling_screenshots_helper(
|
||||||
url: str | None = None,
|
url: str | None = None,
|
||||||
draw_boxes: bool = False,
|
draw_boxes: bool = False,
|
||||||
max_number: int = settings.MAX_NUM_SCREENSHOTS,
|
max_number: int = settings.MAX_NUM_SCREENSHOTS,
|
||||||
) -> List[bytes]:
|
) -> list[bytes]:
|
||||||
skyvern_page = await SkyvernFrame.create_instance(frame=page)
|
skyvern_page = await SkyvernFrame.create_instance(frame=page)
|
||||||
# page is the main frame and the index must be 0
|
# page is the main frame and the index must be 0
|
||||||
assert isinstance(skyvern_page.frame, Page)
|
assert isinstance(skyvern_page.frame, Page)
|
||||||
frame = "main.frame"
|
frame = "main.frame"
|
||||||
frame_index = 0
|
frame_index = 0
|
||||||
|
|
||||||
screenshots: List[bytes] = []
|
screenshots: list[bytes] = []
|
||||||
if await skyvern_page.is_window_scrollable():
|
if await skyvern_page.is_window_scrollable():
|
||||||
scroll_y_px_old = -30.0
|
scroll_y_px_old = -30.0
|
||||||
scroll_y_px = await skyvern_page.scroll_to_top(draw_boxes=draw_boxes, frame=frame, frame_index=frame_index)
|
scroll_y_px = await skyvern_page.scroll_to_top(draw_boxes=draw_boxes, frame=frame, frame_index=frame_index)
|
||||||
|
@ -161,7 +161,7 @@ class SkyvernFrame:
|
||||||
draw_boxes: bool = False,
|
draw_boxes: bool = False,
|
||||||
max_number: int = settings.MAX_NUM_SCREENSHOTS,
|
max_number: int = settings.MAX_NUM_SCREENSHOTS,
|
||||||
scroll: bool = True,
|
scroll: bool = True,
|
||||||
) -> List[bytes]:
|
) -> list[bytes]:
|
||||||
if not scroll:
|
if not scroll:
|
||||||
return [await _current_viewpoint_screenshot_helper(page=page)]
|
return [await _current_viewpoint_screenshot_helper(page=page)]
|
||||||
|
|
||||||
|
@ -199,7 +199,7 @@ class SkyvernFrame:
|
||||||
js_script = "(element) => scrollToElementTop(element)"
|
js_script = "(element) => scrollToElementTop(element)"
|
||||||
return await self.evaluate(frame=self.frame, expression=js_script, arg=element)
|
return await self.evaluate(frame=self.frame, expression=js_script, arg=element)
|
||||||
|
|
||||||
async def parse_element_from_html(self, frame: str, element: ElementHandle, interactable: bool) -> Dict:
|
async def parse_element_from_html(self, frame: str, element: ElementHandle, interactable: bool) -> dict:
|
||||||
js_script = "async ([frame, element, interactable]) => await buildElementObject(frame, element, interactable)"
|
js_script = "async ([frame, element, interactable]) => await buildElementObject(frame, element, interactable)"
|
||||||
return await self.evaluate(frame=self.frame, expression=js_script, arg=[frame, element, interactable])
|
return await self.evaluate(frame=self.frame, expression=js_script, arg=[frame, element, interactable])
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue