mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-04 03:30:15 +00:00
remove AgentMiddleware (#211)
This commit is contained in:
parent
9091a6716e
commit
02db2a90e6
2 changed files with 3 additions and 46 deletions
|
@ -47,8 +47,6 @@ class Agent:
|
||||||
|
|
||||||
app.include_router(router, prefix="/api/v1")
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
app.add_middleware(AgentMiddleware, agent=self)
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
RawContextMiddleware,
|
RawContextMiddleware,
|
||||||
plugins=(
|
plugins=(
|
||||||
|
@ -85,20 +83,6 @@ class Agent:
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
class AgentMiddleware:
|
|
||||||
"""
|
|
||||||
Middleware that injects the agent instance into the request scope.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app: FastAPI, agent: Agent):
|
|
||||||
self.app = app
|
|
||||||
self.agent = agent
|
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send): # type: ignore
|
|
||||||
scope["agent"] = self.agent
|
|
||||||
await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionDatePlugin(Plugin):
|
class ExecutionDatePlugin(Plugin):
|
||||||
key = "execution_date"
|
key = "execution_date"
|
||||||
|
|
||||||
|
|
|
@ -83,19 +83,17 @@ async def check_server_status() -> Response:
|
||||||
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
|
@base_router.post("/tasks", tags=["agent"], response_model=CreateTaskResponse)
|
||||||
async def create_agent_task(
|
async def create_agent_task(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
request: Request,
|
|
||||||
task: TaskRequest,
|
task: TaskRequest,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
x_api_key: Annotated[str | None, Header()] = None,
|
x_api_key: Annotated[str | None, Header()] = None,
|
||||||
x_max_steps_override: Annotated[int | None, Header()] = None,
|
x_max_steps_override: Annotated[int | None, Header()] = None,
|
||||||
) -> CreateTaskResponse:
|
) -> CreateTaskResponse:
|
||||||
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
analytics.capture("skyvern-oss-agent-task-create", data={"url": task.url})
|
||||||
agent = request["agent"]
|
|
||||||
|
|
||||||
if current_org and current_org.organization_name == "CoverageCat":
|
if current_org and current_org.organization_name == "CoverageCat":
|
||||||
task.proxy_location = ProxyLocation.RESIDENTIAL
|
task.proxy_location = ProxyLocation.RESIDENTIAL
|
||||||
|
|
||||||
created_task = await agent.create_task(task, current_org.organization_id)
|
created_task = await app.agent.create_task(task, current_org.organization_id)
|
||||||
if x_max_steps_override:
|
if x_max_steps_override:
|
||||||
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
LOG.info("Overriding max steps per run", max_steps_override=x_max_steps_override)
|
||||||
await AsyncExecutorFactory.get_executor().execute_task(
|
await AsyncExecutorFactory.get_executor().execute_task(
|
||||||
|
@ -121,13 +119,11 @@ async def create_agent_task(
|
||||||
summary="Executes the next step",
|
summary="Executes the next step",
|
||||||
)
|
)
|
||||||
async def execute_agent_task_step(
|
async def execute_agent_task_step(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
step_id: str | None = None,
|
step_id: str | None = None,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
analytics.capture("skyvern-oss-agent-task-step-execute")
|
analytics.capture("skyvern-oss-agent-task-step-execute")
|
||||||
agent = request["agent"]
|
|
||||||
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
task = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -171,7 +167,7 @@ async def execute_agent_task_step(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=f"No step found with id {step_id}",
|
detail=f"No step found with id {step_id}",
|
||||||
)
|
)
|
||||||
step, _, _ = await agent.execute_step(current_org, task, step)
|
step, _, _ = await app.agent.execute_step(current_org, task, step)
|
||||||
return Response(
|
return Response(
|
||||||
content=step.model_dump_json() if step else "",
|
content=step.model_dump_json() if step else "",
|
||||||
status_code=200,
|
status_code=200,
|
||||||
|
@ -181,12 +177,10 @@ async def execute_agent_task_step(
|
||||||
|
|
||||||
@base_router.get("/tasks/{task_id}", response_model=TaskResponse)
|
@base_router.get("/tasks/{task_id}", response_model=TaskResponse)
|
||||||
async def get_task(
|
async def get_task(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> TaskResponse:
|
) -> TaskResponse:
|
||||||
analytics.capture("skyvern-oss-agent-task-get")
|
analytics.capture("skyvern-oss-agent-task-get")
|
||||||
request["agent"]
|
|
||||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||||
if not task_obj:
|
if not task_obj:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -270,13 +264,11 @@ async def get_task(
|
||||||
response_model=TaskResponse,
|
response_model=TaskResponse,
|
||||||
)
|
)
|
||||||
async def retry_webhook(
|
async def retry_webhook(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
x_api_key: Annotated[str | None, Header()] = None,
|
x_api_key: Annotated[str | None, Header()] = None,
|
||||||
) -> TaskResponse:
|
) -> TaskResponse:
|
||||||
analytics.capture("skyvern-oss-agent-task-retry-webhook")
|
analytics.capture("skyvern-oss-agent-task-retry-webhook")
|
||||||
agent = request["agent"]
|
|
||||||
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
task_obj = await app.DATABASE.get_task(task_id, organization_id=current_org.organization_id)
|
||||||
if not task_obj:
|
if not task_obj:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -290,20 +282,18 @@ async def retry_webhook(
|
||||||
return task_obj.to_task_response()
|
return task_obj.to_task_response()
|
||||||
|
|
||||||
# retry the webhook
|
# retry the webhook
|
||||||
await agent.execute_task_webhook(task=task_obj, last_step=latest_step, api_key=x_api_key)
|
await app.agent.execute_task_webhook(task=task_obj, last_step=latest_step, api_key=x_api_key)
|
||||||
|
|
||||||
return task_obj.to_task_response()
|
return task_obj.to_task_response()
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/internal/tasks/{task_id}", response_model=list[Task])
|
@base_router.get("/internal/tasks/{task_id}", response_model=list[Task])
|
||||||
async def get_task_internal(
|
async def get_task_internal(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Get all tasks.
|
Get all tasks.
|
||||||
:param request:
|
|
||||||
:param page: Starting page, defaults to 1
|
:param page: Starting page, defaults to 1
|
||||||
:param page_size:
|
:param page_size:
|
||||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||||
|
@ -321,80 +311,68 @@ async def get_task_internal(
|
||||||
|
|
||||||
@base_router.get("/tasks", tags=["agent"], response_model=list[Task])
|
@base_router.get("/tasks", tags=["agent"], response_model=list[Task])
|
||||||
async def get_agent_tasks(
|
async def get_agent_tasks(
|
||||||
request: Request,
|
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(10, ge=1),
|
page_size: int = Query(10, ge=1),
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Get all tasks.
|
Get all tasks.
|
||||||
:param request:
|
|
||||||
:param page: Starting page, defaults to 1
|
:param page: Starting page, defaults to 1
|
||||||
:param page_size: Page size, defaults to 10
|
:param page_size: Page size, defaults to 10
|
||||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||||
get_agent_task endpoint.
|
get_agent_task endpoint.
|
||||||
"""
|
"""
|
||||||
analytics.capture("skyvern-oss-agent-tasks-get")
|
analytics.capture("skyvern-oss-agent-tasks-get")
|
||||||
request["agent"]
|
|
||||||
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
||||||
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
|
return ORJSONResponse([task.to_task_response().model_dump() for task in tasks])
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
|
@base_router.get("/internal/tasks", tags=["agent"], response_model=list[Task])
|
||||||
async def get_agent_tasks_internal(
|
async def get_agent_tasks_internal(
|
||||||
request: Request,
|
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
page_size: int = Query(10, ge=1),
|
page_size: int = Query(10, ge=1),
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Get all tasks.
|
Get all tasks.
|
||||||
:param request:
|
|
||||||
:param page: Starting page, defaults to 1
|
:param page: Starting page, defaults to 1
|
||||||
:param page_size: Page size, defaults to 10
|
:param page_size: Page size, defaults to 10
|
||||||
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
:return: List of tasks with pagination without steps populated. Steps can be populated by calling the
|
||||||
get_agent_task endpoint.
|
get_agent_task endpoint.
|
||||||
"""
|
"""
|
||||||
analytics.capture("skyvern-oss-agent-tasks-get-internal")
|
analytics.capture("skyvern-oss-agent-tasks-get-internal")
|
||||||
request["agent"]
|
|
||||||
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
tasks = await app.DATABASE.get_tasks(page, page_size, organization_id=current_org.organization_id)
|
||||||
return ORJSONResponse([task.model_dump() for task in tasks])
|
return ORJSONResponse([task.model_dump() for task in tasks])
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
|
@base_router.get("/tasks/{task_id}/steps", tags=["agent"], response_model=list[Step])
|
||||||
async def get_agent_task_steps(
|
async def get_agent_task_steps(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Get all steps for a task.
|
Get all steps for a task.
|
||||||
:param request:
|
|
||||||
:param task_id:
|
:param task_id:
|
||||||
:return: List of steps for a task with pagination.
|
:return: List of steps for a task with pagination.
|
||||||
"""
|
"""
|
||||||
analytics.capture("skyvern-oss-agent-task-steps-get")
|
analytics.capture("skyvern-oss-agent-task-steps-get")
|
||||||
request["agent"]
|
|
||||||
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
|
steps = await app.DATABASE.get_task_steps(task_id, organization_id=current_org.organization_id)
|
||||||
return ORJSONResponse([step.model_dump() for step in steps])
|
return ORJSONResponse([step.model_dump() for step in steps])
|
||||||
|
|
||||||
|
|
||||||
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
|
@base_router.get("/tasks/{task_id}/steps/{step_id}/artifacts", tags=["agent"], response_model=list[Artifact])
|
||||||
async def get_agent_task_step_artifacts(
|
async def get_agent_task_step_artifacts(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
step_id: str,
|
step_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""
|
"""
|
||||||
Get all artifacts for a list of steps.
|
Get all artifacts for a list of steps.
|
||||||
:param request:
|
|
||||||
:param task_id:
|
:param task_id:
|
||||||
:param step_id:
|
:param step_id:
|
||||||
:return: List of artifacts for a list of steps.
|
:return: List of artifacts for a list of steps.
|
||||||
"""
|
"""
|
||||||
analytics.capture("skyvern-oss-agent-task-step-artifacts-get")
|
analytics.capture("skyvern-oss-agent-task-step-artifacts-get")
|
||||||
request["agent"]
|
|
||||||
artifacts = await app.DATABASE.get_artifacts_for_task_step(
|
artifacts = await app.DATABASE.get_artifacts_for_task_step(
|
||||||
task_id,
|
task_id,
|
||||||
step_id,
|
step_id,
|
||||||
|
@ -416,12 +394,10 @@ class ActionResultTmp(BaseModel):
|
||||||
|
|
||||||
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
|
@base_router.get("/tasks/{task_id}/actions", response_model=list[ActionResultTmp])
|
||||||
async def get_task_actions(
|
async def get_task_actions(
|
||||||
request: Request,
|
|
||||||
task_id: str,
|
task_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> list[ActionResultTmp]:
|
) -> list[ActionResultTmp]:
|
||||||
analytics.capture("skyvern-oss-agent-task-actions-get")
|
analytics.capture("skyvern-oss-agent-task-actions-get")
|
||||||
request["agent"]
|
|
||||||
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
|
steps = await app.DATABASE.get_task_step_models(task_id, organization_id=current_org.organization_id)
|
||||||
results: list[ActionResultTmp] = []
|
results: list[ActionResultTmp] = []
|
||||||
for step_s in steps:
|
for step_s in steps:
|
||||||
|
@ -435,7 +411,6 @@ async def get_task_actions(
|
||||||
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
|
@base_router.post("/workflows/{workflow_id}/run", response_model=RunWorkflowResponse)
|
||||||
async def execute_workflow(
|
async def execute_workflow(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
request: Request,
|
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
workflow_request: WorkflowRequestBody,
|
workflow_request: WorkflowRequestBody,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
|
@ -470,13 +445,11 @@ async def execute_workflow(
|
||||||
|
|
||||||
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
|
@base_router.get("/workflows/{workflow_id}/runs/{workflow_run_id}", response_model=WorkflowRunStatusResponse)
|
||||||
async def get_workflow_run(
|
async def get_workflow_run(
|
||||||
request: Request,
|
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
current_org: Organization = Depends(org_auth_service.get_current_org),
|
current_org: Organization = Depends(org_auth_service.get_current_org),
|
||||||
) -> WorkflowRunStatusResponse:
|
) -> WorkflowRunStatusResponse:
|
||||||
analytics.capture("skyvern-oss-agent-workflow-run-get")
|
analytics.capture("skyvern-oss-agent-workflow-run-get")
|
||||||
request["agent"]
|
|
||||||
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
|
return await app.WORKFLOW_SERVICE.build_workflow_run_status_response(
|
||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
workflow_run_id=workflow_run_id,
|
workflow_run_id=workflow_run_id,
|
||||||
|
|
Loading…
Add table
Reference in a new issue