add organization_id filter for get_workflow and get_workflow_run (#1422)

This commit is contained in:
Shuchang Zheng 2024-12-22 17:49:33 -08:00 committed by GitHub
parent b256bace6a
commit 2e37542218
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 44 additions and 18 deletions

View file

@ -191,7 +191,7 @@ class WorkflowService:
) -> WorkflowRun:
"""Execute a workflow."""
organization_id = organization.organization_id
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
# Set workflow run status to running, create workflow run parameters
@ -219,7 +219,8 @@ class WorkflowService:
for block_idx, block in enumerate(blocks):
try:
refreshed_workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run.workflow_run_id
workflow_run_id=workflow_run.workflow_run_id,
organization_id=organization_id,
)
if refreshed_workflow_run and refreshed_workflow_run.status == WorkflowRunStatus.canceled:
LOG.info(
@ -358,7 +359,10 @@ class WorkflowService:
await self.clean_up_workflow(workflow=workflow, workflow_run=workflow_run, api_key=api_key)
return workflow_run
refreshed_workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run.workflow_run_id)
refreshed_workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run.workflow_run_id,
organization_id=organization_id,
)
if refreshed_workflow_run and refreshed_workflow_run.status not in (
WorkflowRunStatus.canceled,
WorkflowRunStatus.failed,
@ -570,8 +574,11 @@ class WorkflowService:
status=WorkflowRunStatus.canceled,
)
async def get_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
workflow_run = await app.DATABASE.get_workflow_run(workflow_run_id=workflow_run_id)
async def get_workflow_run(self, workflow_run_id: str, organization_id: str | None = None) -> WorkflowRun:
workflow_run = await app.DATABASE.get_workflow_run(
workflow_run_id=workflow_run_id,
organization_id=organization_id,
)
if not workflow_run:
raise WorkflowRunNotFound(workflow_run_id)
return workflow_run
@ -734,7 +741,7 @@ class WorkflowService:
workflow_run_id: str,
organization_id: str,
) -> WorkflowRunStatusResponse:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
if workflow_run is None:
LOG.error(f"Workflow run {workflow_run_id} not found")
raise WorkflowRunNotFound(workflow_run_id=workflow_run_id)
@ -756,7 +763,7 @@ class WorkflowService:
LOG.error(f"Workflow {workflow_permanent_id} not found")
raise WorkflowNotFound(workflow_permanent_id=workflow_permanent_id)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id, organization_id=organization_id)
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
screenshot_artifacts = []
screenshot_urls: list[str] | None = None