add organization_id to workflow service and db query (#320)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-05-15 08:43:36 -07:00 committed by GitHub
parent 6110fa4a44
commit 164a4da03a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 14 deletions

View file

@ -82,7 +82,7 @@ class WorkflowService:
:return: The created workflow run.
"""
# Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id)
workflow = await self.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if workflow is None:
LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id)
@ -141,10 +141,11 @@ class WorkflowService:
self,
workflow_run_id: str,
api_key: str,
organization_id: str | None = None,
) -> WorkflowRun:
"""Execute a workflow."""
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
# Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
@ -270,11 +271,11 @@ class WorkflowService:
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
workflow_definition=workflow_definition.model_dump(),
)
async def get_workflow(self, workflow_id: str) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id)
async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if not workflow:
raise WorkflowNotFound(workflow_id)
return workflow
@ -282,6 +283,7 @@ class WorkflowService:
async def update_workflow(
self,
workflow_id: str,
organization_id: str | None = None,
title: str | None = None,
description: str | None = None,
workflow_definition: WorkflowDefinition | None = None,
@ -290,6 +292,7 @@ class WorkflowService:
workflow_definition.validate()
return await app.DATABASE.update_workflow(
workflow_id=workflow_id,
organization_id=organization_id,
title=title,
description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
@ -449,9 +452,13 @@ class WorkflowService:
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
async def build_workflow_run_status_response(
self, workflow_id: str, workflow_run_id: str, last_block_result: BlockResult | None, organization_id: str
self,
workflow_id: str,
workflow_run_id: str,
last_block_result: BlockResult | None,
organization_id: str,
) -> WorkflowRunStatusResponse:
workflow = await self.get_workflow(workflow_id=workflow_id)
workflow = await self.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if workflow is None:
LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id)
@ -756,6 +763,7 @@ class WorkflowService:
workflow_definition = WorkflowDefinition(parameters=parameters.values(), blocks=blocks)
workflow = await self.update_workflow(
workflow_id=workflow.workflow_id,
organization_id=organization_id,
workflow_definition=workflow_definition,
)
LOG.info(