mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-15 17:59:42 +00:00
anthropic CUA (#2231)
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
5582998490
commit
0a0228b341
18 changed files with 378 additions and 45 deletions
27
poetry.lock
generated
27
poetry.lock
generated
|
@ -270,6 +270,31 @@ files = [
|
||||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "anthropic"
|
||||||
|
version = "0.50.0"
|
||||||
|
description = "The official Python library for the anthropic API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "anthropic-0.50.0-py3-none-any.whl", hash = "sha256:defbd79327ca2fa61fd7b9eb2f1627dfb1f69c25d49288c52e167ddb84574f80"},
|
||||||
|
{file = "anthropic-0.50.0.tar.gz", hash = "sha256:42175ec04ce4ff2fa37cd436710206aadff546ee99d70d974699f59b49adc66f"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = ">=3.5.0,<5"
|
||||||
|
distro = ">=1.7.0,<2"
|
||||||
|
httpx = ">=0.25.0,<1"
|
||||||
|
jiter = ">=0.4.0,<1"
|
||||||
|
pydantic = ">=1.9.0,<3"
|
||||||
|
sniffio = "*"
|
||||||
|
typing-extensions = ">=4.10,<5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"]
|
||||||
|
vertex = ["google-auth[requests] (>=2,<3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyio"
|
name = "anyio"
|
||||||
version = "4.9.0"
|
version = "4.9.0"
|
||||||
|
@ -6804,4 +6829,4 @@ type = ["pytest-mypy"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = "^3.11,<3.12"
|
python-versions = "^3.11,<3.12"
|
||||||
content-hash = "b8883bdb02803bdb77dfe2de47aca0a28b509720513bf2d7fc2ee001bedf05fb"
|
content-hash = "926815050df2b2d2fbdb96ac5084cb0e19a628a04d29cbc78ea63936e11b213c"
|
||||||
|
|
|
@ -55,6 +55,7 @@ pypdf = "^5.1.0"
|
||||||
fastmcp = "^0.4.1"
|
fastmcp = "^0.4.1"
|
||||||
psutil = ">=7.0.0"
|
psutil = ">=7.0.0"
|
||||||
tiktoken = ">=0.9.0"
|
tiktoken = ">=0.9.0"
|
||||||
|
anthropic = "^0.50.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
isort = "^5.13.2"
|
isort = "^5.13.2"
|
||||||
|
|
|
@ -81,7 +81,7 @@ class SkyvernClient:
|
||||||
run_id: str,
|
run_id: str,
|
||||||
) -> RunResponse:
|
) -> RunResponse:
|
||||||
run_obj = await self.client.agent.get_run(run_id=run_id)
|
run_obj = await self.client.agent.get_run(run_id=run_id)
|
||||||
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua]:
|
if run_obj.run_type in [RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua]:
|
||||||
return TaskRunResponse.model_validate(run_obj.dict())
|
return TaskRunResponse.model_validate(run_obj.dict())
|
||||||
elif run_obj.run_type == RunType.workflow_run:
|
elif run_obj.run_type == RunType.workflow_run:
|
||||||
return WorkflowRunResponse.model_validate(run_obj.dict())
|
return WorkflowRunResponse.model_validate(run_obj.dict())
|
||||||
|
|
|
@ -119,6 +119,7 @@ class Settings(BaseSettings):
|
||||||
# LLM PROVIDER SPECIFIC
|
# LLM PROVIDER SPECIFIC
|
||||||
ENABLE_OPENAI: bool = False
|
ENABLE_OPENAI: bool = False
|
||||||
ENABLE_ANTHROPIC: bool = False
|
ENABLE_ANTHROPIC: bool = False
|
||||||
|
ENABLE_BEDROCK_ANTHROPIC: bool = False
|
||||||
ENABLE_AZURE: bool = False
|
ENABLE_AZURE: bool = False
|
||||||
ENABLE_AZURE_GPT4O_MINI: bool = False
|
ENABLE_AZURE_GPT4O_MINI: bool = False
|
||||||
ENABLE_AZURE_O3_MINI: bool = False
|
ENABLE_AZURE_O3_MINI: bool = False
|
||||||
|
|
|
@ -672,3 +672,8 @@ class SkyvernContextWindowExceededError(SkyvernException):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
message = "Context window exceeded. Please contact support@skyvern.com for help."
|
message = "Context window exceeded. Please contact support@skyvern.com for help."
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallerNotFoundError(SkyvernException):
|
||||||
|
def __init__(self, uid: str) -> None:
|
||||||
|
super().__init__(f"LLM caller for {uid} is not found")
|
||||||
|
|
|
@ -58,6 +58,7 @@ from skyvern.forge.sdk.api.files import (
|
||||||
rename_file,
|
rename_file,
|
||||||
wait_for_download_finished,
|
wait_for_download_finished,
|
||||||
)
|
)
|
||||||
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCaller, LLMCallerManager
|
||||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
from skyvern.forge.sdk.core.security import generate_skyvern_webhook_headers
|
||||||
|
@ -70,7 +71,7 @@ from skyvern.forge.sdk.schemas.tasks import Task, TaskRequest, TaskResponse, Tas
|
||||||
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
from skyvern.forge.sdk.workflow.context_manager import WorkflowRunContext
|
||||||
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
|
from skyvern.forge.sdk.workflow.models.block import ActionBlock, BaseTaskBlock, ValidationBlock
|
||||||
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
|
from skyvern.forge.sdk.workflow.models.workflow import Workflow, WorkflowRun, WorkflowRunStatus
|
||||||
from skyvern.schemas.runs import RunEngine, RunType
|
from skyvern.schemas.runs import CUA_ENGINES, CUA_RUN_TYPES, RunEngine
|
||||||
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
from skyvern.utils.prompt_engine import load_prompt_with_elements
|
||||||
from skyvern.webeye.actions.actions import (
|
from skyvern.webeye.actions.actions import (
|
||||||
Action,
|
Action,
|
||||||
|
@ -88,7 +89,7 @@ from skyvern.webeye.actions.actions import (
|
||||||
from skyvern.webeye.actions.caching import retrieve_action_plan
|
from skyvern.webeye.actions.caching import retrieve_action_plan
|
||||||
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
|
from skyvern.webeye.actions.handler import ActionHandler, poll_verification_code
|
||||||
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
|
from skyvern.webeye.actions.models import AgentStepOutput, DetailedAgentStepOutput
|
||||||
from skyvern.webeye.actions.parse_actions import parse_actions, parse_cua_actions
|
from skyvern.webeye.actions.parse_actions import parse_actions, parse_anthropic_actions, parse_cua_actions
|
||||||
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
|
from skyvern.webeye.actions.responses import ActionResult, ActionSuccess
|
||||||
from skyvern.webeye.browser_factory import BrowserState
|
from skyvern.webeye.browser_factory import BrowserState
|
||||||
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
from skyvern.webeye.scraper.scraper import ElementTreeFormat, ScrapedPage, scrape_website
|
||||||
|
@ -253,6 +254,7 @@ class ForgeAgent:
|
||||||
complete_verification: bool = True,
|
complete_verification: bool = True,
|
||||||
engine: RunEngine = RunEngine.skyvern_v1,
|
engine: RunEngine = RunEngine.skyvern_v1,
|
||||||
cua_response: OpenAIResponse | None = None,
|
cua_response: OpenAIResponse | None = None,
|
||||||
|
llm_caller: LLMCaller | None = None,
|
||||||
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
) -> Tuple[Step, DetailedAgentStepOutput | None, Step | None]:
|
||||||
workflow_run: WorkflowRun | None = None
|
workflow_run: WorkflowRun | None = None
|
||||||
if task.workflow_run_id:
|
if task.workflow_run_id:
|
||||||
|
@ -378,6 +380,13 @@ class ForgeAgent:
|
||||||
if page := await browser_state.get_working_page():
|
if page := await browser_state.get_working_page():
|
||||||
await self.register_async_operations(organization, task, page)
|
await self.register_async_operations(organization, task, page)
|
||||||
|
|
||||||
|
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||||
|
if engine == RunEngine.anthropic_cua and not llm_caller:
|
||||||
|
# llm_caller = LLMCaller(llm_key="BEDROCK_ANTHROPIC_CLAUDE3.5_SONNET_INFERENCE_PROFILE")
|
||||||
|
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||||
|
if not llm_caller:
|
||||||
|
llm_caller = LLMCaller(llm_key="ANTHROPIC_CLAUDE3.5_SONNET")
|
||||||
|
LLMCallerManager.set_llm_caller(task.task_id, llm_caller)
|
||||||
step, detailed_output = await self.agent_step(
|
step, detailed_output = await self.agent_step(
|
||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
|
@ -387,6 +396,7 @@ class ForgeAgent:
|
||||||
complete_verification=complete_verification,
|
complete_verification=complete_verification,
|
||||||
engine=engine,
|
engine=engine,
|
||||||
cua_response=cua_response,
|
cua_response=cua_response,
|
||||||
|
llm_caller=llm_caller,
|
||||||
)
|
)
|
||||||
await app.AGENT_FUNCTION.post_step_execution(task, step)
|
await app.AGENT_FUNCTION.post_step_execution(task, step)
|
||||||
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
|
task = await self.update_task_errors_from_detailed_output(task, detailed_output)
|
||||||
|
@ -778,6 +788,7 @@ class ForgeAgent:
|
||||||
task_block: BaseTaskBlock | None = None,
|
task_block: BaseTaskBlock | None = None,
|
||||||
complete_verification: bool = True,
|
complete_verification: bool = True,
|
||||||
cua_response: OpenAIResponse | None = None,
|
cua_response: OpenAIResponse | None = None,
|
||||||
|
llm_caller: LLMCaller | None = None,
|
||||||
) -> tuple[Step, DetailedAgentStepOutput]:
|
) -> tuple[Step, DetailedAgentStepOutput]:
|
||||||
detailed_agent_step_output = DetailedAgentStepOutput(
|
detailed_agent_step_output = DetailedAgentStepOutput(
|
||||||
scraped_page=None,
|
scraped_page=None,
|
||||||
|
@ -821,8 +832,17 @@ class ForgeAgent:
|
||||||
step=step,
|
step=step,
|
||||||
scraped_page=scraped_page,
|
scraped_page=scraped_page,
|
||||||
previous_response=cua_response,
|
previous_response=cua_response,
|
||||||
|
engine=engine,
|
||||||
)
|
)
|
||||||
detailed_agent_step_output.cua_response = new_cua_response
|
detailed_agent_step_output.cua_response = new_cua_response
|
||||||
|
elif engine == RunEngine.anthropic_cua:
|
||||||
|
assert llm_caller is not None
|
||||||
|
actions = await self._generate_anthropic_actions(
|
||||||
|
task=task,
|
||||||
|
step=step,
|
||||||
|
scraped_page=scraped_page,
|
||||||
|
llm_caller=llm_caller,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
using_cached_action_plan = False
|
using_cached_action_plan = False
|
||||||
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
|
if not task.navigation_goal and not isinstance(task_block, ValidationBlock):
|
||||||
|
@ -834,7 +854,7 @@ class ForgeAgent:
|
||||||
):
|
):
|
||||||
using_cached_action_plan = True
|
using_cached_action_plan = True
|
||||||
else:
|
else:
|
||||||
if engine != RunEngine.openai_cua:
|
if engine in CUA_ENGINES:
|
||||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.llm)
|
||||||
json_response = await app.LLM_API_HANDLER(
|
json_response = await app.LLM_API_HANDLER(
|
||||||
prompt=extract_action_prompt,
|
prompt=extract_action_prompt,
|
||||||
|
@ -1219,7 +1239,8 @@ class ForgeAgent:
|
||||||
step: Step,
|
step: Step,
|
||||||
scraped_page: ScrapedPage,
|
scraped_page: ScrapedPage,
|
||||||
previous_response: OpenAIResponse | None = None,
|
previous_response: OpenAIResponse | None = None,
|
||||||
) -> tuple[list[Action], OpenAIResponse]:
|
engine: RunEngine = RunEngine.openai_cua,
|
||||||
|
) -> tuple[list[Action], OpenAIResponse | None]:
|
||||||
if not previous_response:
|
if not previous_response:
|
||||||
# this is the first step
|
# this is the first step
|
||||||
first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create(
|
first_response: OpenAIResponse = await app.OPENAI_CLIENT.responses.create(
|
||||||
|
@ -1377,6 +1398,48 @@ class ForgeAgent:
|
||||||
|
|
||||||
return await parse_cua_actions(task, step, current_response), current_response
|
return await parse_cua_actions(task, step, current_response), current_response
|
||||||
|
|
||||||
|
async def _generate_anthropic_actions(
|
||||||
|
self,
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
scraped_page: ScrapedPage,
|
||||||
|
llm_caller: LLMCaller,
|
||||||
|
) -> list[Action]:
|
||||||
|
if llm_caller.current_tool_results:
|
||||||
|
llm_caller.message_history.append({"role": "user", "content": llm_caller.current_tool_results})
|
||||||
|
llm_caller.clear_tool_results()
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "computer_20250124",
|
||||||
|
"name": "computer",
|
||||||
|
"display_height_px": settings.BROWSER_HEIGHT,
|
||||||
|
"display_width_px": settings.BROWSER_WIDTH,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if not llm_caller.message_history:
|
||||||
|
llm_response = await llm_caller.call(
|
||||||
|
prompt=task.navigation_goal,
|
||||||
|
screenshots=scraped_page.screenshots,
|
||||||
|
use_message_history=True,
|
||||||
|
tools=tools,
|
||||||
|
raw_response=True,
|
||||||
|
betas=["computer-use-2025-01-24"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
llm_response = await llm_caller.call(
|
||||||
|
screenshots=scraped_page.screenshots,
|
||||||
|
use_message_history=True,
|
||||||
|
tools=tools,
|
||||||
|
raw_response=True,
|
||||||
|
betas=["computer-use-2025-01-24"],
|
||||||
|
)
|
||||||
|
LOG.info("Anthropic response", llm_response=llm_response)
|
||||||
|
assistant_content = llm_response["content"]
|
||||||
|
llm_caller.message_history.append({"role": "assistant", "content": assistant_content})
|
||||||
|
|
||||||
|
actions = await parse_anthropic_actions(task, step, assistant_content)
|
||||||
|
return actions
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult:
|
async def complete_verify(page: Page, scraped_page: ScrapedPage, task: Task, step: Step) -> CompleteVerifyResult:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
|
@ -1387,7 +1450,7 @@ class ForgeAgent:
|
||||||
)
|
)
|
||||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||||
scroll = True
|
scroll = True
|
||||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||||
scroll = False
|
scroll = False
|
||||||
|
|
||||||
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
|
scraped_page_refreshed = await scraped_page.refresh(draw_boxes=False, scroll=scroll)
|
||||||
|
@ -1454,7 +1517,7 @@ class ForgeAgent:
|
||||||
raise BrowserStateMissingPage()
|
raise BrowserStateMissingPage()
|
||||||
|
|
||||||
fullpage_screenshot = True
|
fullpage_screenshot = True
|
||||||
if engine == RunEngine.openai_cua:
|
if engine in CUA_ENGINES:
|
||||||
fullpage_screenshot = False
|
fullpage_screenshot = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -1580,7 +1643,7 @@ class ForgeAgent:
|
||||||
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
|
max_screenshot_number = settings.MAX_NUM_SCREENSHOTS
|
||||||
draw_boxes = True
|
draw_boxes = True
|
||||||
scroll = True
|
scroll = True
|
||||||
if engine == RunEngine.openai_cua:
|
if engine in CUA_ENGINES:
|
||||||
max_screenshot_number = 1
|
max_screenshot_number = 1
|
||||||
draw_boxes = False
|
draw_boxes = False
|
||||||
scroll = False
|
scroll = False
|
||||||
|
@ -1602,7 +1665,7 @@ class ForgeAgent:
|
||||||
engine: RunEngine,
|
engine: RunEngine,
|
||||||
) -> tuple[ScrapedPage, str]:
|
) -> tuple[ScrapedPage, str]:
|
||||||
# start the async tasks while running scrape_website
|
# start the async tasks while running scrape_website
|
||||||
if engine != RunEngine.openai_cua:
|
if engine not in CUA_ENGINES:
|
||||||
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
|
self.async_operation_pool.run_operation(task.task_id, AgentPhase.scrape)
|
||||||
|
|
||||||
# Scrape the web page and get the screenshot and the elements
|
# Scrape the web page and get the screenshot and the elements
|
||||||
|
@ -1653,7 +1716,7 @@ class ForgeAgent:
|
||||||
element_tree_format = ElementTreeFormat.HTML
|
element_tree_format = ElementTreeFormat.HTML
|
||||||
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
|
element_tree_in_prompt: str = scraped_page.build_element_tree(element_tree_format)
|
||||||
extract_action_prompt = ""
|
extract_action_prompt = ""
|
||||||
if engine != RunEngine.openai_cua:
|
if engine not in CUA_ENGINES:
|
||||||
extract_action_prompt = await self._build_extract_action_prompt(
|
extract_action_prompt = await self._build_extract_action_prompt(
|
||||||
task,
|
task,
|
||||||
step,
|
step,
|
||||||
|
@ -2371,7 +2434,7 @@ class ForgeAgent:
|
||||||
|
|
||||||
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
run_obj = await app.DATABASE.get_run(run_id=task.task_id, organization_id=task.organization_id)
|
||||||
scroll = True
|
scroll = True
|
||||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
if run_obj and run_obj.task_run_type in CUA_RUN_TYPES:
|
||||||
scroll = False
|
scroll = False
|
||||||
|
|
||||||
screenshots: list[bytes] = []
|
screenshots: list[bytes] = []
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Awaitable, Callable
|
from typing import Awaitable, Callable
|
||||||
|
|
||||||
|
from anthropic import AsyncAnthropic, AsyncAnthropicBedrock
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||||
|
|
||||||
|
@ -41,6 +42,9 @@ if SettingsManager.get_settings().ENABLE_AZURE_CUA:
|
||||||
azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT,
|
azure_endpoint=SettingsManager.get_settings().AZURE_CUA_ENDPOINT,
|
||||||
azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT,
|
azure_deployment=SettingsManager.get_settings().AZURE_CUA_DEPLOYMENT,
|
||||||
)
|
)
|
||||||
|
ANTHROPIC_CLIENT = AsyncAnthropic(api_key=SettingsManager.get_settings().ANTHROPIC_API_KEY)
|
||||||
|
if SettingsManager.get_settings().ENABLE_BEDROCK_ANTHROPIC:
|
||||||
|
ANTHROPIC_CLIENT = AsyncAnthropicBedrock()
|
||||||
|
|
||||||
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
SECONDARY_LLM_API_HANDLER = LLMAPIHandlerFactory.get_llm_api_handler(
|
||||||
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY
|
SETTINGS_MANAGER.SECONDARY_LLM_KEY if SETTINGS_MANAGER.SECONDARY_LLM_KEY else SETTINGS_MANAGER.LLM_KEY
|
||||||
|
|
|
@ -6,7 +6,9 @@ from typing import Any
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import structlog
|
import structlog
|
||||||
|
from anthropic.types.message import Message as AnthropicMessage
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse
|
||||||
|
|
||||||
from skyvern.config import settings
|
from skyvern.config import settings
|
||||||
from skyvern.exceptions import SkyvernContextWindowExceededError
|
from skyvern.exceptions import SkyvernContextWindowExceededError
|
||||||
|
@ -456,11 +458,18 @@ class LLMCaller:
|
||||||
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
self.llm_config = LLMConfigRegistry.get_config(llm_key)
|
||||||
self.base_parameters = base_parameters
|
self.base_parameters = base_parameters
|
||||||
self.message_history: list[dict[str, Any]] = []
|
self.message_history: list[dict[str, Any]] = []
|
||||||
|
self.current_tool_results: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def add_tool_result(self, tool_result: dict[str, Any]) -> None:
|
||||||
|
self.current_tool_results.append(tool_result)
|
||||||
|
|
||||||
|
def clear_tool_results(self) -> None:
|
||||||
|
self.current_tool_results = []
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str | None = None,
|
||||||
prompt_name: str,
|
prompt_name: str | None = None,
|
||||||
step: Step | None = None,
|
step: Step | None = None,
|
||||||
task_v2: TaskV2 | None = None,
|
task_v2: TaskV2 | None = None,
|
||||||
thought: Thought | None = None,
|
thought: Thought | None = None,
|
||||||
|
@ -469,6 +478,8 @@ class LLMCaller:
|
||||||
parameters: dict[str, Any] | None = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
tools: list | None = None,
|
tools: list | None = None,
|
||||||
use_message_history: bool = False,
|
use_message_history: bool = False,
|
||||||
|
raw_response: bool = False,
|
||||||
|
**extra_parameters: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
active_parameters = self.base_parameters or {}
|
active_parameters = self.base_parameters or {}
|
||||||
|
@ -476,6 +487,8 @@ class LLMCaller:
|
||||||
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
|
parameters = LLMAPIHandlerFactory.get_api_parameters(self.llm_config)
|
||||||
|
|
||||||
active_parameters.update(parameters)
|
active_parameters.update(parameters)
|
||||||
|
if extra_parameters:
|
||||||
|
active_parameters.update(extra_parameters)
|
||||||
if self.llm_config.litellm_params: # type: ignore
|
if self.llm_config.litellm_params: # type: ignore
|
||||||
active_parameters.update(self.llm_config.litellm_params) # type: ignore
|
active_parameters.update(self.llm_config.litellm_params) # type: ignore
|
||||||
|
|
||||||
|
@ -491,7 +504,7 @@ class LLMCaller:
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=prompt.encode("utf-8"),
|
data=prompt.encode("utf-8") if prompt else b"",
|
||||||
artifact_type=ArtifactType.LLM_PROMPT,
|
artifact_type=ArtifactType.LLM_PROMPT,
|
||||||
screenshots=screenshots,
|
screenshots=screenshots,
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -525,8 +538,7 @@ class LLMCaller:
|
||||||
)
|
)
|
||||||
t_llm_request = time.perf_counter()
|
t_llm_request = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
response = await litellm.acompletion(
|
response = await self._dispatch_llm_call(
|
||||||
model=self.llm_config.model_name,
|
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
timeout=settings.LLM_CONFIG_TIMEOUT,
|
timeout=settings.LLM_CONFIG_TIMEOUT,
|
||||||
|
@ -603,6 +615,21 @@ class LLMCaller:
|
||||||
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
cached_token_count=cached_tokens if cached_tokens > 0 else None,
|
||||||
thought_cost=llm_cost,
|
thought_cost=llm_cost,
|
||||||
)
|
)
|
||||||
|
# Track LLM API handler duration
|
||||||
|
duration_seconds = time.perf_counter() - start_time
|
||||||
|
LOG.info(
|
||||||
|
"LLM API handler duration metrics",
|
||||||
|
llm_key=self.llm_key,
|
||||||
|
prompt_name=prompt_name,
|
||||||
|
model=self.llm_config.model_name,
|
||||||
|
duration_seconds=duration_seconds,
|
||||||
|
step_id=step.step_id if step else None,
|
||||||
|
thought_id=thought.observer_thought_id if thought else None,
|
||||||
|
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
||||||
|
)
|
||||||
|
if raw_response:
|
||||||
|
return response.model_dump()
|
||||||
|
|
||||||
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
|
parsed_response = parse_api_response(response, self.llm_config.add_assistant_prefix)
|
||||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||||
|
@ -626,17 +653,53 @@ class LLMCaller:
|
||||||
ai_suggestion=ai_suggestion,
|
ai_suggestion=ai_suggestion,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Track LLM API handler duration
|
return parsed_response
|
||||||
duration_seconds = time.perf_counter() - start_time
|
|
||||||
LOG.info(
|
async def _dispatch_llm_call(
|
||||||
"LLM API handler duration metrics",
|
self,
|
||||||
llm_key=self.llm_key,
|
messages: list[dict[str, Any]],
|
||||||
prompt_name=prompt_name,
|
tools: list | None = None,
|
||||||
model=self.llm_config.model_name,
|
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||||
duration_seconds=duration_seconds,
|
**active_parameters: dict[str, Any],
|
||||||
step_id=step.step_id if step else None,
|
) -> ModelResponse | CustomStreamWrapper | AnthropicMessage:
|
||||||
thought_id=thought.observer_thought_id if thought else None,
|
if self.llm_key and self.llm_key.startswith("ANTHROPIC"):
|
||||||
organization_id=step.organization_id if step else (thought.organization_id if thought else None),
|
return await self._call_anthropic(messages, tools, timeout)
|
||||||
|
|
||||||
|
return await litellm.acompletion(
|
||||||
|
model=self.llm_config.model_name, messages=messages, tools=tools, timeout=timeout, **active_parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
return parsed_response
|
async def _call_anthropic(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list | None = None,
|
||||||
|
timeout: float = settings.LLM_CONFIG_TIMEOUT,
|
||||||
|
**active_parameters: dict[str, Any],
|
||||||
|
) -> AnthropicMessage:
|
||||||
|
max_tokens = active_parameters.get("max_completion_tokens") or active_parameters.get("max_tokens") or 4096
|
||||||
|
model_name = self.llm_config.model_name.replace("bedrock/", "").replace("anthropic/", "")
|
||||||
|
return await app.ANTHROPIC_CLIENT.messages.create(
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
messages=messages,
|
||||||
|
model=model_name,
|
||||||
|
tools=tools,
|
||||||
|
timeout=timeout,
|
||||||
|
betas=active_parameters.get("betas", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallerManager:
|
||||||
|
_llm_callers: dict[str, LLMCaller] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_llm_caller(cls, uid: str) -> LLMCaller | None:
|
||||||
|
return cls._llm_callers.get(uid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_llm_caller(cls, uid: str, llm_caller: LLMCaller) -> None:
|
||||||
|
cls._llm_callers[uid] = llm_caller
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_llm_caller(cls, uid: str) -> None:
|
||||||
|
if uid in cls._llm_callers:
|
||||||
|
del cls._llm_callers[uid]
|
||||||
|
|
|
@ -47,19 +47,22 @@ async def llm_messages_builder(
|
||||||
|
|
||||||
|
|
||||||
async def llm_messages_builder_with_history(
|
async def llm_messages_builder_with_history(
|
||||||
prompt: str,
|
prompt: str | None = None,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
message_history: list[dict[str, Any]] | None = None,
|
message_history: list[dict[str, Any]] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
messages: list[dict[str, Any]] = []
|
messages: list[dict[str, Any]] = []
|
||||||
if message_history:
|
if message_history:
|
||||||
messages = copy.deepcopy(message_history)
|
messages = copy.deepcopy(message_history)
|
||||||
current_user_messages: list[dict[str, Any]] = [
|
|
||||||
{
|
current_user_messages: list[dict[str, Any]] = []
|
||||||
"type": "text",
|
if prompt:
|
||||||
"text": prompt,
|
current_user_messages.append(
|
||||||
}
|
{
|
||||||
]
|
"type": "text",
|
||||||
|
"text": prompt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if screenshots:
|
if screenshots:
|
||||||
for screenshot in screenshots:
|
for screenshot in screenshots:
|
||||||
|
|
|
@ -96,6 +96,8 @@ class BackgroundTaskExecutor(AsyncExecutor):
|
||||||
engine = RunEngine.skyvern_v1
|
engine = RunEngine.skyvern_v1
|
||||||
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
if run_obj and run_obj.task_run_type == RunType.openai_cua:
|
||||||
engine = RunEngine.openai_cua
|
engine = RunEngine.openai_cua
|
||||||
|
elif run_obj and run_obj.task_run_type == RunType.anthropic_cua:
|
||||||
|
engine = RunEngine.anthropic_cua
|
||||||
|
|
||||||
context: SkyvernContext = skyvern_context.ensure_context()
|
context: SkyvernContext = skyvern_context.ensure_context()
|
||||||
context.task_id = task.task_id
|
context.task_id = task.task_id
|
||||||
|
|
|
@ -62,6 +62,7 @@ from skyvern.forge.sdk.workflow.models.workflow import (
|
||||||
)
|
)
|
||||||
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
from skyvern.forge.sdk.workflow.models.yaml import WorkflowCreateYAMLRequest
|
||||||
from skyvern.schemas.runs import (
|
from skyvern.schemas.runs import (
|
||||||
|
CUA_ENGINES,
|
||||||
RunEngine,
|
RunEngine,
|
||||||
RunResponse,
|
RunResponse,
|
||||||
RunType,
|
RunType,
|
||||||
|
@ -1466,7 +1467,7 @@ async def run_task(
|
||||||
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
|
analytics.capture("skyvern-oss-run-task", data={"url": run_request.url})
|
||||||
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
|
await PermissionCheckerFactory.get_instance().check(current_org, browser_session_id=run_request.browser_session_id)
|
||||||
|
|
||||||
if run_request.engine in [RunEngine.skyvern_v1, RunEngine.openai_cua]:
|
if run_request.engine in CUA_ENGINES:
|
||||||
# create task v1
|
# create task v1
|
||||||
# if there's no url, call task generation first to generate the url, data schema if any
|
# if there's no url, call task generation first to generate the url, data schema if any
|
||||||
url = run_request.url
|
url = run_request.url
|
||||||
|
@ -1480,7 +1481,7 @@ async def run_task(
|
||||||
)
|
)
|
||||||
url = url or task_generation.url
|
url = url or task_generation.url
|
||||||
navigation_goal = task_generation.navigation_goal or run_request.prompt
|
navigation_goal = task_generation.navigation_goal or run_request.prompt
|
||||||
if run_request.engine == RunEngine.openai_cua:
|
if run_request.engine in CUA_ENGINES:
|
||||||
navigation_goal = run_request.prompt
|
navigation_goal = run_request.prompt
|
||||||
navigation_payload = task_generation.navigation_payload
|
navigation_payload = task_generation.navigation_payload
|
||||||
data_extraction_goal = task_generation.data_extraction_goal
|
data_extraction_goal = task_generation.data_extraction_goal
|
||||||
|
@ -1511,6 +1512,8 @@ async def run_task(
|
||||||
run_type = RunType.task_v1
|
run_type = RunType.task_v1
|
||||||
if run_request.engine == RunEngine.openai_cua:
|
if run_request.engine == RunEngine.openai_cua:
|
||||||
run_type = RunType.openai_cua
|
run_type = RunType.openai_cua
|
||||||
|
elif run_request.engine == RunEngine.anthropic_cua:
|
||||||
|
run_type = RunType.anthropic_cua
|
||||||
# build the task run response
|
# build the task run response
|
||||||
return TaskRunResponse(
|
return TaskRunResponse(
|
||||||
run_id=task_v1_response.task_id,
|
run_id=task_v1_response.task_id,
|
||||||
|
@ -1586,8 +1589,6 @@ async def run_task(
|
||||||
publish_workflow=run_request.publish_workflow,
|
publish_workflow=run_request.publish_workflow,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if run_request.engine == RunEngine.openai_cua:
|
|
||||||
pass
|
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
|
raise HTTPException(status_code=400, detail=f"Invalid agent engine: {run_request.engine}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -93,12 +93,18 @@ class RunType(StrEnum):
|
||||||
task_v2 = "task_v2"
|
task_v2 = "task_v2"
|
||||||
workflow_run = "workflow_run"
|
workflow_run = "workflow_run"
|
||||||
openai_cua = "openai_cua"
|
openai_cua = "openai_cua"
|
||||||
|
anthropic_cua = "anthropic_cua"
|
||||||
|
|
||||||
|
|
||||||
class RunEngine(StrEnum):
|
class RunEngine(StrEnum):
|
||||||
skyvern_v1 = "skyvern-1.0"
|
skyvern_v1 = "skyvern-1.0"
|
||||||
skyvern_v2 = "skyvern-2.0"
|
skyvern_v2 = "skyvern-2.0"
|
||||||
openai_cua = "openai-cua"
|
openai_cua = "openai-cua"
|
||||||
|
anthropic_cua = "anthropic-cua"
|
||||||
|
|
||||||
|
|
||||||
|
CUA_ENGINES = [RunEngine.openai_cua, RunEngine.anthropic_cua]
|
||||||
|
CUA_RUN_TYPES = [RunType.openai_cua, RunType.anthropic_cua]
|
||||||
|
|
||||||
|
|
||||||
class RunStatus(StrEnum):
|
class RunStatus(StrEnum):
|
||||||
|
@ -217,8 +223,8 @@ class BaseRunResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class TaskRunResponse(BaseRunResponse):
|
class TaskRunResponse(BaseRunResponse):
|
||||||
run_type: Literal[RunType.task_v1, RunType.task_v2, RunType.openai_cua] = Field(
|
run_type: Literal[RunType.task_v1, RunType.task_v2, RunType.openai_cua, RunType.anthropic_cua] = Field(
|
||||||
description="Types of a task run - task_v1, task_v2, openai_cua"
|
description="Types of a task run - task_v1, task_v2, openai_cua, anthropic_cua"
|
||||||
)
|
)
|
||||||
run_request: TaskRunRequest | None = Field(
|
run_request: TaskRunRequest | None = Field(
|
||||||
default=None, description="The original request parameters used to start this task run"
|
default=None, description="The original request parameters used to start this task run"
|
||||||
|
|
|
@ -13,7 +13,11 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
||||||
if not run:
|
if not run:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if run.task_run_type == RunType.task_v1 or run.task_run_type == RunType.openai_cua:
|
if (
|
||||||
|
run.task_run_type == RunType.task_v1
|
||||||
|
or run.task_run_type == RunType.openai_cua
|
||||||
|
or run.task_run_type == RunType.anthropic_cua
|
||||||
|
):
|
||||||
# fetch task v1 from db and transform to task run response
|
# fetch task v1 from db and transform to task run response
|
||||||
task_v1 = await app.DATABASE.get_task(run.run_id, organization_id=organization_id)
|
task_v1 = await app.DATABASE.get_task(run.run_id, organization_id=organization_id)
|
||||||
if not task_v1:
|
if not task_v1:
|
||||||
|
@ -21,6 +25,8 @@ async def get_run_response(run_id: str, organization_id: str | None = None) -> R
|
||||||
run_engine = RunEngine.skyvern_v1
|
run_engine = RunEngine.skyvern_v1
|
||||||
if run.task_run_type == RunType.openai_cua:
|
if run.task_run_type == RunType.openai_cua:
|
||||||
run_engine = RunEngine.openai_cua
|
run_engine = RunEngine.openai_cua
|
||||||
|
elif run.task_run_type == RunType.anthropic_cua:
|
||||||
|
run_engine = RunEngine.anthropic_cua
|
||||||
return TaskRunResponse(
|
return TaskRunResponse(
|
||||||
run_id=run.run_id,
|
run_id=run.run_id,
|
||||||
run_type=run.task_run_type,
|
run_type=run.task_run_type,
|
||||||
|
@ -136,7 +142,7 @@ async def cancel_run(run_id: str, organization_id: str | None = None, api_key: s
|
||||||
detail=f"Run not found {run_id}",
|
detail=f"Run not found {run_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if run.task_run_type in [RunType.task_v1, RunType.openai_cua]:
|
if run.task_run_type in [RunType.task_v1, RunType.openai_cua, RunType.anthropic_cua]:
|
||||||
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
|
await cancel_task_v1(run_id, organization_id=organization_id, api_key=api_key)
|
||||||
elif run.task_run_type == RunType.task_v2:
|
elif run.task_run_type == RunType.task_v2:
|
||||||
await cancel_task_v2(run_id, organization_id=organization_id)
|
await cancel_task_v2(run_id, organization_id=organization_id)
|
||||||
|
|
|
@ -87,6 +87,8 @@ async def run_task(
|
||||||
run_type = RunType.task_v1
|
run_type = RunType.task_v1
|
||||||
if engine == RunEngine.openai_cua:
|
if engine == RunEngine.openai_cua:
|
||||||
run_type = RunType.openai_cua
|
run_type = RunType.openai_cua
|
||||||
|
elif engine == RunEngine.anthropic_cua:
|
||||||
|
run_type = RunType.anthropic_cua
|
||||||
await app.DATABASE.create_task_run(
|
await app.DATABASE.create_task_run(
|
||||||
task_run_type=run_type,
|
task_run_type=run_type,
|
||||||
organization_id=organization.organization_id,
|
organization_id=organization.organization_id,
|
||||||
|
|
|
@ -113,6 +113,7 @@ class Action(BaseModel):
|
||||||
element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None
|
element_id: Annotated[str, Field(coerce_numbers_to_str=True)] | None = None
|
||||||
skyvern_element_hash: str | None = None
|
skyvern_element_hash: str | None = None
|
||||||
skyvern_element_data: dict[str, Any] | None = None
|
skyvern_element_data: dict[str, Any] | None = None
|
||||||
|
tool_call_id: str | None = None
|
||||||
|
|
||||||
# DecisiveAction (CompleteAction, TerminateAction) fields
|
# DecisiveAction (CompleteAction, TerminateAction) fields
|
||||||
errors: list[UserDefinedError] | None = None
|
errors: list[UserDefinedError] | None = None
|
||||||
|
|
|
@ -59,6 +59,7 @@ from skyvern.forge.sdk.api.files import (
|
||||||
list_files_in_directory,
|
list_files_in_directory,
|
||||||
wait_for_download_finished,
|
wait_for_download_finished,
|
||||||
)
|
)
|
||||||
|
from skyvern.forge.sdk.api.llm.api_handler_factory import LLMCallerManager
|
||||||
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
||||||
from skyvern.forge.sdk.core import skyvern_context
|
from skyvern.forge.sdk.core import skyvern_context
|
||||||
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
from skyvern.forge.sdk.core.aiohttp_helper import aiohttp_post
|
||||||
|
@ -363,9 +364,26 @@ class ActionHandler:
|
||||||
handler = ActionHandler._handled_action_types[action.action_type]
|
handler = ActionHandler._handled_action_types[action.action_type]
|
||||||
results = await handler(action, page, scraped_page, task, step)
|
results = await handler(action, page, scraped_page, task, step)
|
||||||
actions_result.extend(results)
|
actions_result.extend(results)
|
||||||
|
llm_caller = LLMCallerManager.get_llm_caller(task.task_id)
|
||||||
if not results or not isinstance(actions_result[-1], ActionSuccess):
|
if not results or not isinstance(actions_result[-1], ActionSuccess):
|
||||||
|
if llm_caller and action.tool_call_id:
|
||||||
|
# add failure message to the llm caller
|
||||||
|
tool_call_result = {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": action.tool_call_id,
|
||||||
|
"content": {"result": "Tool execution failed"},
|
||||||
|
}
|
||||||
|
llm_caller.add_tool_result(tool_call_result)
|
||||||
return actions_result
|
return actions_result
|
||||||
|
|
||||||
|
if llm_caller and action.tool_call_id:
|
||||||
|
tool_call_result = {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": action.tool_call_id,
|
||||||
|
"content": {"result": "Tool executed successfully"},
|
||||||
|
}
|
||||||
|
llm_caller.add_tool_result(tool_call_result)
|
||||||
|
|
||||||
# do the teardown
|
# do the teardown
|
||||||
teardown = ActionHandler._teardown_action_types.get(action.action_type)
|
teardown = ActionHandler._teardown_action_types.get(action.action_type)
|
||||||
if teardown:
|
if teardown:
|
||||||
|
@ -1532,7 +1550,7 @@ async def handle_keypress_action(
|
||||||
) -> list[ActionResult]:
|
) -> list[ActionResult]:
|
||||||
updated_keys = []
|
updated_keys = []
|
||||||
for key in action.keys:
|
for key in action.keys:
|
||||||
if key.lower() == "enter":
|
if key.lower() in ("enter", "return"):
|
||||||
updated_keys.append("Enter")
|
updated_keys.append("Enter")
|
||||||
elif key.lower() == "space":
|
elif key.lower() == "space":
|
||||||
updated_keys.append(" ")
|
updated_keys.append(" ")
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
@ -448,3 +449,133 @@ async def parse_cua_actions(
|
||||||
action.action_order = 0
|
action.action_order = 0
|
||||||
return [action]
|
return [action]
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
async def parse_anthropic_actions(
|
||||||
|
task: Task,
|
||||||
|
step: Step,
|
||||||
|
assistant_content: list[dict[str, Any]],
|
||||||
|
) -> list[Action]:
|
||||||
|
tool_calls = [block for block in assistant_content if block["type"] == "tool_use"]
|
||||||
|
idx = 0
|
||||||
|
actions: list[Action] = []
|
||||||
|
while idx < len(tool_calls):
|
||||||
|
tool_call = tool_calls[idx]
|
||||||
|
tool_call_id = tool_call["id"]
|
||||||
|
parsed_args = _parse_anthropic_computer_args(tool_call)
|
||||||
|
if not parsed_args:
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
action = parsed_args["action"]
|
||||||
|
if action == "mouse_move":
|
||||||
|
x, y = parsed_args["coordinate"]
|
||||||
|
actions.append(
|
||||||
|
MoveAction(
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
organization_id=task.organization_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
step_order=step.order,
|
||||||
|
action_order=idx,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
elif action == "left_click":
|
||||||
|
if idx - 1 >= 0:
|
||||||
|
prev_tool_call = tool_calls[idx - 1]
|
||||||
|
prev_parsed_args = _parse_anthropic_computer_args(prev_tool_call)
|
||||||
|
if prev_parsed_args and prev_parsed_args["action"] == "mouse_move":
|
||||||
|
coordinate = prev_parsed_args["coordinate"]
|
||||||
|
else:
|
||||||
|
coordinate = parsed_args.get("coordinate")
|
||||||
|
else:
|
||||||
|
coordinate = parsed_args.get("coordinate")
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
if not coordinate:
|
||||||
|
LOG.warning(
|
||||||
|
"Left click action has no coordinate and it doesn't have mouse_move before it",
|
||||||
|
tool_call=tool_call,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
x, y = coordinate
|
||||||
|
actions.append(
|
||||||
|
ClickAction(
|
||||||
|
element_id="",
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
button="left",
|
||||||
|
organization_id=task.organization_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
step_order=step.order,
|
||||||
|
action_order=idx - 1,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif action == "type":
|
||||||
|
text = parsed_args.get("text")
|
||||||
|
idx += 1
|
||||||
|
if not text:
|
||||||
|
LOG.warning(
|
||||||
|
"Type action has no text",
|
||||||
|
tool_call=tool_call,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
actions.append(
|
||||||
|
InputTextAction(
|
||||||
|
element_id="",
|
||||||
|
text=text,
|
||||||
|
organization_id=task.organization_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
step_order=step.order,
|
||||||
|
action_order=idx,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif action == "key":
|
||||||
|
text = parsed_args.get("text")
|
||||||
|
idx += 1
|
||||||
|
if not text:
|
||||||
|
LOG.warning(
|
||||||
|
"Key action has no text",
|
||||||
|
tool_call=tool_call,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
actions.append(
|
||||||
|
KeypressAction(
|
||||||
|
element_id="",
|
||||||
|
keys=[text],
|
||||||
|
organization_id=task.organization_id,
|
||||||
|
workflow_run_id=task.workflow_run_id,
|
||||||
|
task_id=task.task_id,
|
||||||
|
step_id=step.step_id,
|
||||||
|
step_order=step.order,
|
||||||
|
action_order=idx,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.error(
|
||||||
|
"Unsupported action",
|
||||||
|
tool_call=tool_call,
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_anthropic_computer_args(tool_call: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
tool_call_type = tool_call["type"]
|
||||||
|
if tool_call_type != "function":
|
||||||
|
return None
|
||||||
|
tool_call_name = tool_call["function"]["name"]
|
||||||
|
if tool_call_name != "computer":
|
||||||
|
return None
|
||||||
|
tool_call_arguments = tool_call["function"]["arguments"]
|
||||||
|
return json.loads(tool_call_arguments)
|
||||||
|
|
|
@ -18,6 +18,7 @@ class ActionResult(BaseModel):
|
||||||
interacted_with_sibling: bool | None = None
|
interacted_with_sibling: bool | None = None
|
||||||
interacted_with_parent: bool | None = None
|
interacted_with_parent: bool | None = None
|
||||||
skip_remaining_actions: bool | None = None
|
skip_remaining_actions: bool | None = None
|
||||||
|
tool_call_result: dict[str, Any] | None = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
results = [f"ActionResult(success={self.success}"]
|
results = [f"ActionResult(success={self.success}"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue