add organization_id to workflow service and db query (#320)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-05-15 08:43:36 -07:00 committed by GitHub
parent 6110fa4a44
commit 164a4da03a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 14 deletions

View file

@ -740,12 +740,13 @@ class AgentDB:
await session.refresh(workflow) await session.refresh(workflow)
return convert_to_workflow(workflow, self.debug_enabled) return convert_to_workflow(workflow, self.debug_enabled)
async def get_workflow(self, workflow_id: str) -> Workflow | None: async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow | None:
try: try:
async with self.Session() as session: async with self.Session() as session:
if workflow := ( get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id)) if organization_id:
).first(): get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
return convert_to_workflow(workflow, self.debug_enabled) return convert_to_workflow(workflow, self.debug_enabled)
return None return None
except SQLAlchemyError: except SQLAlchemyError:
@ -755,15 +756,17 @@ class AgentDB:
async def update_workflow( async def update_workflow(
self, self,
workflow_id: str, workflow_id: str,
organization_id: str | None = None,
title: str | None = None, title: str | None = None,
description: str | None = None, description: str | None = None,
workflow_definition: dict[str, Any] | None = None, workflow_definition: dict[str, Any] | None = None,
) -> Workflow: ) -> Workflow:
try: try:
async with self.Session() as session: async with self.Session() as session:
if workflow := ( get_workflow_query = select(WorkflowModel).filter_by(workflow_id=workflow_id)
await session.scalars(select(WorkflowModel).filter_by(workflow_id=workflow_id)) if organization_id:
).first(): get_workflow_query = get_workflow_query.filter_by(organization_id=organization_id)
if workflow := (await session.scalars(get_workflow_query)).first():
if title: if title:
workflow.title = title workflow.title = title
if description: if description:

View file

@ -82,7 +82,7 @@ class WorkflowService:
:return: The created workflow run. :return: The created workflow run.
""" """
# Validate the workflow and the organization # Validate the workflow and the organization
workflow = await self.get_workflow(workflow_id=workflow_id) workflow = await self.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if workflow is None: if workflow is None:
LOG.error(f"Workflow {workflow_id} not found") LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id) raise WorkflowNotFound(workflow_id=workflow_id)
@ -141,10 +141,11 @@ class WorkflowService:
self, self,
workflow_run_id: str, workflow_run_id: str,
api_key: str, api_key: str,
organization_id: str | None = None,
) -> WorkflowRun: ) -> WorkflowRun:
"""Execute a workflow.""" """Execute a workflow."""
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id) workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id) workflow = await self.get_workflow(workflow_id=workflow_run.workflow_id, organization_id=organization_id)
# Set workflow run status to running, create workflow run parameters # Set workflow run status to running, create workflow run parameters
await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id) await self.mark_workflow_run_as_running(workflow_run_id=workflow_run.workflow_run_id)
@ -270,11 +271,11 @@ class WorkflowService:
organization_id=organization_id, organization_id=organization_id,
title=title, title=title,
description=description, description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None, workflow_definition=workflow_definition.model_dump(),
) )
async def get_workflow(self, workflow_id: str) -> Workflow: async def get_workflow(self, workflow_id: str, organization_id: str | None = None) -> Workflow:
workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id) workflow = await app.DATABASE.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if not workflow: if not workflow:
raise WorkflowNotFound(workflow_id) raise WorkflowNotFound(workflow_id)
return workflow return workflow
@ -282,6 +283,7 @@ class WorkflowService:
async def update_workflow( async def update_workflow(
self, self,
workflow_id: str, workflow_id: str,
organization_id: str | None = None,
title: str | None = None, title: str | None = None,
description: str | None = None, description: str | None = None,
workflow_definition: WorkflowDefinition | None = None, workflow_definition: WorkflowDefinition | None = None,
@ -290,6 +292,7 @@ class WorkflowService:
workflow_definition.validate() workflow_definition.validate()
return await app.DATABASE.update_workflow( return await app.DATABASE.update_workflow(
workflow_id=workflow_id, workflow_id=workflow_id,
organization_id=organization_id,
title=title, title=title,
description=description, description=description,
workflow_definition=workflow_definition.model_dump() if workflow_definition else None, workflow_definition=workflow_definition.model_dump() if workflow_definition else None,
@ -449,9 +452,13 @@ class WorkflowService:
return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id) return await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
async def build_workflow_run_status_response( async def build_workflow_run_status_response(
self, workflow_id: str, workflow_run_id: str, last_block_result: BlockResult | None, organization_id: str self,
workflow_id: str,
workflow_run_id: str,
last_block_result: BlockResult | None,
organization_id: str,
) -> WorkflowRunStatusResponse: ) -> WorkflowRunStatusResponse:
workflow = await self.get_workflow(workflow_id=workflow_id) workflow = await self.get_workflow(workflow_id=workflow_id, organization_id=organization_id)
if workflow is None: if workflow is None:
LOG.error(f"Workflow {workflow_id} not found") LOG.error(f"Workflow {workflow_id} not found")
raise WorkflowNotFound(workflow_id=workflow_id) raise WorkflowNotFound(workflow_id=workflow_id)
@ -756,6 +763,7 @@ class WorkflowService:
workflow_definition = WorkflowDefinition(parameters=parameters.values(), blocks=blocks) workflow_definition = WorkflowDefinition(parameters=parameters.values(), blocks=blocks)
workflow = await self.update_workflow( workflow = await self.update_workflow(
workflow_id=workflow.workflow_id, workflow_id=workflow.workflow_id,
organization_id=organization_id,
workflow_definition=workflow_definition, workflow_definition=workflow_definition,
) )
LOG.info( LOG.info(