mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
155 lines
6.1 KiB
Python
155 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import TYPE_CHECKING
|
|
|
|
from sqlalchemy import and_, asc, select
|
|
|
|
from skyvern.config import settings
|
|
from skyvern.forge.sdk.db._error_handling import db_operation
|
|
from skyvern.forge.sdk.db.models import TOTPCodeModel
|
|
from skyvern.forge.sdk.schemas.totp_codes import OTPType, TOTPCode
|
|
|
|
if TYPE_CHECKING:
|
|
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
|
|
|
|
|
class OTPMixin:
|
|
"""Database operations for OTP/TOTP management."""
|
|
|
|
Session: _SessionFactory
|
|
|
|
@db_operation("get_otp_codes")
|
|
async def get_otp_codes(
|
|
self,
|
|
organization_id: str,
|
|
totp_identifier: str,
|
|
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
|
otp_type: OTPType | None = None,
|
|
workflow_run_id: str | None = None,
|
|
limit: int | None = None,
|
|
) -> list[TOTPCode]:
|
|
"""
|
|
1. filter by:
|
|
- organization_id
|
|
- totp_identifier
|
|
- workflow_run_id (optional)
|
|
2. make sure created_at is within the valid lifespan
|
|
3. sort by task_id/workflow_id/workflow_run_id nullslast and created_at desc
|
|
4. apply an optional limit at the DB layer
|
|
"""
|
|
all_null = and_(
|
|
TOTPCodeModel.task_id.is_(None),
|
|
TOTPCodeModel.workflow_id.is_(None),
|
|
TOTPCodeModel.workflow_run_id.is_(None),
|
|
)
|
|
async with self.Session() as session:
|
|
query = (
|
|
select(TOTPCodeModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(totp_identifier=totp_identifier)
|
|
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
|
)
|
|
if otp_type:
|
|
query = query.filter(TOTPCodeModel.otp_type == otp_type)
|
|
if workflow_run_id is not None:
|
|
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
|
query = query.order_by(asc(all_null), TOTPCodeModel.created_at.desc())
|
|
if limit is not None:
|
|
query = query.limit(limit)
|
|
totp_codes = (await session.scalars(query)).all()
|
|
return [TOTPCode.model_validate(code) for code in totp_codes]
|
|
|
|
@db_operation("get_otp_codes_by_run")
|
|
async def get_otp_codes_by_run(
|
|
self,
|
|
organization_id: str,
|
|
task_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
|
|
limit: int = 1,
|
|
) -> list[TOTPCode]:
|
|
"""Get OTP codes matching a specific task or workflow run (no totp_identifier required).
|
|
|
|
Used when the agent detects a 2FA page but no TOTP credentials are pre-configured.
|
|
The user submits codes manually via the UI, and this method finds them by run context.
|
|
"""
|
|
if not workflow_run_id and not task_id:
|
|
return []
|
|
async with self.Session() as session:
|
|
query = (
|
|
select(TOTPCodeModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
|
|
)
|
|
if workflow_run_id:
|
|
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
|
elif task_id:
|
|
query = query.filter(TOTPCodeModel.task_id == task_id)
|
|
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
|
results = (await session.scalars(query)).all()
|
|
return [TOTPCode.model_validate(r) for r in results]
|
|
|
|
@db_operation("get_recent_otp_codes")
|
|
async def get_recent_otp_codes(
|
|
self,
|
|
organization_id: str,
|
|
limit: int = 50,
|
|
valid_lifespan_minutes: int | None = None,
|
|
otp_type: OTPType | None = None,
|
|
workflow_run_id: str | None = None,
|
|
totp_identifier: str | None = None,
|
|
) -> list[TOTPCode]:
|
|
"""
|
|
Return recent otp codes for an organization ordered by newest first with optional
|
|
workflow_run_id filtering.
|
|
"""
|
|
async with self.Session() as session:
|
|
query = select(TOTPCodeModel).filter_by(organization_id=organization_id)
|
|
|
|
if valid_lifespan_minutes is not None:
|
|
query = query.filter(
|
|
TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes)
|
|
)
|
|
|
|
if otp_type:
|
|
query = query.filter(TOTPCodeModel.otp_type == otp_type)
|
|
if workflow_run_id is not None:
|
|
query = query.filter(TOTPCodeModel.workflow_run_id == workflow_run_id)
|
|
if totp_identifier:
|
|
query = query.filter(TOTPCodeModel.totp_identifier == totp_identifier)
|
|
query = query.order_by(TOTPCodeModel.created_at.desc()).limit(limit)
|
|
totp_codes = (await session.scalars(query)).all()
|
|
return [TOTPCode.model_validate(totp_code) for totp_code in totp_codes]
|
|
|
|
@db_operation("create_otp_code")
|
|
async def create_otp_code(
|
|
self,
|
|
organization_id: str,
|
|
totp_identifier: str,
|
|
content: str,
|
|
code: str,
|
|
otp_type: OTPType,
|
|
task_id: str | None = None,
|
|
workflow_id: str | None = None,
|
|
workflow_run_id: str | None = None,
|
|
source: str | None = None,
|
|
expired_at: datetime | None = None,
|
|
) -> TOTPCode:
|
|
async with self.Session() as session:
|
|
new_totp_code = TOTPCodeModel(
|
|
organization_id=organization_id,
|
|
totp_identifier=totp_identifier,
|
|
content=content,
|
|
code=code,
|
|
task_id=task_id,
|
|
workflow_id=workflow_id,
|
|
workflow_run_id=workflow_run_id,
|
|
source=source,
|
|
expired_at=expired_at,
|
|
otp_type=otp_type,
|
|
)
|
|
session.add(new_totp_code)
|
|
await session.commit()
|
|
await session.refresh(new_totp_code)
|
|
return TOTPCode.model_validate(new_totp_code)
|