create cruise related artifact in cruise api (#1355)

This commit is contained in:
Shuchang Zheng 2024-12-08 21:17:58 -08:00 committed by GitHub
parent bda119027e
commit 5842bfc1fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 209 additions and 75 deletions

View file

@ -0,0 +1,37 @@
"""add created_at and modified_at to observer tables;
Revision ID: c502ecf908c6
Revises: dc2a8facf0d7
Create Date: 2024-12-09 00:40:30.098534+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c502ecf908c6"
down_revision: Union[str, None] = "dc2a8facf0d7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("observer_cruises", sa.Column("created_at", sa.DateTime(), nullable=False))
op.add_column("observer_cruises", sa.Column("modified_at", sa.DateTime(), nullable=False))
op.add_column("observer_thoughts", sa.Column("created_at", sa.DateTime(), nullable=False))
op.add_column("observer_thoughts", sa.Column("modified_at", sa.DateTime(), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("observer_thoughts", "modified_at")
op.drop_column("observer_thoughts", "created_at")
op.drop_column("observer_cruises", "modified_at")
op.drop_column("observer_cruises", "created_at")
# ### end Alembic commands ###

View file

@ -20,6 +20,7 @@ from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouter
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
LOG = structlog.get_logger() LOG = structlog.get_logger()
@ -58,6 +59,8 @@ class LLMAPIHandlerFactory:
async def llm_api_handler_with_router_and_fallback( async def llm_api_handler_with_router_and_fallback(
prompt: str, prompt: str,
step: Step | None = None, step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -76,24 +79,17 @@ class LLMAPIHandlerFactory:
if parameters is None: if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config) parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
if step: await app.ARTIFACT_MANAGER.create_llm_artifact(
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"), data=prompt.encode("utf-8"),
) artifact_type=ArtifactType.LLM_PROMPT,
for screenshot in screenshots or []: screenshots=screenshots,
await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM, observer_cruise=observer_cruise,
data=screenshot, observer_thought=observer_thought,
) )
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps( data=json.dumps(
{ {
"model": llm_key, "model": llm_key,
@ -101,6 +97,10 @@ class LLMAPIHandlerFactory:
**parameters, **parameters,
} }
).encode("utf-8"), ).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
) )
try: try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters) response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
@ -122,12 +122,14 @@ class LLMAPIHandlerFactory:
) )
raise LLMProviderError(llm_key) from e raise LLMProviderError(llm_key) from e
if step: await app.ARTIFACT_MANAGER.create_llm_artifact(
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"), data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
) )
if step:
llm_cost = litellm.completion_cost(completion_response=response) llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0) completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
@ -140,11 +142,12 @@ class LLMAPIHandlerFactory:
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
) )
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step: await app.ARTIFACT_MANAGER.create_llm_artifact(
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
) )
return parsed_response return parsed_response
@ -162,6 +165,8 @@ class LLMAPIHandlerFactory:
async def llm_api_handler( async def llm_api_handler(
prompt: str, prompt: str,
step: Step | None = None, step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
@ -173,28 +178,20 @@ class LLMAPIHandlerFactory:
if llm_config.litellm_params: # type: ignore if llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore active_parameters.update(llm_config.litellm_params) # type: ignore
if step: await app.ARTIFACT_MANAGER.create_llm_artifact(
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"), data=prompt.encode("utf-8"),
) artifact_type=ArtifactType.LLM_PROMPT,
for screenshot in screenshots or []: screenshots=screenshots,
await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM, observer_cruise=observer_cruise,
data=screenshot, observer_thought=observer_thought,
) )
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision: if not llm_config.supports_vision:
screenshots = None screenshots = None
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix) messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step: await app.ARTIFACT_MANAGER.create_llm_artifact(
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps( data=json.dumps(
{ {
"model": llm_config.model_name, "model": llm_config.model_name,
@ -203,6 +200,10 @@ class LLMAPIHandlerFactory:
**parameters, **parameters,
} }
).encode("utf-8"), ).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
) )
t_llm_request = time.perf_counter() t_llm_request = time.perf_counter()
try: try:
@ -231,12 +232,16 @@ class LLMAPIHandlerFactory:
except Exception as e: except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key) LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e raise LLMProviderError(llm_key) from e
if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_llm_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"), data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
) )
if step:
llm_cost = litellm.completion_cost(completion_response=response) llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0) prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0) completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
@ -249,11 +254,12 @@ class LLMAPIHandlerFactory:
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None, incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
) )
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step: await app.ARTIFACT_MANAGER.create_llm_artifact(
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"), data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
) )
return parsed_response return parsed_response

View file

@ -4,6 +4,7 @@ from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
from litellm import AllowedFailsPolicy from litellm import AllowedFailsPolicy
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise
from skyvern.forge.sdk.settings_manager import SettingsManager from skyvern.forge.sdk.settings_manager import SettingsManager
@ -78,6 +79,8 @@ class LLMAPIHandler(Protocol):
self, self,
prompt: str, prompt: str,
step: Step | None = None, step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverCruise | None = None,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None, parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]: ... ) -> Awaitable[dict[str, Any]]: ...

View file

@ -9,7 +9,7 @@ from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverThought from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
LOG = structlog.get_logger(__name__) LOG = structlog.get_logger(__name__)
@ -103,12 +103,78 @@ class ArtifactManager:
path=path, path=path,
) )
async def create_observer_cruise_artifact(
self,
observer_cruise: ObserverCruise,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_observer_cruise_uri(artifact_id, observer_cruise, artifact_type)
return await self._create_artifact(
aio_task_primary_key=observer_cruise.observer_cruise_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
observer_cruise_id=observer_cruise.observer_cruise_id,
organization_id=observer_cruise.organization_id,
data=data,
path=path,
)
async def create_llm_artifact(
self,
data: bytes,
artifact_type: ArtifactType,
screenshots: list[bytes] | None = None,
step: Step | None = None,
observer_thought: ObserverThought | None = None,
observer_cruise: ObserverCruise | None = None,
) -> None:
if step:
await self.create_artifact(
step=step,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
elif observer_cruise:
await self.create_observer_cruise_artifact(
observer_cruise=observer_cruise,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_observer_cruise_artifact(
observer_cruise=observer_cruise,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
elif observer_thought:
await self.create_observer_thought_artifact(
observer_thought=observer_thought,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_observer_thought_artifact(
observer_thought=observer_thought,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
async def update_artifact_data( async def update_artifact_data(
self, self,
artifact_id: str | None, artifact_id: str | None,
organization_id: str | None, organization_id: str | None,
data: bytes, data: bytes,
primary_key: Literal["task_id", "observer_thought_id"] = "task_id", primary_key: Literal["task_id", "observer_thought_id", "observer_cruise_id"] = "task_id",
) -> None: ) -> None:
if not artifact_id or not organization_id: if not artifact_id or not organization_id:
return None return None
@ -125,6 +191,10 @@ class ArtifactManager:
if not artifact.observer_thought_id: if not artifact.observer_thought_id:
raise ValueError("Observer Thought ID is required to update artifact data.") raise ValueError("Observer Thought ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task) self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task)
elif primary_key == "observer_cruise_id":
if not artifact.observer_cruise_id:
raise ValueError("Observer Cruise ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_cruise_id].append(aio_task)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None: async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact) return await app.STORAGE.retrieve_artifact(artifact)

View file

@ -1815,6 +1815,10 @@ class AgentDB:
self, self,
observer_cruise_id: str, observer_cruise_id: str,
status: ObserverCruiseStatus | None = None, status: ObserverCruiseStatus | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
url: str | None = None,
prompt: str | None = None,
organization_id: str | None = None, organization_id: str | None = None,
) -> ObserverCruise: ) -> ObserverCruise:
async with self.Session() as session: async with self.Session() as session:
@ -1828,6 +1832,14 @@ class AgentDB:
if observer_cruise: if observer_cruise:
if status: if status:
observer_cruise.status = status observer_cruise.status = status
if workflow_run_id:
observer_cruise.workflow_run_id = workflow_run_id
if workflow_id:
observer_cruise.workflow_id = workflow_id
if url:
observer_cruise.url = url
if prompt:
observer_cruise.prompt = prompt
await session.commit() await session.commit()
await session.refresh(observer_cruise) await session.refresh(observer_cruise)
return ObserverCruise.model_validate(observer_cruise) return ObserverCruise.model_validate(observer_cruise)

View file

@ -518,6 +518,9 @@ class ObserverCruiseModel(Base):
prompt = Column(UnicodeText, nullable=True) prompt = Column(UnicodeText, nullable=True)
url = Column(String, nullable=True) url = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class ObserverThoughtModel(Base): class ObserverThoughtModel(Base):
__tablename__ = "observer_thoughts" __tablename__ = "observer_thoughts"
@ -532,3 +535,6 @@ class ObserverThoughtModel(Base):
observation = Column(String, nullable=True) observation = Column(String, nullable=True)
thought = Column(String, nullable=True) thought = Column(String, nullable=True)
answer = Column(String, nullable=True) answer = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)