mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-01 18:20:06 +00:00
create cruise related artifact in cruise api (#1355)
This commit is contained in:
parent
bda119027e
commit
5842bfc1fd
6 changed files with 209 additions and 75 deletions
|
@ -0,0 +1,37 @@
|
|||
"""add created_at and modified_at to observer tables;
|
||||
|
||||
Revision ID: c502ecf908c6
|
||||
Revises: dc2a8facf0d7
|
||||
Create Date: 2024-12-09 00:40:30.098534+00:00
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c502ecf908c6"
|
||||
down_revision: Union[str, None] = "dc2a8facf0d7"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("observer_cruises", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("observer_cruises", sa.Column("modified_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("observer_thoughts", sa.Column("created_at", sa.DateTime(), nullable=False))
|
||||
op.add_column("observer_thoughts", sa.Column("modified_at", sa.DateTime(), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("observer_thoughts", "modified_at")
|
||||
op.drop_column("observer_thoughts", "created_at")
|
||||
op.drop_column("observer_cruises", "modified_at")
|
||||
op.drop_column("observer_cruises", "created_at")
|
||||
# ### end Alembic commands ###
|
|
@ -20,6 +20,7 @@ from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouter
|
|||
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
|
||||
from skyvern.forge.sdk.artifact.models import ArtifactType
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
|
||||
|
||||
LOG = structlog.get_logger()
|
||||
|
||||
|
@ -58,6 +59,8 @@ class LLMAPIHandlerFactory:
|
|||
async def llm_api_handler_with_router_and_fallback(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
observer_cruise: ObserverCruise | None = None,
|
||||
observer_thought: ObserverThought | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
@ -76,32 +79,29 @@ class LLMAPIHandlerFactory:
|
|||
if parameters is None:
|
||||
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=prompt.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_key,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
)
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_key,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
try:
|
||||
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
|
||||
LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name)
|
||||
|
@ -122,12 +122,14 @@ class LLMAPIHandlerFactory:
|
|||
)
|
||||
raise LLMProviderError(llm_key) from e
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
|
@ -140,12 +142,13 @@ class LLMAPIHandlerFactory:
|
|||
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||
)
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
return llm_api_handler_with_router_and_fallback
|
||||
|
@ -162,6 +165,8 @@ class LLMAPIHandlerFactory:
|
|||
async def llm_api_handler(
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
observer_cruise: ObserverCruise | None = None,
|
||||
observer_thought: ObserverThought | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
@ -173,37 +178,33 @@ class LLMAPIHandlerFactory:
|
|||
if llm_config.litellm_params: # type: ignore
|
||||
active_parameters.update(llm_config.litellm_params) # type: ignore
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
data=prompt.encode("utf-8"),
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=prompt.encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_PROMPT,
|
||||
screenshots=screenshots,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
|
||||
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
|
||||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_config.model_name,
|
||||
"messages": messages,
|
||||
# we're not using active_parameters here because it may contain sensitive information
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": llm_config.model_name,
|
||||
"messages": messages,
|
||||
# we're not using active_parameters here because it may contain sensitive information
|
||||
**parameters,
|
||||
}
|
||||
).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_REQUEST,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
t_llm_request = time.perf_counter()
|
||||
try:
|
||||
# TODO (kerem): add a timeout to this call
|
||||
|
@ -231,12 +232,16 @@ class LLMAPIHandlerFactory:
|
|||
except Exception as e:
|
||||
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
|
||||
raise LLMProviderError(llm_key) from e
|
||||
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE,
|
||||
data=response.model_dump_json(indent=2).encode("utf-8"),
|
||||
)
|
||||
llm_cost = litellm.completion_cost(completion_response=response)
|
||||
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
|
||||
|
@ -249,12 +254,13 @@ class LLMAPIHandlerFactory:
|
|||
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
|
||||
)
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
)
|
||||
await app.ARTIFACT_MANAGER.create_llm_artifact(
|
||||
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
|
||||
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
|
||||
step=step,
|
||||
observer_cruise=observer_cruise,
|
||||
observer_thought=observer_thought,
|
||||
)
|
||||
return parsed_response
|
||||
|
||||
return llm_api_handler
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
|
|||
from litellm import AllowedFailsPolicy
|
||||
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverCruise
|
||||
from skyvern.forge.sdk.settings_manager import SettingsManager
|
||||
|
||||
|
||||
|
@ -78,6 +79,8 @@ class LLMAPIHandler(Protocol):
|
|||
self,
|
||||
prompt: str,
|
||||
step: Step | None = None,
|
||||
observer_cruise: ObserverCruise | None = None,
|
||||
observer_thought: ObserverCruise | None = None,
|
||||
screenshots: list[bytes] | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
) -> Awaitable[dict[str, Any]]: ...
|
||||
|
|
|
@ -9,7 +9,7 @@ from skyvern.forge import app
|
|||
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
||||
from skyvern.forge.sdk.db.id import generate_artifact_id
|
||||
from skyvern.forge.sdk.models import Step
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverThought
|
||||
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
|
||||
|
||||
LOG = structlog.get_logger(__name__)
|
||||
|
||||
|
@ -103,12 +103,78 @@ class ArtifactManager:
|
|||
path=path,
|
||||
)
|
||||
|
||||
async def create_observer_cruise_artifact(
|
||||
self,
|
||||
observer_cruise: ObserverCruise,
|
||||
artifact_type: ArtifactType,
|
||||
data: bytes | None = None,
|
||||
path: str | None = None,
|
||||
) -> str:
|
||||
artifact_id = generate_artifact_id()
|
||||
uri = app.STORAGE.build_observer_cruise_uri(artifact_id, observer_cruise, artifact_type)
|
||||
return await self._create_artifact(
|
||||
aio_task_primary_key=observer_cruise.observer_cruise_id,
|
||||
artifact_id=artifact_id,
|
||||
artifact_type=artifact_type,
|
||||
uri=uri,
|
||||
observer_cruise_id=observer_cruise.observer_cruise_id,
|
||||
organization_id=observer_cruise.organization_id,
|
||||
data=data,
|
||||
path=path,
|
||||
)
|
||||
|
||||
async def create_llm_artifact(
|
||||
self,
|
||||
data: bytes,
|
||||
artifact_type: ArtifactType,
|
||||
screenshots: list[bytes] | None = None,
|
||||
step: Step | None = None,
|
||||
observer_thought: ObserverThought | None = None,
|
||||
observer_cruise: ObserverCruise | None = None,
|
||||
) -> None:
|
||||
if step:
|
||||
await self.create_artifact(
|
||||
step=step,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await self.create_artifact(
|
||||
step=step,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
elif observer_cruise:
|
||||
await self.create_observer_cruise_artifact(
|
||||
observer_cruise=observer_cruise,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await self.create_observer_cruise_artifact(
|
||||
observer_cruise=observer_cruise,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
elif observer_thought:
|
||||
await self.create_observer_thought_artifact(
|
||||
observer_thought=observer_thought,
|
||||
artifact_type=artifact_type,
|
||||
data=data,
|
||||
)
|
||||
for screenshot in screenshots or []:
|
||||
await self.create_observer_thought_artifact(
|
||||
observer_thought=observer_thought,
|
||||
artifact_type=ArtifactType.SCREENSHOT_LLM,
|
||||
data=screenshot,
|
||||
)
|
||||
|
||||
async def update_artifact_data(
|
||||
self,
|
||||
artifact_id: str | None,
|
||||
organization_id: str | None,
|
||||
data: bytes,
|
||||
primary_key: Literal["task_id", "observer_thought_id"] = "task_id",
|
||||
primary_key: Literal["task_id", "observer_thought_id", "observer_cruise_id"] = "task_id",
|
||||
) -> None:
|
||||
if not artifact_id or not organization_id:
|
||||
return None
|
||||
|
@ -125,6 +191,10 @@ class ArtifactManager:
|
|||
if not artifact.observer_thought_id:
|
||||
raise ValueError("Observer Thought ID is required to update artifact data.")
|
||||
self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task)
|
||||
elif primary_key == "observer_cruise_id":
|
||||
if not artifact.observer_cruise_id:
|
||||
raise ValueError("Observer Cruise ID is required to update artifact data.")
|
||||
self.upload_aiotasks_map[artifact.observer_cruise_id].append(aio_task)
|
||||
|
||||
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
|
||||
return await app.STORAGE.retrieve_artifact(artifact)
|
||||
|
|
|
@ -1815,6 +1815,10 @@ class AgentDB:
|
|||
self,
|
||||
observer_cruise_id: str,
|
||||
status: ObserverCruiseStatus | None = None,
|
||||
workflow_run_id: str | None = None,
|
||||
workflow_id: str | None = None,
|
||||
url: str | None = None,
|
||||
prompt: str | None = None,
|
||||
organization_id: str | None = None,
|
||||
) -> ObserverCruise:
|
||||
async with self.Session() as session:
|
||||
|
@ -1828,6 +1832,14 @@ class AgentDB:
|
|||
if observer_cruise:
|
||||
if status:
|
||||
observer_cruise.status = status
|
||||
if workflow_run_id:
|
||||
observer_cruise.workflow_run_id = workflow_run_id
|
||||
if workflow_id:
|
||||
observer_cruise.workflow_id = workflow_id
|
||||
if url:
|
||||
observer_cruise.url = url
|
||||
if prompt:
|
||||
observer_cruise.prompt = prompt
|
||||
await session.commit()
|
||||
await session.refresh(observer_cruise)
|
||||
return ObserverCruise.model_validate(observer_cruise)
|
||||
|
|
|
@ -518,6 +518,9 @@ class ObserverCruiseModel(Base):
|
|||
prompt = Column(UnicodeText, nullable=True)
|
||||
url = Column(String, nullable=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
class ObserverThoughtModel(Base):
|
||||
__tablename__ = "observer_thoughts"
|
||||
|
@ -532,3 +535,6 @@ class ObserverThoughtModel(Base):
|
|||
observation = Column(String, nullable=True)
|
||||
thought = Column(String, nullable=True)
|
||||
answer = Column(String, nullable=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
|
|
Loading…
Add table
Reference in a new issue