add cost info to the workflow run repsonse (#1456)

This commit is contained in:
Shuchang Zheng 2024-12-31 11:24:09 -08:00 committed by GitHub
parent 175ce55f06
commit 171aef6bf7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 35 additions and 1 deletions

View file

@ -318,6 +318,21 @@ class AgentDB:
LOG.error("UnexpectedError", exc_info=True) LOG.error("UnexpectedError", exc_info=True)
raise raise
async def get_steps_by_task_ids(self, task_ids: list[str], organization_id: str | None = None) -> list[Step]:
try:
async with self.Session() as session:
steps = (
await session.scalars(
select(StepModel)
.filter(StepModel.task_id.in_(task_ids))
.filter_by(organization_id=organization_id)
)
).all()
return [convert_to_step(step, debug_enabled=self.debug_enabled) for step in steps]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]: async def get_task_step_models(self, task_id: str, organization_id: str | None = None) -> Sequence[StepModel]:
try: try:
async with self.Session() as session: async with self.Session() as session:

View file

@ -724,6 +724,7 @@ async def get_workflow_run(
workflow_permanent_id=workflow_id, workflow_permanent_id=workflow_id,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
organization_id=current_org.organization_id, organization_id=current_org.organization_id,
include_cost=True,
) )

View file

@ -134,3 +134,5 @@ class WorkflowRunStatusResponse(BaseModel):
recording_url: str | None = None recording_url: str | None = None
downloaded_file_urls: list[str] | None = None downloaded_file_urls: list[str] | None = None
outputs: dict[str, Any] | None = None outputs: dict[str, Any] | None = None
total_steps: int | None = None
total_cost: float | None = None

View file

@ -22,7 +22,7 @@ from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.core.security import generate_skyvern_signature from skyvern.forge.sdk.core.security import generate_skyvern_signature
from skyvern.forge.sdk.core.skyvern_context import SkyvernContext from skyvern.forge.sdk.core.skyvern_context import SkyvernContext
from skyvern.forge.sdk.db.enums import TaskType from skyvern.forge.sdk.db.enums import TaskType
from skyvern.forge.sdk.models import Step from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task
from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock, WorkflowRunTimeline, WorkflowRunTimelineType
@ -741,6 +741,7 @@ class WorkflowService:
self, self,
workflow_run_id: str, workflow_run_id: str,
organization_id: str, organization_id: str,
include_cost: bool = False,
) -> WorkflowRunStatusResponse: ) -> WorkflowRunStatusResponse:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id) workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
if workflow_run is None: if workflow_run is None:
@ -751,6 +752,7 @@ class WorkflowService:
workflow_permanent_id=workflow_permanent_id, workflow_permanent_id=workflow_permanent_id,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
organization_id=organization_id, organization_id=organization_id,
include_cost=include_cost,
) )
async def build_workflow_run_status_response( async def build_workflow_run_status_response(
@ -758,6 +760,7 @@ class WorkflowService:
workflow_permanent_id: str, workflow_permanent_id: str,
workflow_run_id: str, workflow_run_id: str,
organization_id: str, organization_id: str,
include_cost: bool = False,
) -> WorkflowRunStatusResponse: ) -> WorkflowRunStatusResponse:
workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id, organization_id=organization_id) workflow = await self.get_workflow_by_permanent_id(workflow_permanent_id, organization_id=organization_id)
if workflow is None: if workflow is None:
@ -824,6 +827,17 @@ class WorkflowService:
if output_parameter_tuples: if output_parameter_tuples:
outputs = {output_parameter.key: output.value for output_parameter, output in output_parameter_tuples} outputs = {output_parameter.key: output.value for output_parameter, output in output_parameter_tuples}
total_steps = None
total_cost = None
if include_cost:
workflow_run_steps = await app.DATABASE.get_steps_by_task_ids(
task_ids=[task.task_id for task in workflow_run_tasks], organization_id=organization_id
)
total_steps = len(workflow_run_steps)
# TODO: This is a temporary cost calculation. We need to implement a more accurate cost calculation.
# successful steps are the ones that have a status of completed and the total count of unique step.order
successful_steps = set(step.order for step in workflow_run_steps if step.status == StepStatus.completed)
total_cost = 0.1 * len(successful_steps)
return WorkflowRunStatusResponse( return WorkflowRunStatusResponse(
workflow_id=workflow.workflow_permanent_id, workflow_id=workflow.workflow_permanent_id,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
@ -840,6 +854,8 @@ class WorkflowService:
recording_url=recording_url, recording_url=recording_url,
downloaded_file_urls=downloaded_file_urls, downloaded_file_urls=downloaded_file_urls,
outputs=outputs, outputs=outputs,
total_steps=total_steps,
total_cost=total_cost,
) )
async def clean_up_workflow( async def clean_up_workflow(