[Backend] Add City and State targeting for Massive geo-targeting (#4133)

This commit is contained in:
Marc Kelechava 2025-11-28 14:24:44 -08:00 committed by GitHub
parent 793d5d350d
commit b23fea86be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 213 additions and 47 deletions

View file

@ -115,7 +115,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunParameter,
WorkflowRunStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType
from skyvern.schemas.runs import GeoTarget, ProxyLocation, ProxyLocationInput, RunEngine, RunType
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptFile, ScriptStatus, WorkflowScript
from skyvern.schemas.steps import AgentStepOutput
from skyvern.schemas.workflows import BlockStatus, BlockType, WorkflowStatus
@ -123,6 +123,34 @@ from skyvern.webeye.actions.actions import Action
LOG = structlog.get_logger()
def _serialize_proxy_location(proxy_location: ProxyLocationInput) -> str | None:
"""
Serialize proxy_location for database storage.
Converts GeoTarget objects or dicts to JSON strings, passes through
ProxyLocation enum values as-is, and returns None for None.
"""
result: str | None = None
if proxy_location is None:
result = None
elif isinstance(proxy_location, GeoTarget):
result = json.dumps(proxy_location.model_dump())
elif isinstance(proxy_location, dict):
result = json.dumps(proxy_location)
else:
# ProxyLocation enum - return the string value
result = str(proxy_location)
LOG.debug(
"Serializing proxy_location for DB",
input_type=type(proxy_location).__name__,
input_value=str(proxy_location),
serialized_value=result,
)
return result
DB_CONNECT_ARGS: dict[str, Any] = {}
if "postgresql+psycopg" in settings.DATABASE_STRING:
@ -161,7 +189,7 @@ class AgentDB:
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
workflow_run_id: str | None = None,
order: int | None = None,
@ -194,7 +222,7 @@ class AgentDB:
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
organization_id=organization_id,
proxy_location=proxy_location,
proxy_location=_serialize_proxy_location(proxy_location),
extracted_information_schema=extracted_information_schema,
workflow_run_id=workflow_run_id,
order=order,
@ -1390,7 +1418,7 @@ class AgentDB:
workflow_definition: dict[str, Any],
organization_id: str | None = None,
description: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
webhook_callback_url: str | None = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
@ -1415,7 +1443,7 @@ class AgentDB:
title=title,
description=description,
workflow_definition=workflow_definition,
proxy_location=proxy_location,
proxy_location=_serialize_proxy_location(proxy_location),
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
@ -2259,7 +2287,7 @@ class AgentDB:
organization_id: str,
browser_session_id: str | None = None,
browser_profile_id: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
@ -2281,7 +2309,7 @@ class AgentDB:
organization_id=organization_id,
browser_session_id=browser_session_id,
browser_profile_id=browser_profile_id,
proxy_location=proxy_location,
proxy_location=_serialize_proxy_location(proxy_location),
status="created",
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
@ -3565,7 +3593,7 @@ class AgentDB:
prompt: str | None = None,
url: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
totp_identifier: str | None = None,
totp_verification_url: str | None = None,
webhook_callback_url: str | None = None,
@ -3584,7 +3612,7 @@ class AgentDB:
workflow_permanent_id=workflow_permanent_id,
prompt=prompt,
url=url,
proxy_location=proxy_location,
proxy_location=_serialize_proxy_location(proxy_location),
totp_identifier=totp_identifier,
totp_verification_url=totp_verification_url,
webhook_callback_url=webhook_callback_url,
@ -4190,7 +4218,7 @@ class AgentDB:
runnable_type: str | None = None,
runnable_id: str | None = None,
timeout_minutes: int | None = None,
proxy_location: ProxyLocation | None = ProxyLocation.RESIDENTIAL,
proxy_location: ProxyLocationInput = ProxyLocation.RESIDENTIAL,
) -> PersistentBrowserSession:
"""Create a new persistent browser session."""
try:
@ -4200,7 +4228,7 @@ class AgentDB:
runnable_type=runnable_type,
runnable_id=runnable_id,
timeout_minutes=timeout_minutes,
proxy_location=proxy_location,
proxy_location=_serialize_proxy_location(proxy_location),
)
session.add(browser_session)
await session.commit()

View file

@ -55,7 +55,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunStatus,
WorkflowStatus,
)
from skyvern.schemas.runs import ProxyLocation, ScriptRunResponse
from skyvern.schemas.runs import GeoTarget, ProxyLocation, ProxyLocationInput, ScriptRunResponse
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptFile
from skyvern.schemas.workflows import BlockStatus, BlockType
from skyvern.webeye.actions.actions import (
@ -85,6 +85,50 @@ from skyvern.webeye.actions.actions import (
LOG = structlog.get_logger()
def _deserialize_proxy_location(value: str | None) -> ProxyLocationInput:
"""
Deserialize proxy_location from database storage.
Handles:
- None -> None
- ProxyLocation enum string (e.g., "RESIDENTIAL") -> ProxyLocation enum
- JSON string (e.g., '{"country": "US", ...}') -> GeoTarget object
"""
if value is None:
return None
result: ProxyLocationInput = None
# Try to parse as JSON first (for GeoTarget)
if value.startswith("{"):
try:
data = json.loads(value)
result = GeoTarget.model_validate(data)
LOG.info(
"Deserialized proxy_location as GeoTarget",
db_value=value,
result=str(result),
)
return result
except (json.JSONDecodeError, ValueError):
pass
# Try as ProxyLocation enum
try:
result = ProxyLocation(value)
LOG.info(
"Deserialized proxy_location as ProxyLocation enum",
db_value=value,
result=str(result),
)
return result
except ValueError:
# If all else fails, return as-is (shouldn't happen with valid data)
LOG.warning("Failed to deserialize proxy_location", db_value=value)
return None
# Mapping of action types to their corresponding action classes
ACTION_TYPE_TO_CLASS = {
ActionType.CLICK: ClickAction,
@ -142,7 +186,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False, workflow_p
extracted_information=task_obj.extracted_information,
failure_reason=task_obj.failure_reason,
organization_id=task_obj.organization_id,
proxy_location=(ProxyLocation(task_obj.proxy_location) if task_obj.proxy_location else None),
proxy_location=_deserialize_proxy_location(task_obj.proxy_location),
extracted_information_schema=task_obj.extracted_information_schema,
extra_http_headers=task_obj.extra_http_headers,
workflow_run_id=task_obj.workflow_run_id,
@ -272,7 +316,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
totp_identifier=workflow_model.totp_identifier,
persist_browser_session=workflow_model.persist_browser_session,
model=workflow_model.model,
proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None),
proxy_location=_deserialize_proxy_location(workflow_model.proxy_location),
max_screenshot_scrolls=workflow_model.max_screenshot_scrolling_times,
version=workflow_model.version,
is_saved_task=workflow_model.is_saved_task,
@ -312,9 +356,7 @@ def convert_to_workflow_run(
browser_profile_id=workflow_run_model.browser_profile_id,
status=WorkflowRunStatus[workflow_run_model.status],
failure_reason=workflow_run_model.failure_reason,
proxy_location=(
ProxyLocation(workflow_run_model.proxy_location) if workflow_run_model.proxy_location else None
),
proxy_location=_deserialize_proxy_location(workflow_run_model.proxy_location),
webhook_callback_url=workflow_run_model.webhook_callback_url,
webhook_failure_reason=workflow_run_model.webhook_failure_reason,
totp_verification_url=workflow_run_model.totp_verification_url,

View file

@ -5,7 +5,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_validator
from skyvern.config import settings
from skyvern.schemas.runs import ProxyLocation
from skyvern.schemas.runs import ProxyLocationInput
from skyvern.utils.url_validators import validate_url
DEFAULT_WORKFLOW_TITLE = "New Workflow"
@ -40,7 +40,7 @@ class TaskV2(BaseModel):
output: dict[str, Any] | list | str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocationInput = None
webhook_callback_url: str | None = None
webhook_failure_reason: str | None = None
extracted_information_schema: dict | list | str | None = None
@ -149,7 +149,7 @@ class TaskV2Request(BaseModel):
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocationInput = None
publish_workflow: bool = False
extracted_information_schema: dict | list | str | None = None
error_code_mapping: dict[str, str] | None = None

View file

@ -18,7 +18,7 @@ from skyvern.exceptions import (
from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.schemas.files import FileInfo
from skyvern.schemas.docs.doc_strings import PROXY_LOCATION_DOC_STRING
from skyvern.schemas.runs import ProxyLocation
from skyvern.schemas.runs import ProxyLocationInput
from skyvern.utils.url_validators import validate_url
@ -69,7 +69,7 @@ class TaskBase(BaseModel):
}
],
)
proxy_location: ProxyLocation | None = Field(
proxy_location: ProxyLocationInput = Field(
default=None,
description=PROXY_LOCATION_DOC_STRING,
)

View file

@ -10,7 +10,7 @@ from skyvern.forge.sdk.schemas.task_v2 import TaskV2
from skyvern.forge.sdk.workflow.exceptions import WorkflowDefinitionHasDuplicateBlockLabels
from skyvern.forge.sdk.workflow.models.block import BlockTypeVar
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, OutputParameter
from skyvern.schemas.runs import ProxyLocation, ScriptRunResponse
from skyvern.schemas.runs import ProxyLocationInput, ScriptRunResponse
from skyvern.schemas.workflows import WorkflowStatus
from skyvern.utils.url_validators import validate_url
@ -18,7 +18,7 @@ from skyvern.utils.url_validators import validate_url
@deprecated("Use WorkflowRunRequest instead")
class WorkflowRequestBody(BaseModel):
data: dict[str, Any] | None = None
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocationInput = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
@ -77,7 +77,7 @@ class Workflow(BaseModel):
is_saved_task: bool
description: str | None = None
workflow_definition: WorkflowDefinition
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocationInput = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
@ -142,7 +142,7 @@ class WorkflowRun(BaseModel):
debug_session_id: str | None = None
status: WorkflowRunStatus
extra_http_headers: dict[str, str] | None = None
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocationInput = None
webhook_callback_url: str | None = None
webhook_failure_reason: str | None = None
totp_verification_url: str | None = None
@ -186,7 +186,7 @@ class WorkflowRunResponseBase(BaseModel):
workflow_run_id: str
status: WorkflowRunStatus
failure_reason: str | None = None
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocationInput = None
webhook_callback_url: str | None = None
webhook_failure_reason: str | None = None
totp_verification_url: str | None = None

View file

@ -107,7 +107,13 @@ from skyvern.forge.sdk.workflow.models.workflow import (
WorkflowRunResponseBase,
WorkflowRunStatus,
)
from skyvern.schemas.runs import ProxyLocation, RunStatus, RunType, WorkflowRunRequest, WorkflowRunResponse
from skyvern.schemas.runs import (
ProxyLocationInput,
RunStatus,
RunType,
WorkflowRunRequest,
WorkflowRunResponse,
)
from skyvern.schemas.scripts import Script, ScriptBlock, ScriptStatus, WorkflowScript
from skyvern.schemas.workflows import (
BLOCK_YAML_TYPES,
@ -528,7 +534,7 @@ class WorkflowService:
workflow: Workflow,
*,
browser_session_id: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
) -> PersistentBrowserSession | None:
if browser_session_id: # the user has supplied an id, so no need to create one
return None
@ -1125,7 +1131,7 @@ class WorkflowService:
title: str,
workflow_definition: WorkflowDefinition,
description: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
max_screenshot_scrolling_times: int | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
@ -1176,7 +1182,7 @@ class WorkflowService:
totp_identifier: str | None = None,
totp_verification_url: str | None = None,
webhook_callback_url: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
max_iterations: int | None = None,
@ -3246,7 +3252,7 @@ class WorkflowService:
self,
organization: Organization,
title: str,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
max_screenshot_scrolling_times: int | None = None,
extra_http_headers: dict[str, str] | None = None,
run_with: str | None = None,

View file

@ -146,6 +146,94 @@ class ProxyLocation(StrEnum):
return mapping.get(proxy_location, "US")
# Supported countries for GeoTarget - must match Massive's coverage
SUPPORTED_GEO_COUNTRIES = frozenset(
{
"US",
"AR",
"AU",
"BR",
"CA",
"DE",
"ES",
"FR",
"GB",
"IE",
"IN",
"IT",
"JP",
"MX",
"NL",
"NZ",
"TR",
"ZA",
}
)
class GeoTarget(BaseModel):
"""
Granular geographic targeting for proxy selection.
Supports country, subdivision (state/region), and city level targeting.
Uses ISO 3166-1 alpha-2 for countries, ISO 3166-2 for subdivisions,
and GeoNames English names for cities.
Examples:
- {"country": "US"} - United States (same as RESIDENTIAL)
- {"country": "US", "subdivision": "CA"} - California, US
- {"country": "US", "subdivision": "NY", "city": "New York"} - New York City
- {"country": "GB", "city": "London"} - London, UK
"""
country: str = Field(
description="ISO 3166-1 alpha-2 country code (e.g., 'US', 'GB', 'DE')",
examples=["US", "GB", "DE", "FR"],
min_length=2,
max_length=2,
)
subdivision: str | None = Field(
default=None,
description="ISO 3166-2 subdivision code without country prefix (e.g., 'CA' for California, 'NY' for New York)",
examples=["CA", "NY", "TX", "ENG"],
max_length=10,
)
city: str | None = Field(
default=None,
description="City name in English from GeoNames (e.g., 'New York', 'Los Angeles', 'London')",
examples=["New York", "Los Angeles", "London", "Berlin"],
max_length=100,
)
@field_validator("country")
@classmethod
def validate_country(cls, v: str) -> str:
"""Validate country is in supported list and normalize to uppercase."""
v = v.upper()
if v not in SUPPORTED_GEO_COUNTRIES:
raise ValueError(
f"Country '{v}' is not supported for geo targeting. "
f"Supported countries: {sorted(SUPPORTED_GEO_COUNTRIES)}"
)
return v
@field_validator("subdivision")
@classmethod
def validate_subdivision(cls, v: str | None) -> str | None:
"""Normalize subdivision code to uppercase and strip country prefix if present."""
if v is None:
return v
v = v.upper()
# Strip country prefix if accidentally included (e.g., "US-CA" -> "CA")
if "-" in v:
v = v.split("-", 1)[1]
return v
# Type alias for proxy location that accepts either legacy enum or new GeoTarget
ProxyLocationInput = ProxyLocation | GeoTarget | dict | None
def get_tzinfo_from_proxy(proxy_location: ProxyLocation) -> ZoneInfo | None:
if proxy_location == ProxyLocation.NONE:
return None
@ -277,9 +365,10 @@ class TaskRunRequest(BaseModel):
title: str | None = Field(
default=None, description="The title for the task", examples=["The title of my first skyvern task"]
)
proxy_location: ProxyLocation | None = Field(
proxy_location: ProxyLocation | GeoTarget | dict | None = Field(
default=ProxyLocation.RESIDENTIAL,
description=PROXY_LOCATION_DOC_STRING,
description=PROXY_LOCATION_DOC_STRING + " Can also be a GeoTarget object for granular city/state targeting: "
'{"country": "US", "subdivision": "CA", "city": "San Francisco"}',
)
data_extraction_schema: dict | list | str | None = Field(
default=None,
@ -365,9 +454,10 @@ class WorkflowRunRequest(BaseModel):
)
parameters: dict[str, Any] | None = Field(default=None, description="Parameters to pass to the workflow")
title: str | None = Field(default=None, description="The title for this workflow run")
proxy_location: ProxyLocation | None = Field(
proxy_location: ProxyLocation | GeoTarget | dict | None = Field(
default=ProxyLocation.RESIDENTIAL,
description=PROXY_LOCATION_DOC_STRING,
description=PROXY_LOCATION_DOC_STRING + " Can also be a GeoTarget object for granular city/state targeting: "
'{"country": "US", "subdivision": "CA", "city": "San Francisco"}',
)
webhook_url: str | None = Field(
default=None,

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, field_validator
from skyvern.config import settings
from skyvern.forge.sdk.workflow.models.parameter import OutputParameter, ParameterType, WorkflowParameterType
from skyvern.schemas.runs import ProxyLocation, RunEngine
from skyvern.schemas.runs import GeoTarget, ProxyLocation, RunEngine
class WorkflowStatus(StrEnum):
@ -551,7 +551,7 @@ class WorkflowDefinitionYAML(BaseModel):
class WorkflowCreateYAMLRequest(BaseModel):
title: str
description: str | None = None
proxy_location: ProxyLocation | None = None
proxy_location: ProxyLocation | GeoTarget | dict | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None

View file

@ -37,7 +37,7 @@ from skyvern.forge.sdk.workflow.models.block import (
)
from skyvern.forge.sdk.workflow.models.parameter import PARAMETER_TYPE, ContextParameter
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRequestBody, WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import ProxyLocation, RunEngine, RunType, TaskRunRequest, TaskRunResponse
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput, RunEngine, RunType, TaskRunRequest, TaskRunResponse
from skyvern.schemas.workflows import (
BLOCK_YAML_TYPES,
PARAMETER_YAML_TYPES,
@ -150,7 +150,7 @@ async def initialize_task_v2(
organization: Organization,
user_prompt: str,
user_url: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
totp_identifier: str | None = None,
totp_verification_url: str | None = None,
webhook_callback_url: str | None = None,

View file

@ -41,7 +41,7 @@ from skyvern.exceptions import (
)
from skyvern.forge.sdk.api.files import get_download_dir, make_temp_directory
from skyvern.forge.sdk.core.skyvern_context import current, ensure_context
from skyvern.schemas.runs import ProxyLocation, get_tzinfo_from_proxy
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput, get_tzinfo_from_proxy
from skyvern.webeye.utils.page import ScreenshotMode, SkyvernFrame
LOG = structlog.get_logger()
@ -680,7 +680,7 @@ class BrowserState:
async def check_and_fix_state(
self,
url: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
script_id: str | None = None,
@ -886,7 +886,7 @@ class BrowserState:
async def get_or_create_page(
self,
url: str | None = None,
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
task_id: str | None = None,
workflow_run_id: str | None = None,
script_id: str | None = None,

View file

@ -9,7 +9,7 @@ from skyvern.exceptions import MissingBrowserState
from skyvern.forge import app
from skyvern.forge.sdk.schemas.tasks import Task
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRun
from skyvern.schemas.runs import ProxyLocation
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
from skyvern.webeye.browser_factory import BrowserContextFactory, BrowserState, VideoArtifact
LOG = structlog.get_logger()
@ -26,7 +26,7 @@ class BrowserManager:
@staticmethod
async def _create_browser_state(
proxy_location: ProxyLocation | None = None,
proxy_location: ProxyLocationInput = None,
url: str | None = None,
task_id: str | None = None,
workflow_run_id: str | None = None,

View file

@ -16,7 +16,7 @@ from skyvern.forge.sdk.schemas.persistent_browser_sessions import (
PersistentBrowserSessionStatus,
is_final_status,
)
from skyvern.schemas.runs import ProxyLocation
from skyvern.schemas.runs import ProxyLocation, ProxyLocationInput
from skyvern.webeye.browser_factory import BrowserState
LOG = structlog.get_logger()
@ -254,7 +254,7 @@ class PersistentSessionsManager:
runnable_id: str | None = None,
runnable_type: str | None = None,
timeout_minutes: int | None = None,
proxy_location: ProxyLocation | None = ProxyLocation.RESIDENTIAL,
proxy_location: ProxyLocationInput = ProxyLocation.RESIDENTIAL,
) -> PersistentBrowserSession:
"""Create a new browser session for an organization and return its ID with the browser state."""