diff --git a/alembic/versions/2024_11_29_1332-a5feab7712fe_add_application_column_to_tasks.py b/alembic/versions/2024_11_29_1332-a5feab7712fe_add_application_column_to_tasks.py new file mode 100644 index 00000000..af659e73 --- /dev/null +++ b/alembic/versions/2024_11_29_1332-a5feab7712fe_add_application_column_to_tasks.py @@ -0,0 +1,27 @@ +"""Add application column to tasks + +Revision ID: a5feab7712fe +Revises: 56085e451bec +Create Date: 2024-11-29 13:32:58.845703+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a5feab7712fe" +down_revision: Union[str, None] = "56085e451bec" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("tasks", sa.Column("application", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("tasks", "application") diff --git a/skyvern/forge/agent.py b/skyvern/forge/agent.py index 4d2704a3..55f50c96 100644 --- a/skyvern/forge/agent.py +++ b/skyvern/forge/agent.py @@ -216,6 +216,7 @@ class ForgeAgent: proxy_location=task_request.proxy_location, extracted_information_schema=task_request.extracted_information_schema, error_code_mapping=task_request.error_code_mapping, + application=task_request.application, ) LOG.info( "Created new task", diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 52589758..7cad9dce 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -114,6 +114,7 @@ class AgentDB: max_steps_per_run: int | None = None, error_code_mapping: dict[str, str] | None = None, task_type: str = TaskType.general, + application: str | None = None, ) -> Task: try: async with self.Session() as session: @@ -138,6 +139,7 @@ class AgentDB: retry=retry, max_steps_per_run=max_steps_per_run, error_code_mapping=error_code_mapping, + application=application, ) session.add(new_task) await session.commit() @@ -478,6 +480,7 @@ class AgentDB: workflow_run_id: str | None = None, organization_id: str | None = None, only_standalone_tasks: bool = False, + application: str | None = None, order_by_column: OrderBy = OrderBy.created_at, order: SortDirection = SortDirection.desc, ) -> list[Task]: @@ -505,6 +508,8 @@ class AgentDB: query = query.filter(TaskModel.workflow_run_id == workflow_run_id) if only_standalone_tasks: query = query.filter(TaskModel.workflow_run_id.is_(None)) + if application: + query = query.filter(TaskModel.application == application) order_by_col = getattr(TaskModel, order_by_column) query = ( query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc()) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 05ad24ab..1fbe81f3 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -71,6 +71,7 @@ class TaskModel(Base): error_code_mapping = Column(JSON, nullable=True) errors = Column(JSON, default=[], nullable=False) max_steps_per_run = Column(Integer, nullable=True) + application = Column(String, nullable=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True) modified_at = Column( DateTime, diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 19ea6cf8..7c11c26e 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -82,6 +82,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task: max_steps_per_run=task_obj.max_steps_per_run, error_code_mapping=task_obj.error_code_mapping, errors=task_obj.errors, + application=task_obj.application, ) return task diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index d2b42b65..4690a3de 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -421,6 +421,7 @@ async def get_agent_tasks( workflow_run_id: Annotated[str | None, Query()] = None, current_org: Organization = Depends(org_auth_service.get_current_org), only_standalone_tasks: bool = Query(False), + application: Annotated[str | None, Query()] = None, sort: OrderBy = Query(OrderBy.created_at), order: SortDirection = Query(SortDirection.desc), ) -> Response: @@ -451,6 +452,7 @@ async def get_agent_tasks( only_standalone_tasks=only_standalone_tasks, order=order, order_by_column=sort, + application=application, ) return ORJSONResponse([task.to_task_response().model_dump() for task in tasks]) diff --git a/skyvern/forge/sdk/schemas/tasks.py b/skyvern/forge/sdk/schemas/tasks.py index 1df94347..17654658 100644 --- a/skyvern/forge/sdk/schemas/tasks.py +++ b/skyvern/forge/sdk/schemas/tasks.py @@ -91,6 +91,11 @@ class TaskBase(BaseModel): description="The type of the task", examples=[TaskType.general, TaskType.validation], ) + application: str | None = Field( + default=None, + description="The application for which the task is running", + examples=["forms"], + ) class TaskRequest(TaskBase):