create cruise related artifact in cruise api (#1355)

This commit is contained in:
Shuchang Zheng 2024-12-08 21:17:58 -08:00 committed by GitHub
parent bda119027e
commit 5842bfc1fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 209 additions and 75 deletions

View file

@ -0,0 +1,37 @@
"""add created_at and modified_at to observer tables;
Revision ID: c502ecf908c6
Revises: dc2a8facf0d7
Create Date: 2024-12-09 00:40:30.098534+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c502ecf908c6"
down_revision: Union[str, None] = "dc2a8facf0d7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("observer_cruises", sa.Column("created_at", sa.DateTime(), nullable=False))
op.add_column("observer_cruises", sa.Column("modified_at", sa.DateTime(), nullable=False))
op.add_column("observer_thoughts", sa.Column("created_at", sa.DateTime(), nullable=False))
op.add_column("observer_thoughts", sa.Column("modified_at", sa.DateTime(), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("observer_thoughts", "modified_at")
op.drop_column("observer_thoughts", "created_at")
op.drop_column("observer_cruises", "modified_at")
op.drop_column("observer_cruises", "created_at")
# ### end Alembic commands ###

View file

@ -20,6 +20,7 @@ from skyvern.forge.sdk.api.llm.models import LLMAPIHandler, LLMConfig, LLMRouter
from skyvern.forge.sdk.api.llm.utils import llm_messages_builder, parse_api_response
from skyvern.forge.sdk.artifact.models import ArtifactType
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
LOG = structlog.get_logger()
@ -58,6 +59,8 @@ class LLMAPIHandlerFactory:
async def llm_api_handler_with_router_and_fallback(
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
@ -76,32 +79,29 @@ class LLMAPIHandlerFactory:
if parameters is None:
parameters = LLMAPIHandlerFactory.get_api_parameters(llm_config)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_key,
"messages": messages,
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
try:
response = await router.acompletion(model=main_model_group, messages=messages, **parameters)
LOG.info("LLM API call successful", llm_key=llm_key, model=llm_config.model_name)
@ -122,12 +122,14 @@ class LLMAPIHandlerFactory:
)
raise LLMProviderError(llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
@ -140,12 +142,13 @@ class LLMAPIHandlerFactory:
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
return parsed_response
return llm_api_handler_with_router_and_fallback
@ -162,6 +165,8 @@ class LLMAPIHandlerFactory:
async def llm_api_handler(
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverThought | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> dict[str, Any]:
@ -173,37 +178,33 @@ class LLMAPIHandlerFactory:
if llm_config.litellm_params: # type: ignore
active_parameters.update(llm_config.litellm_params) # type: ignore
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_PROMPT,
data=prompt.encode("utf-8"),
)
for screenshot in screenshots or []:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=prompt.encode("utf-8"),
artifact_type=ArtifactType.LLM_PROMPT,
screenshots=screenshots,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
# TODO (kerem): instead of overriding the screenshots, should we just not take them in the first place?
if not llm_config.supports_vision:
screenshots = None
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_REQUEST,
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(
{
"model": llm_config.model_name,
"messages": messages,
# we're not using active_parameters here because it may contain sensitive information
**parameters,
}
).encode("utf-8"),
artifact_type=ArtifactType.LLM_REQUEST,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
t_llm_request = time.perf_counter()
try:
# TODO (kerem): add a timeout to this call
@ -231,12 +232,16 @@ class LLMAPIHandlerFactory:
except Exception as e:
LOG.exception("LLM request failed unexpectedly", llm_key=llm_key)
raise LLMProviderError(llm_key) from e
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=response.model_dump_json(indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE,
data=response.model_dump_json(indent=2).encode("utf-8"),
)
llm_cost = litellm.completion_cost(completion_response=response)
prompt_tokens = response.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = response.get("usage", {}).get("completion_tokens", 0)
@ -249,12 +254,13 @@ class LLMAPIHandlerFactory:
incremental_output_tokens=completion_tokens if completion_tokens > 0 else None,
)
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step:
await app.ARTIFACT_MANAGER.create_artifact(
step=step,
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
)
await app.ARTIFACT_MANAGER.create_llm_artifact(
data=json.dumps(parsed_response, indent=2).encode("utf-8"),
artifact_type=ArtifactType.LLM_RESPONSE_PARSED,
step=step,
observer_cruise=observer_cruise,
observer_thought=observer_thought,
)
return parsed_response
return llm_api_handler

View file

@ -4,6 +4,7 @@ from typing import Any, Awaitable, Literal, Optional, Protocol, TypedDict
from litellm import AllowedFailsPolicy
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverCruise
from skyvern.forge.sdk.settings_manager import SettingsManager
@ -78,6 +79,8 @@ class LLMAPIHandler(Protocol):
self,
prompt: str,
step: Step | None = None,
observer_cruise: ObserverCruise | None = None,
observer_thought: ObserverCruise | None = None,
screenshots: list[bytes] | None = None,
parameters: dict[str, Any] | None = None,
) -> Awaitable[dict[str, Any]]: ...

View file

@ -9,7 +9,7 @@ from skyvern.forge import app
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.id import generate_artifact_id
from skyvern.forge.sdk.models import Step
from skyvern.forge.sdk.schemas.observers import ObserverThought
from skyvern.forge.sdk.schemas.observers import ObserverCruise, ObserverThought
LOG = structlog.get_logger(__name__)
@ -103,12 +103,78 @@ class ArtifactManager:
path=path,
)
async def create_observer_cruise_artifact(
self,
observer_cruise: ObserverCruise,
artifact_type: ArtifactType,
data: bytes | None = None,
path: str | None = None,
) -> str:
artifact_id = generate_artifact_id()
uri = app.STORAGE.build_observer_cruise_uri(artifact_id, observer_cruise, artifact_type)
return await self._create_artifact(
aio_task_primary_key=observer_cruise.observer_cruise_id,
artifact_id=artifact_id,
artifact_type=artifact_type,
uri=uri,
observer_cruise_id=observer_cruise.observer_cruise_id,
organization_id=observer_cruise.organization_id,
data=data,
path=path,
)
async def create_llm_artifact(
self,
data: bytes,
artifact_type: ArtifactType,
screenshots: list[bytes] | None = None,
step: Step | None = None,
observer_thought: ObserverThought | None = None,
observer_cruise: ObserverCruise | None = None,
) -> None:
if step:
await self.create_artifact(
step=step,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_artifact(
step=step,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
elif observer_cruise:
await self.create_observer_cruise_artifact(
observer_cruise=observer_cruise,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_observer_cruise_artifact(
observer_cruise=observer_cruise,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
elif observer_thought:
await self.create_observer_thought_artifact(
observer_thought=observer_thought,
artifact_type=artifact_type,
data=data,
)
for screenshot in screenshots or []:
await self.create_observer_thought_artifact(
observer_thought=observer_thought,
artifact_type=ArtifactType.SCREENSHOT_LLM,
data=screenshot,
)
async def update_artifact_data(
self,
artifact_id: str | None,
organization_id: str | None,
data: bytes,
primary_key: Literal["task_id", "observer_thought_id"] = "task_id",
primary_key: Literal["task_id", "observer_thought_id", "observer_cruise_id"] = "task_id",
) -> None:
if not artifact_id or not organization_id:
return None
@ -125,6 +191,10 @@ class ArtifactManager:
if not artifact.observer_thought_id:
raise ValueError("Observer Thought ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_thought_id].append(aio_task)
elif primary_key == "observer_cruise_id":
if not artifact.observer_cruise_id:
raise ValueError("Observer Cruise ID is required to update artifact data.")
self.upload_aiotasks_map[artifact.observer_cruise_id].append(aio_task)
async def retrieve_artifact(self, artifact: Artifact) -> bytes | None:
return await app.STORAGE.retrieve_artifact(artifact)

View file

@ -1815,6 +1815,10 @@ class AgentDB:
self,
observer_cruise_id: str,
status: ObserverCruiseStatus | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
url: str | None = None,
prompt: str | None = None,
organization_id: str | None = None,
) -> ObserverCruise:
async with self.Session() as session:
@ -1828,6 +1832,14 @@ class AgentDB:
if observer_cruise:
if status:
observer_cruise.status = status
if workflow_run_id:
observer_cruise.workflow_run_id = workflow_run_id
if workflow_id:
observer_cruise.workflow_id = workflow_id
if url:
observer_cruise.url = url
if prompt:
observer_cruise.prompt = prompt
await session.commit()
await session.refresh(observer_cruise)
return ObserverCruise.model_validate(observer_cruise)

View file

@ -518,6 +518,9 @@ class ObserverCruiseModel(Base):
prompt = Column(UnicodeText, nullable=True)
url = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class ObserverThoughtModel(Base):
__tablename__ = "observer_thoughts"
@ -532,3 +535,6 @@ class ObserverThoughtModel(Base):
observation = Column(String, nullable=True)
thought = Column(String, nullable=True)
answer = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)