anthropic CUA (#2231)

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Shuchang Zheng 2025-04-28 09:49:44 +08:00 committed by GitHub
parent 5582998490
commit 0a0228b341
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 378 additions and 45 deletions

27
poetry.lock generated
View file

@ -270,6 +270,31 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
] ]
[[package]]
name = "anthropic"
version = "0.50.0"
description = "The official Python library for the anthropic API"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "anthropic-0.50.0-py3-none-any.whl", hash = "sha256:defbd79327ca2fa61fd7b9eb2f1627dfb1f69c25d49288c52e167ddb84574f80"},
{file = "anthropic-0.50.0.tar.gz", hash = "sha256:42175ec04ce4ff2fa37cd436710206aadff546ee99d70d974699f59b49adc66f"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.25.0,<1"
jiter = ">=0.4.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
typing-extensions = ">=4.10,<5"
[package.extras]
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
vertex = ["google-auth[requests] (>=2,<3)"]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.9.0" version = "4.9.0"
@ -6804,4 +6829,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "^3.11,<3.12" python-versions = "^3.11,<3.12"
content-hash = "b8883bdb02803bdb77dfe2de47aca0a28b509720513bf2d7fc2ee001bedf05fb" content-hash = "926815050df2b2d2fbdb96ac5084cb0e19a628a04d29cbc78ea63936e11b213c"

View file

@ -55,6 +55,7 @@ pypdf = "^5.1.0"
fastmcp = "^0.4.1" fastmcp = "^0.4.1"
psutil = ">=7.0.0" psutil = ">=7.0.0"
tiktoken = ">=0.9.0" tiktoken = ">=0.9.0"
anthropic = "^0.50.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
isort = "^5.13.2" isort = "^5.13.2"

View file

@ -81,7 +81,7 @@ class SkyvernClient:
run_id: str, run_id: str,
) -> RunResponse: ) -> RunResponse:
run_obj = await self.client.agent.get_run(run_id=run_id) run_obj = await self.client.agent.get_run(run_id=run_id)
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua]: if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
return TaskRunResponse.model_validate(run_obj.dict()) return TaskRunResponse.model_validate(run_obj.dict())
elif run_obj.run_type == RunType.workflow_run: elif run_obj.run_type == RunType.workflow_run:
return WorkflowRunResponse.model_validate(run_obj.dict()) return WorkflowRunResponse.model_validate(run_obj.dict())

View file

@ -119,6 +119,7 @@ class Settings(BaseSettings):
# LLM PROVIDER SPECIFIC # LLM PROVIDER SPECIFIC
ENABLE_OPENAI: bool = False ENABLE_OPENAI: bool = False
ENABLE_ANTHROPIC: bool = False ENABLE_ANTHROPIC: bool = False
ENABLE_BEDROCK_ANTHROPIC: bool = False
ENABLE_AZURE: bool = False ENABLE_AZURE: bool = False
ENABLE_AZURE_GPT4O_MINI: bool = False ENABLE_AZURE_GPT4O_MINI: bool = False
ENABLE_AZURE_O3_MINI: bool = False ENABLE_AZURE_O3_MINI: bool = False

View file

@ -672,3 +672,8 @@ class SkyvernContextWindowExceededError(SkyvernException):
def __init__(self) -> None: def __init__(self) -> None:
message = "Context window exceeded. Please contact support@skyvern.com for help." message = "Context window exceeded. Please contact support@skyvern.com for help."
super().__init__(message) super().__init__(message)
class LLMCallerNotFoundError(SkyvernException):
def __init__(self, uid: str) -> None:
super().__init__(f"LLM caller for {uid} is not found")

View file

@ -58,6 +58,7 @@ from skyvern.forge.sdk.api.files import (
rename_file, rename_file,
wait_for_download_finished, wait_for_download_finished,
) )
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
@ -70,7 +71,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
from skyvern.schemas.runs import RunEngine, RunType from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine
from skyvern.utils.prompt_engine import load_prompt_with_elements from skyvern.utils.prompt_engine import load_prompt_with_elements
from skyvern.webeye.actions.actions import ( from skyvern.webeye.actions.actions import (
Action, Action,
@ -88,7 +89,7 @@ from skyvern.webeye.actions.actions import (
from skyvern.webeye.actions.caching import retrieve_action_plan from skyvern.webeye.actions.caching import retrieve_action_plan
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
from skyvern.webeye.actions.parse_actions import parse_actions, parse_cua_actions from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
from skyvern.webeye.browser_factory import BrowserState from skyvern.webeye.browser_factory import BrowserState
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
@ -253,6 +254,7 @@ class ForgeAgent:
complete_verification: bool = True, complete_verification: bool = True,
engine: RunEngine = RunEngine.skyvern_v1, engine: RunEngine = RunEngine.skyvern_v1,
cua_response: OpenAIResponse | None = None, cua_response: OpenAIResponse | None = None,
llm_caller: LLMCaller | None = None,
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]: ) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
workflow_run: WorkflowRun | None = None workflow_run: WorkflowRun | None = None
if task.workflow_run_id: if task.workflow_run_id:
@ -378,6 +380,13 @@ class ForgeAgent:
if page := await browser_state.get_working_page(): if page := await browser_state.get_working_page():
await self.register_async_operations(organization, task, page) await self.register_async_operations(organization, task, page)
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if engine == RunEngine.anthropic_cua and not llm_caller:
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not llm_caller:
llm_caller = LLMCaller(llm_key="ANTHROPIC_CLAUDE3.5_SONNET")
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
step, detailed_output = await self.agent_step( step, detailed_output = await self.agent_step(
task, task,
step, step,
@ -387,6 +396,7 @@ class ForgeAgent:
complete_verification=complete_verification, complete_verification=complete_verification,
engine=engine, engine=engine,
cua_response=cua_response, cua_response=cua_response,
llm_caller=llm_caller,
) )
await app.AGENT_FUNCTION.post_step_execution(task, step) await app.AGENT_FUNCTION.post_step_execution(task, step)
task = await self.update_task_errors_from_detailed_output(task, detailed_output) task = await self.update_task_errors_from_detailed_output(task, detailed_output)
@ -778,6 +788,7 @@ class ForgeAgent:
task_block: BaseTaskBlock | None = None, task_block: BaseTaskBlock | None = None,
complete_verification: bool = True, complete_verification: bool = True,
cua_response: OpenAIResponse | None = None, cua_response: OpenAIResponse | None = None,
llm_caller: LLMCaller | None = None,
) -> tuple[Step, DetailedAgentStepOutput]: ) -> tuple[Step, DetailedAgentStepOutput]:
detailed_agent_step_output = DetailedAgentStepOutput( detailed_agent_step_output = DetailedAgentStepOutput(
scraped_page=None, scraped_page=None,
@ -821,8 +832,17 @@ class ForgeAgent:
step=step, step=step,
scraped_page=scraped_page, scraped_page=scraped_page,
previous_response=cua_response, previous_response=cua_response,
engine=engine,
) )
detailed_agent_step_output.cua_response = new_cua_response detailed_agent_step_output.cua_response = new_cua_response
elif engine == RunEngine.anthropic_cua:
assert llm_caller is not None
actions = await self._generate_anthropic_actions(
task=task,
step=step,
scraped_page=scraped_page,
llm_caller=llm_caller,
)
else: else:
using_cached_action_plan = False using_cached_action_plan = False
if not task.navigation_goal and not isinstance(task_block, ValidationBlock): if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
@ -834,7 +854,7 @@ class ForgeAgent:
): ):
using_cached_action_plan = True using_cached_action_plan = True
else: else:
if engine != RunEngine.openai_cua: if engine in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm) self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
json_response = await app.LLM_API_HANDLER( json_response = await app.LLM_API_HANDLER(
prompt=extract_action_prompt, prompt=extract_action_prompt,
@ -1219,7 +1239,8 @@ class ForgeAgent:
step: Step, step: Step,
scraped_page: ScrapedPage, scraped_page: ScrapedPage,
previous_response: OpenAIResponse | None = None, previous_response: OpenAIResponse | None = None,
) -> tuple[list[Action], OpenAIResponse]: engine: RunEngine = RunEngine.openai_cua,
) -> tuple[list[Action], OpenAIResponse | None]:
if not previous_response: if not previous_response:
# this is the first step # this is the first step
first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create( first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create(
@ -1377,6 +1398,48 @@ class ForgeAgent:
return await parse_cua_actions(task, step, current_response), current_response return await parse_cua_actions(task, step, current_response), current_response
async def _generate_anthropic_actions(
self,
task: Task,
step: Step,
scraped_page: ScrapedPage,
llm_caller: LLMCaller,
) -> list[Action]:
if llm_caller.current_tool_results:
llm_caller.message_history.append({"role": "user", "content": llm_caller.current_tool_results})
llm_caller.clear_tool_results()
tools = [
{
"type": "computer_20250124",
"name": "computer",
"display_height_px": settings.BROWSER_HEIGHT,
"display_width_px": settings.BROWSER_WIDTH,
}
]
if not llm_caller.message_history:
llm_response = await llm_caller.call(
prompt=task.navigation_goal,
screenshots=scraped_page.screenshots,
use_message_history=True,
tools=tools,
raw_response=True,
betas=["computer-use-2025-01-24"],
)
else:
llm_response = await llm_caller.call(
screenshots=scraped_page.screenshots,
use_message_history=True,
tools=tools,
raw_response=True,
betas=["computer-use-2025-01-24"],
)
LOG.info("Anthropic response", llm_response=llm_response)
assistant_content = llm_response["content"]
llm_caller.message_history.append({"role": "assistant", "content": assistant_content})
actions = await parse_anthropic_actions(task, step, assistant_content)
return actions
@staticmethod @staticmethod
async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult: async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult:
LOG.info( LOG.info(
@ -1387,7 +1450,7 @@ class ForgeAgent:
) )
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True scroll = True
if run_obj and run_obj.task_run_type == RunType.openai_cua: if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
scroll = False scroll = False
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll) scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
@ -1454,7 +1517,7 @@ class ForgeAgent:
raise BrowserStateMissingPage() raise BrowserStateMissingPage()
fullpage_screenshot = True fullpage_screenshot = True
if engine == RunEngine.openai_cua: if engine in CUA_ENGINES:
fullpage_screenshot = False fullpage_screenshot = False
try: try:
@ -1580,7 +1643,7 @@ class ForgeAgent:
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
draw_boxes = True draw_boxes = True
scroll = True scroll = True
if engine == RunEngine.openai_cua: if engine in CUA_ENGINES:
max_screenshot_number = 1 max_screenshot_number = 1
draw_boxes = False draw_boxes = False
scroll = False scroll = False
@ -1602,7 +1665,7 @@ class ForgeAgent:
engine: RunEngine, engine: RunEngine,
) -> tuple[ScrapedPage, str]: ) -> tuple[ScrapedPage, str]:
# start the async tasks while running scrape_website # start the async tasks while running scrape_website
if engine != RunEngine.openai_cua: if engine not in CUA_ENGINES:
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape) self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
# Scrape the web page and get the screenshot and the elements # Scrape the web page and get the screenshot and the elements
@ -1653,7 +1716,7 @@ class ForgeAgent:
element_tree_format = ElementTreeFormat.HTML element_tree_format = ElementTreeFormat.HTML
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format) element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
extract_action_prompt = "" extract_action_prompt = ""
if engine != RunEngine.openai_cua: if engine not in CUA_ENGINES:
extract_action_prompt = await self._build_extract_action_prompt( extract_action_prompt = await self._build_extract_action_prompt(
task, task,
step, step,
@ -2371,7 +2434,7 @@ class ForgeAgent:
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id) run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
scroll = True scroll = True
if run_obj and run_obj.task_run_type == RunType.openai_cua: if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
scroll = False scroll = False
screenshots: list[bytes] = [] screenshots: list[bytes] = []

View file

@ -1,5 +1,6 @@
from typing import Awaitable, Callable from typing import Awaitable, Callable
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
from fastapi import FastAPI from fastapi import FastAPI
from openai import AsyncAzureOpenAI, AsyncOpenAI from openai import AsyncAzureOpenAI, AsyncOpenAI
@ -41,6 +42,9 @@ if SettingsManager.get_settings().ENABLE_AZURE_CUA:
azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT, azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT,
azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT, azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT,
) )
ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY)
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler( SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY

View file

@ -6,7 +6,9 @@ from typing import Any
import litellm import litellm
import structlog import structlog
from anthropic.types.message import Message as AnthropicMessage
from jinja2 import Template from jinja2 import Template
from litellm.utils import CustomStreamWrapper, ModelResponse
from skyvern.config import settings from skyvern.config import settings
from skyvern.exceptions import SkyvernContextWindowExceededError from skyvern.exceptions import SkyvernContextWindowExceededError
@ -456,11 +458,18 @@ class LLMCaller:
self.llm_config = LLMConfigRegistry.get_config(llm_key) self.llm_config = LLMConfigRegistry.get_config(llm_key)
self.base_parameters = base_parameters self.base_parameters = base_parameters
self.message_history: list[dict[str, Any]] = [] self.message_history: list[dict[str, Any]] = []
self.current_tool_results: list[dict[str, Any]] = []
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
self.current_tool_results.append(tool_result)
def clear_tool_results(self) -> None:
self.current_tool_results = []
async def call( async def call(
self, self,
prompt: str, prompt: str | None = None,
prompt_name: str, prompt_name: str | None = None,
step: Step | None = None, step: Step | None = None,
task_v2: TaskV2 | None = None, task_v2: TaskV2 | None = None,
thought: Thought | None = None, thought: Thought | None = None,
@ -469,6 +478,8 @@ class LLMCaller:
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
tools: list | None = None, tools: list | None = None,
use_message_history: bool = False, use_message_history: bool = False,
raw_response: bool = False,
**extra_parameters: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
start_time = time.perf_counter() start_time = time.perf_counter()
active_parameters = self.base_parameters or {} active_parameters = self.base_parameters or {}
@ -476,6 +487,8 @@ class LLMCaller:
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config) parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
active_parameters.update(parameters) active_parameters.update(parameters)
if extra_parameters:
active_parameters.update(extra_parameters)
if self.llm_config.litellm_params: # type: ignore if self.llm_config.litellm_params: # type: ignore
active_parameters.update(self.llm_config.litellm_params) # type: ignore active_parameters.update(self.llm_config.litellm_params) # type: ignore
@ -491,7 +504,7 @@ class LLMCaller:
) )
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"), data=prompt.encode("utf-8") if prompt else b"",
artifact_type=ArtifactType.LLM_PROMPT, artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots, screenshots=screenshots,
step=step, step=step,
@ -525,8 +538,7 @@ class LLMCaller:
) )
t_llm_request = time.perf_counter() t_llm_request = time.perf_counter()
try: try:
response = await litellm.acompletion( response = await self._dispatch_llm_call(
model=self.llm_config.model_name,
messages=messages, messages=messages,
tools=tools, tools=tools,
timeout=settings.LLM_CONFIG_TIMEOUT, timeout=settings.LLM_CONFIG_TIMEOUT,
@ -603,6 +615,21 @@ class LLMCaller:
cached_token_count=cached_tokens if cached_tokens > 0 else None, cached_token_count=cached_tokens if cached_tokens > 0 else None,
thought_cost=llm_cost, thought_cost=llm_cost,
) )
# Track LLM API handler duration
duration_seconds = time.perf_counter() - start_time
LOG.info(
"LLM API handler duration metrics",
llm_key=self.llm_key,
prompt_name=prompt_name,
model=self.llm_config.model_name,
duration_seconds=duration_seconds,
step_id=step.step_id if step else None,
thought_id=thought.observer_thought_id if thought else None,
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
)
if raw_response:
return response.model_dump()
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix) parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
await app.ARTIFACT_MANAGER.create_llm_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=json.dumps(parsed_response, indent=2).encode("utf-8"),
@ -626,17 +653,53 @@ class LLMCaller:
ai_suggestion=ai_suggestion, ai_suggestion=ai_suggestion,
) )
# Track LLM API handler duration return parsed_response
duration_seconds = time.perf_counter() - start_time
LOG.info( async def _dispatch_llm_call(
"LLM API handler duration metrics", self,
llm_key=self.llm_key, messages: list[dict[str, Any]],
prompt_name=prompt_name, tools: list | None = None,
model=self.llm_config.model_name, timeout: float = settings.LLM_CONFIG_TIMEOUT,
duration_seconds=duration_seconds, **active_parameters: dict[str, Any],
step_id=step.step_id if step else None, ) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
thought_id=thought.observer_thought_id if thought else None, if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
organization_id=step.organization_id if step else (thought.organization_id if thought else None), return await self._call_anthropic(messages, tools, timeout)
return await litellm.acompletion(
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
) )
return parsed_response async def _call_anthropic(
self,
messages: list[dict[str, Any]],
tools: list | None = None,
timeout: float = settings.LLM_CONFIG_TIMEOUT,
**active_parameters: dict[str, Any],
) -> AnthropicMessage:
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
return await app.ANTHROPIC_CLIENT.messages.create(
max_tokens=max_tokens,
messages=messages,
model=model_name,
tools=tools,
timeout=timeout,
betas=active_parameters.get("betas", None),
)
class LLMCallerManager:
_llm_callers: dict[str, LLMCaller] = {}
@classmethod
def get_llm_caller(cls, uid: str) -> LLMCaller | None:
return cls._llm_callers.get(uid)
@classmethod
def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None:
cls._llm_callers[uid] = llm_caller
@classmethod
def clear_llm_caller(cls, uid: str) -> None:
if uid in cls._llm_callers:
del cls._llm_callers[uid]

View file

@ -47,19 +47,22 @@ async def llm_messages_builder(
async def llm_messages_builder_with_history( async def llm_messages_builder_with_history(
prompt: str, prompt: str | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
message_history: list[dict[str, Any]] | None = None, message_history: list[dict[str, Any]] | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [] messages: list[dict[str, Any]] = []
if message_history: if message_history:
messages = copy.deepcopy(message_history) messages = copy.deepcopy(message_history)
current_user_messages: list[dict[str, Any]] = [
{ current_user_messages: list[dict[str, Any]] = []
"type": "text", if prompt:
"text": prompt, current_user_messages.append(
} {
] "type": "text",
"text": prompt,
}
)
if screenshots: if screenshots:
for screenshot in screenshots: for screenshot in screenshots:

View file

@ -96,6 +96,8 @@ class BackgroundTaskExecutor(AsyncExecutor):
engine = RunEngine.skyvern_v1 engine = RunEngine.skyvern_v1
if run_obj and run_obj.task_run_type == RunType.openai_cua: if run_obj and run_obj.task_run_type == RunType.openai_cua:
engine = RunEngine.openai_cua engine = RunEngine.openai_cua
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
engine = RunEngine.anthropic_cua
context: SkyvernContext = skyvern_context.ensure_context() context: SkyvernContext = skyvern_context.ensure_context()
context.task_id = task.task_id context.task_id = task.task_id

View file

@ -62,6 +62,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
) )
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
from skyvern.schemas.runs import ( from skyvern.schemas.runs import (
CUA_ENGINES,
RunEngine, RunEngine,
RunResponse, RunResponse,
RunType, RunType,
@ -1466,7 +1467,7 @@ async def run_task(
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url}) analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id) await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
if run_request.engine in [RunEngine.skyvern_v1, RunEngine.openai_cua]: if run_request.engine in CUA_ENGINES:
# create task v1 # create task v1
# if there's no url, call task generation first to generate the url, data schema if any # if there's no url, call task generation first to generate the url, data schema if any
url = run_request.url url = run_request.url
@ -1480,7 +1481,7 @@ async def run_task(
) )
url = url or task_generation.url url = url or task_generation.url
navigation_goal = task_generation.navigation_goal or run_request.prompt navigation_goal = task_generation.navigation_goal or run_request.prompt
if run_request.engine == RunEngine.openai_cua: if run_request.engine in CUA_ENGINES:
navigation_goal = run_request.prompt navigation_goal = run_request.prompt
navigation_payload = task_generation.navigation_payload navigation_payload = task_generation.navigation_payload
data_extraction_goal = task_generation.data_extraction_goal data_extraction_goal = task_generation.data_extraction_goal
@ -1511,6 +1512,8 @@ async def run_task(
run_type = RunType.task_v1 run_type = RunType.task_v1
if run_request.engine == RunEngine.openai_cua: if run_request.engine == RunEngine.openai_cua:
run_type = RunType.openai_cua run_type = RunType.openai_cua
elif run_request.engine == RunEngine.anthropic_cua:
run_type = RunType.anthropic_cua
# build the task run response # build the task run response
return TaskRunResponse( return TaskRunResponse(
run_id=task_v1_response.task_id, run_id=task_v1_response.task_id,
@ -1586,8 +1589,6 @@ async def run_task(
publish_workflow=run_request.publish_workflow, publish_workflow=run_request.publish_workflow,
), ),
) )
if run_request.engine == RunEngine.openai_cua:
pass
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}") raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")

View file

@ -93,12 +93,18 @@ class RunType(StrEnum):
task_v2 = "task_v2" task_v2 = "task_v2"
workflow_run = "workflow_run" workflow_run = "workflow_run"
openai_cua = "openai_cua" openai_cua = "openai_cua"
anthropic_cua = "anthropic_cua"
class RunEngine(StrEnum): class RunEngine(StrEnum):
skyvern_v1 = "skyvern-1.0" skyvern_v1 = "skyvern-1.0"
skyvern_v2 = "skyvern-2.0" skyvern_v2 = "skyvern-2.0"
openai_cua = "openai-cua" openai_cua = "openai-cua"
anthropic_cua = "anthropic-cua"
CUA_ENGINES = [RunEngine.openai_cua, RunEngine.anthropic_cua]
CUA_RUN_TYPES = [RunType.openai_cua, RunType.anthropic_cua]
class RunStatus(StrEnum): class RunStatus(StrEnum):
@ -217,8 +223,8 @@ class BaseRunResponse(BaseModel):
class TaskRunResponse(BaseRunResponse): class TaskRunResponse(BaseRunResponse):
run_type: Literal[RunType.task_v1, RunType.task_v2, RunType.openai_cua] = Field( run_type: Literal[RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua] = Field(
description="Types of a task run - task_v1, task_v2, openai_cua" description="Types of a task run - task_v1, task_v2, openai_cua, anthropic_cua"
) )
run_request: TaskRunRequest | None = Field( run_request: TaskRunRequest | None = Field(
default=None, description="The original request parameters used to start this task run" default=None, description="The original request parameters used to start this task run"

View file

@ -13,7 +13,11 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
if not run: if not run:
return None return None
if run.task_run_type == RunType.task_v1 or run.task_run_type == RunType.openai_cua: if (
run.task_run_type == RunType.task_v1
or run.task_run_type == RunType.openai_cua
or run.task_run_type == RunType.anthropic_cua
):
# fetch task v1 from db and transform to task run response # fetch task v1 from db and transform to task run response
task_v1 = await app.DATABASE.get_task(run.run_id, organization_id=organization_id) task_v1 = await app.DATABASE.get_task(run.run_id, organization_id=organization_id)
if not task_v1: if not task_v1:
@ -21,6 +25,8 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
run_engine = RunEngine.skyvern_v1 run_engine = RunEngine.skyvern_v1
if run.task_run_type == RunType.openai_cua: if run.task_run_type == RunType.openai_cua:
run_engine = RunEngine.openai_cua run_engine = RunEngine.openai_cua
elif run.task_run_type == RunType.anthropic_cua:
run_engine = RunEngine.anthropic_cua
return TaskRunResponse( return TaskRunResponse(
run_id=run.run_id, run_id=run.run_id,
run_type=run.task_run_type, run_type=run.task_run_type,
@ -136,7 +142,7 @@ async def cancel_run(run_id: str, organization_id: str | None = None, api_key: s
detail=f"Run not found {run_id}", detail=f"Run not found {run_id}",
) )
if run.task_run_type in [RunType.task_v1, RunType.openai_cua]: if run.task_run_type in [RunType.task_v1, RunType.openai_cua, RunType.anthropic_cua]:
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key) await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
elif run.task_run_type == RunType.task_v2: elif run.task_run_type == RunType.task_v2:
await cancel_task_v2(run_id, organization_id=organization_id) await cancel_task_v2(run_id, organization_id=organization_id)

View file

@ -87,6 +87,8 @@ async def run_task(
run_type = RunType.task_v1 run_type = RunType.task_v1
if engine == RunEngine.openai_cua: if engine == RunEngine.openai_cua:
run_type = RunType.openai_cua run_type = RunType.openai_cua
elif engine == RunEngine.anthropic_cua:
run_type = RunType.anthropic_cua
await app.DATABASE.create_task_run( await app.DATABASE.create_task_run(
task_run_type=run_type, task_run_type=run_type,
organization_id=organization.organization_id, organization_id=organization.organization_id,

View file

@ -113,6 +113,7 @@ class Action(BaseModel):
element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None
skyvern_element_hash: str | None = None skyvern_element_hash: str | None = None
skyvern_element_data: dict[str, Any] | None = None skyvern_element_data: dict[str, Any] | None = None
tool_call_id: str | None = None
# DecisiveAction (CompleteAction, TerminateAction) fields # DecisiveAction (CompleteAction, TerminateAction) fields
errors: list[UserDefinedError] | None = None errors: list[UserDefinedError] | None = None

View file

@ -59,6 +59,7 @@ from skyvern.forge.sdk.api.files import (
list_files_in_directory, list_files_in_directory,
wait_for_download_finished, wait_for_download_finished,
) )
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCallerManager
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
@ -363,9 +364,26 @@ class ActionHandler:
handler = ActionHandler._handled_action_types[action.action_type] handler = ActionHandler._handled_action_types[action.action_type]
results = await handler(action, page, scraped_page, task, step) results = await handler(action, page, scraped_page, task, step)
actions_result.extend(results) actions_result.extend(results)
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
if not results or not isinstance(actions_result[-1], ActionSuccess): if not results or not isinstance(actions_result[-1], ActionSuccess):
if llm_caller and action.tool_call_id:
# add failure message to the llm caller
tool_call_result = {
"type": "tool_result",
"tool_use_id": action.tool_call_id,
"content": {"result": "Tool execution failed"},
}
llm_caller.add_tool_result(tool_call_result)
return actions_result return actions_result
if llm_caller and action.tool_call_id:
tool_call_result = {
"type": "tool_result",
"tool_use_id": action.tool_call_id,
"content": {"result": "Tool executed successfully"},
}
llm_caller.add_tool_result(tool_call_result)
# do the teardown # do the teardown
teardown = ActionHandler._teardown_action_types.get(action.action_type) teardown = ActionHandler._teardown_action_types.get(action.action_type)
if teardown: if teardown:
@ -1532,7 +1550,7 @@ async def handle_keypress_action(
) -> list[ActionResult]: ) -> list[ActionResult]:
updated_keys = [] updated_keys = []
for key in action.keys: for key in action.keys:
if key.lower() == "enter": if key.lower() in ("enter", "return"):
updated_keys.append("Enter") updated_keys.append("Enter")
elif key.lower() == "space": elif key.lower() == "space":
updated_keys.append(" ") updated_keys.append(" ")

View file

@ -1,3 +1,4 @@
import json
from typing import Any, Dict from typing import Any, Dict
import structlog import structlog
@ -448,3 +449,133 @@ async def parse_cua_actions(
action.action_order = 0 action.action_order = 0
return [action] return [action]
return actions return actions
async def parse_anthropic_actions(
task: Task,
step: Step,
assistant_content: list[dict[str, Any]],
) -> list[Action]:
tool_calls = [block for block in assistant_content if block["type"] == "tool_use"]
idx = 0
actions: list[Action] = []
while idx < len(tool_calls):
tool_call = tool_calls[idx]
tool_call_id = tool_call["id"]
parsed_args = _parse_anthropic_computer_args(tool_call)
if not parsed_args:
idx += 1
continue
action = parsed_args["action"]
if action == "mouse_move":
x, y = parsed_args["coordinate"]
actions.append(
MoveAction(
x=x,
y=y,
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
action_order=idx,
tool_call_id=tool_call_id,
)
)
idx += 1
elif action == "left_click":
if idx - 1 >= 0:
prev_tool_call = tool_calls[idx - 1]
prev_parsed_args = _parse_anthropic_computer_args(prev_tool_call)
if prev_parsed_args and prev_parsed_args["action"] == "mouse_move":
coordinate = prev_parsed_args["coordinate"]
else:
coordinate = parsed_args.get("coordinate")
else:
coordinate = parsed_args.get("coordinate")
idx += 1
if not coordinate:
LOG.warning(
"Left click action has no coordinate and it doesn't have mouse_move before it",
tool_call=tool_call,
)
continue
x, y = coordinate
actions.append(
ClickAction(
element_id="",
x=x,
y=y,
button="left",
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
action_order=idx - 1,
tool_call_id=tool_call_id,
)
)
elif action == "type":
text = parsed_args.get("text")
idx += 1
if not text:
LOG.warning(
"Type action has no text",
tool_call=tool_call,
)
continue
actions.append(
InputTextAction(
element_id="",
text=text,
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
action_order=idx,
tool_call_id=tool_call_id,
)
)
elif action == "key":
text = parsed_args.get("text")
idx += 1
if not text:
LOG.warning(
"Key action has no text",
tool_call=tool_call,
)
continue
actions.append(
KeypressAction(
element_id="",
keys=[text],
organization_id=task.organization_id,
workflow_run_id=task.workflow_run_id,
task_id=task.task_id,
step_id=step.step_id,
step_order=step.order,
action_order=idx,
tool_call_id=tool_call_id,
)
)
else:
LOG.error(
"Unsupported action",
tool_call=tool_call,
)
idx += 1
return actions
def _parse_anthropic_computer_args(tool_call: dict[str, Any]) -> dict[str, Any] | None:
tool_call_type = tool_call["type"]
if tool_call_type != "function":
return None
tool_call_name = tool_call["function"]["name"]
if tool_call_name != "computer":
return None
tool_call_arguments = tool_call["function"]["arguments"]
return json.loads(tool_call_arguments)

View file

@ -18,6 +18,7 @@ class ActionResult(BaseModel):
interacted_with_sibling: bool | None = None interacted_with_sibling: bool | None = None
interacted_with_parent: bool | None = None interacted_with_parent: bool | None = None
skip_remaining_actions: bool | None = None skip_remaining_actions: bool | None = None
tool_call_result: dict[str, Any] | None = None
def __str__(self) -> str: def __str__(self) -> str:
results = [f"ActionResult(success={self.success}"] results = [f"ActionResult(success={self.success}"]