diff --git a/skyvern/core/script_generations/generate_script.py b/skyvern/core/script_generations/generate_script.py index 8d5fb482..ab14134c 100644 --- a/skyvern/core/script_generations/generate_script.py +++ b/skyvern/core/script_generations/generate_script.py @@ -294,13 +294,25 @@ def _action_to_stmt(act: dict[str, Any], assign_to_output: bool = False) -> cst. args.append( cst.Arg( keyword=cst.Name("prompt"), - value=_value(act["data_extraction_goal"]), + value=_render_value(act["data_extraction_goal"]), whitespace_after_arg=cst.ParenthesizedWhitespace( indent=True, last_line=cst.SimpleWhitespace(INDENT), ), ) ) + if act.get("data_extraction_schema"): + args.append( + cst.Arg( + keyword=cst.Name("schema"), + value=_value(act["data_extraction_schema"]), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + comma=cst.Comma(), + ) + ) args.extend( [ @@ -565,6 +577,14 @@ def _build_extract_statement(block_title: str, block: dict[str, Any]) -> cst.Sim last_line=cst.SimpleWhitespace(INDENT), ), ), + cst.Arg( + keyword=cst.Name("schema"), + value=_value(block.get("data_schema", "")), + whitespace_after_arg=cst.ParenthesizedWhitespace( + indent=True, + last_line=cst.SimpleWhitespace(INDENT), + ), + ), cst.Arg( keyword=cst.Name("cache_key"), value=_value(block_title), diff --git a/skyvern/core/script_generations/skyvern_page.py b/skyvern/core/script_generations/skyvern_page.py index d30ef442..15a57b1f 100644 --- a/skyvern/core/script_generations/skyvern_page.py +++ b/skyvern/core/script_generations/skyvern_page.py @@ -20,7 +20,7 @@ from skyvern.forge.sdk.core import skyvern_context from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.webeye.actions import handler_utils from skyvern.webeye.actions.action_types import ActionType -from skyvern.webeye.actions.actions import Action, ActionStatus, SelectOption +from skyvern.webeye.actions.actions import Action, ActionStatus, ExtractAction, SelectOption from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.scraper.scraper import ScrapedPage, scrape_website @@ -224,6 +224,25 @@ class SkyvernPage: response=response, created_by="script", ) + if action_type == ActionType.EXTRACT: + action = ExtractAction( + element_id="", + action_type=action_type, + status=status, + organization_id=context.organization_id, + workflow_run_id=context.workflow_run_id, + task_id=context.task_id, + step_id=context.step_id, + step_order=0, + action_order=0, + intention=intention, + reasoning=f"Auto-generated action for {action_type.value}", + data_extraction_goal=kwargs.get("prompt"), + data_extraction_schema=kwargs.get("schema"), + option=select_option, + response=response, + created_by="script", + ) created_action = await app.DATABASE.create_action(action) return created_action diff --git a/skyvern/core/script_generations/transform_workflow_run.py b/skyvern/core/script_generations/transform_workflow_run.py index ef0017a3..b10a6795 100644 --- a/skyvern/core/script_generations/transform_workflow_run.py +++ b/skyvern/core/script_generations/transform_workflow_run.py @@ -7,6 +7,7 @@ from skyvern.core.script_generations.constants import SCRIPT_TASK_BLOCKS from skyvern.forge import app from skyvern.schemas.workflows import BlockType from skyvern.services import workflow_service +from skyvern.webeye.actions.action_types import ActionType LOG = structlog.get_logger(__name__) @@ -100,6 +101,19 @@ async def transform_workflow_run_to_code_gen_input(workflow_run_id: str, organiz for action in actions: action_dump = action.model_dump() action_dump["xpath"] = action.get_xpath() + if ( + "data_extraction_goal" in final_dump + and final_dump["data_extraction_goal"] + and action.action_type == ActionType.EXTRACT + ): + # use the right data extraction goal for the extract action + action_dump["data_extraction_goal"] = final_dump["data_extraction_goal"] + if ( + "extracted_information_schema" in final_dump + and final_dump["extracted_information_schema"] + and action.action_type == ActionType.EXTRACT + ): + action_dump["data_extraction_schema"] = final_dump["extracted_information_schema"] action_dumps.append(action_dump) actions_by_task[run_block.task_id] = action_dumps else: diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index d65e5114..f756785c 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -3010,6 +3010,7 @@ class ForgeAgent: return ExtractAction( reasoning=data_extraction_summary_resp.get("summary", "Extracting information from the page"), data_extraction_goal=task.data_extraction_goal, + data_extraction_schema=task.extracted_information_schema, organization_id=task.organization_id, task_id=task.task_id, workflow_run_id=task.workflow_run_id, diff --git a/skyvern/services/script_service.py b/skyvern/services/script_service.py index 74c83028..1b3a4610 100644 --- a/skyvern/services/script_service.py +++ b/skyvern/services/script_service.py @@ -251,6 +251,7 @@ async def execute_script( async def _create_workflow_block_run_and_task( block_type: BlockType, prompt: str | None = None, + schema: dict[str, Any] | list | str | None = None, url: str | None = None, ) -> tuple[str | None, str | None, str | None]: """ @@ -287,7 +288,8 @@ async def _create_workflow_block_run_and_task( title=f"Script {block_type.value} task", navigation_goal=prompt, data_extraction_goal=prompt if block_type == BlockType.EXTRACTION else None, - navigation_payload={}, + extracted_information_schema=schema, + navigation_payload=None, status="running", organization_id=organization_id, workflow_run_id=workflow_run_id, @@ -899,6 +901,10 @@ async def _generate_block_code_from_task( continue action_dump = task_action.model_dump() action_dump["xpath"] = task_action.get_xpath() + is_data_extraction_goal = "data_extraction_goal" in block_data and "data_extraction_goal" in action_dump + if is_data_extraction_goal: + # use the raw data extraction goal which is potentially a template + action_dump["data_extraction_goal"] = block_data["data_extraction_goal"] actions_to_cache.append(action_dump) if not actions_to_cache: @@ -1157,6 +1163,7 @@ async def login( async def extract( prompt: str, + schema: dict[str, Any] | list | str | None = None, url: str | None = None, max_steps: int | None = None, cache_key: str | None = None, @@ -1165,6 +1172,7 @@ async def extract( workflow_run_block_id, task_id, step_id = await _create_workflow_block_run_and_task( block_type=BlockType.EXTRACTION, prompt=prompt, + schema=schema, url=url, ) # set the prompt in the RunContext