Add application column to tasks (#1291)

Co-authored-by: Muhammed Salih Altun <muhammedsalihaltun@gmail.com>
This commit is contained in:
Shuchang Zheng 2024-11-29 05:43:02 -08:00 committed by GitHub
parent fe0f971842
commit 379d5a30cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 42 additions and 0 deletions

View file

@ -0,0 +1,27 @@
"""Add application column to tasks
Revision ID: a5feab7712fe
Revises: 56085e451bec
Create Date: 2024-11-29 13:32:58.845703+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "a5feab7712fe"
down_revision: Union[str, None] = "56085e451bec"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("tasks", sa.Column("application", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("tasks", "application")

View file

@ -216,6 +216,7 @@ class ForgeAgent:
proxy_location=task_request.proxy_location, proxy_location=task_request.proxy_location,
extracted_information_schema=task_request.extracted_information_schema, extracted_information_schema=task_request.extracted_information_schema,
error_code_mapping=task_request.error_code_mapping, error_code_mapping=task_request.error_code_mapping,
application=task_request.application,
) )
LOG.info( LOG.info(
"Created new task", "Created new task",

View file

@ -114,6 +114,7 @@ class AgentDB:
max_steps_per_run: int | None = None, max_steps_per_run: int | None = None,
error_code_mapping: dict[str, str] | None = None, error_code_mapping: dict[str, str] | None = None,
task_type: str = TaskType.general, task_type: str = TaskType.general,
application: str | None = None,
) -> Task: ) -> Task:
try: try:
async with self.Session() as session: async with self.Session() as session:
@ -138,6 +139,7 @@ class AgentDB:
retry=retry, retry=retry,
max_steps_per_run=max_steps_per_run, max_steps_per_run=max_steps_per_run,
error_code_mapping=error_code_mapping, error_code_mapping=error_code_mapping,
application=application,
) )
session.add(new_task) session.add(new_task)
await session.commit() await session.commit()
@ -478,6 +480,7 @@ class AgentDB:
workflow_run_id: str | None = None, workflow_run_id: str | None = None,
organization_id: str | None = None, organization_id: str | None = None,
only_standalone_tasks: bool = False, only_standalone_tasks: bool = False,
application: str | None = None,
order_by_column: OrderBy = OrderBy.created_at, order_by_column: OrderBy = OrderBy.created_at,
order: SortDirection = SortDirection.desc, order: SortDirection = SortDirection.desc,
) -> list[Task]: ) -> list[Task]:
@ -505,6 +508,8 @@ class AgentDB:
query = query.filter(TaskModel.workflow_run_id == workflow_run_id) query = query.filter(TaskModel.workflow_run_id == workflow_run_id)
if only_standalone_tasks: if only_standalone_tasks:
query = query.filter(TaskModel.workflow_run_id.is_(None)) query = query.filter(TaskModel.workflow_run_id.is_(None))
if application:
query = query.filter(TaskModel.application == application)
order_by_col = getattr(TaskModel, order_by_column) order_by_col = getattr(TaskModel, order_by_column)
query = ( query = (
query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc()) query.order_by(order_by_col.desc() if order == SortDirection.desc else order_by_col.asc())

View file

@ -71,6 +71,7 @@ class TaskModel(Base):
error_code_mapping = Column(JSON, nullable=True) error_code_mapping = Column(JSON, nullable=True)
errors = Column(JSON, default=[], nullable=False) errors = Column(JSON, default=[], nullable=False)
max_steps_per_run = Column(Integer, nullable=True) max_steps_per_run = Column(Integer, nullable=True)
application = Column(String, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column( modified_at = Column(
DateTime, DateTime,

View file

@ -82,6 +82,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
max_steps_per_run=task_obj.max_steps_per_run, max_steps_per_run=task_obj.max_steps_per_run,
error_code_mapping=task_obj.error_code_mapping, error_code_mapping=task_obj.error_code_mapping,
errors=task_obj.errors, errors=task_obj.errors,
application=task_obj.application,
) )
return task return task

View file

@ -421,6 +421,7 @@ async def get_agent_tasks(
workflow_run_id: Annotated[str | None, Query()] = None, workflow_run_id: Annotated[str | None, Query()] = None,
current_org: Organization = Depends(org_auth_service.get_current_org), current_org: Organization = Depends(org_auth_service.get_current_org),
only_standalone_tasks: bool = Query(False), only_standalone_tasks: bool = Query(False),
application: Annotated[str | None, Query()] = None,
sort: OrderBy = Query(OrderBy.created_at), sort: OrderBy = Query(OrderBy.created_at),
order: SortDirection = Query(SortDirection.desc), order: SortDirection = Query(SortDirection.desc),
) -> Response: ) -> Response:
@ -451,6 +452,7 @@ async def get_agent_tasks(
only_standalone_tasks=only_standalone_tasks, only_standalone_tasks=only_standalone_tasks,
order=order, order=order,
order_by_column=sort, order_by_column=sort,
application=application,
) )
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks]) return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])

View file

@ -91,6 +91,11 @@ class TaskBase(BaseModel):
description="The type of the task", description="The type of the task",
examples=[TaskType.general, TaskType.validation], examples=[TaskType.general, TaskType.validation],
) )
application: str | None = Field(
default=None,
description="The application for which the task is running",
examples=["forms"],
)
class TaskRequest(TaskBase): class TaskRequest(TaskBase):