mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
380 lines
16 KiB
Python
380 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import TYPE_CHECKING, Literal, overload
|
|
|
|
from sqlalchemy import select, update
|
|
|
|
from skyvern.forge.sdk.db._error_handling import db_operation
|
|
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
|
|
from skyvern.forge.sdk.db.exceptions import NotFoundError
|
|
from skyvern.forge.sdk.db.mixins.base import read_retry
|
|
from skyvern.forge.sdk.db.models import (
|
|
OrganizationAuthTokenModel,
|
|
OrganizationModel,
|
|
TaskModel,
|
|
WorkflowRunModel,
|
|
)
|
|
from skyvern.forge.sdk.db.utils import (
|
|
convert_to_organization,
|
|
convert_to_organization_auth_token,
|
|
)
|
|
from skyvern.forge.sdk.encrypt import encryptor
|
|
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
|
from skyvern.forge.sdk.schemas.organizations import (
|
|
AzureClientSecretCredential,
|
|
AzureOrganizationAuthToken,
|
|
BitwardenCredential,
|
|
BitwardenOrganizationAuthToken,
|
|
Organization,
|
|
OrganizationAuthToken,
|
|
)
|
|
from skyvern.forge.sdk.schemas.tasks import TaskStatus
|
|
from skyvern.forge.sdk.workflow.models.workflow import WorkflowRunStatus
|
|
|
|
if TYPE_CHECKING:
|
|
from skyvern.forge.sdk.db.base_alchemy_db import _SessionFactory
|
|
|
|
|
|
class OrganizationsMixin:
|
|
Session: _SessionFactory
|
|
"""Database operations for organization and auth-token management."""
|
|
|
|
@read_retry()
|
|
@db_operation("get_active_verification_requests", log_errors=False)
|
|
async def get_active_verification_requests(self, organization_id: str) -> list[dict]:
|
|
"""Return active 2FA verification requests for an organization.
|
|
|
|
Queries both tasks and workflow runs where waiting_for_verification_code=True.
|
|
Used to provide initial state when a WebSocket notification client connects.
|
|
"""
|
|
results: list[dict] = []
|
|
async with self.Session() as session:
|
|
# Tasks waiting for verification (exclude finalized tasks)
|
|
finalized_task_statuses = [s.value for s in TaskStatus if s.is_final()]
|
|
task_rows = (
|
|
await session.scalars(
|
|
select(TaskModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(waiting_for_verification_code=True)
|
|
.filter_by(workflow_run_id=None)
|
|
.filter(TaskModel.status.not_in(finalized_task_statuses))
|
|
.filter(TaskModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
|
)
|
|
).all()
|
|
for t in task_rows:
|
|
results.append(
|
|
{
|
|
"task_id": t.task_id,
|
|
"workflow_run_id": None,
|
|
"verification_code_identifier": t.verification_code_identifier,
|
|
"verification_code_polling_started_at": (
|
|
t.verification_code_polling_started_at.isoformat()
|
|
if t.verification_code_polling_started_at
|
|
else None
|
|
),
|
|
}
|
|
)
|
|
# Workflow runs waiting for verification (exclude finalized runs)
|
|
finalized_wr_statuses = [s.value for s in WorkflowRunStatus if s.is_final()]
|
|
wr_rows = (
|
|
await session.scalars(
|
|
select(WorkflowRunModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(waiting_for_verification_code=True)
|
|
.filter(WorkflowRunModel.status.not_in(finalized_wr_statuses))
|
|
.filter(WorkflowRunModel.created_at > datetime.utcnow() - timedelta(hours=1))
|
|
)
|
|
).all()
|
|
for wr in wr_rows:
|
|
results.append(
|
|
{
|
|
"task_id": None,
|
|
"workflow_run_id": wr.workflow_run_id,
|
|
"verification_code_identifier": wr.verification_code_identifier,
|
|
"verification_code_polling_started_at": (
|
|
wr.verification_code_polling_started_at.isoformat()
|
|
if wr.verification_code_polling_started_at
|
|
else None
|
|
),
|
|
}
|
|
)
|
|
return results
|
|
|
|
@db_operation("get_all_organizations")
|
|
async def get_all_organizations(self) -> list[Organization]:
|
|
async with self.Session() as session:
|
|
organizations = (await session.scalars(select(OrganizationModel))).all()
|
|
return [convert_to_organization(organization) for organization in organizations]
|
|
|
|
@db_operation("get_organization")
|
|
async def get_organization(self, organization_id: str) -> Organization | None:
|
|
async with self.Session() as session:
|
|
if organization := (
|
|
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
|
).first():
|
|
return convert_to_organization(organization)
|
|
else:
|
|
return None
|
|
|
|
@db_operation("get_organization_by_domain")
|
|
async def get_organization_by_domain(self, domain: str) -> Organization | None:
|
|
async with self.Session() as session:
|
|
if organization := (await session.scalars(select(OrganizationModel).filter_by(domain=domain))).first():
|
|
return convert_to_organization(organization)
|
|
return None
|
|
|
|
@db_operation("create_organization")
|
|
async def create_organization(
|
|
self,
|
|
organization_name: str,
|
|
webhook_callback_url: str | None = None,
|
|
max_steps_per_run: int | None = None,
|
|
max_retries_per_step: int | None = None,
|
|
domain: str | None = None,
|
|
organization_id: str | None = None,
|
|
) -> Organization:
|
|
async with self.Session() as session:
|
|
org = OrganizationModel(
|
|
organization_id=organization_id,
|
|
organization_name=organization_name,
|
|
webhook_callback_url=webhook_callback_url,
|
|
max_steps_per_run=max_steps_per_run,
|
|
max_retries_per_step=max_retries_per_step,
|
|
domain=domain,
|
|
)
|
|
session.add(org)
|
|
await session.commit()
|
|
await session.refresh(org)
|
|
|
|
return convert_to_organization(org)
|
|
|
|
@db_operation("update_organization")
|
|
async def update_organization(
|
|
self,
|
|
organization_id: str,
|
|
organization_name: str | None = None,
|
|
webhook_callback_url: str | None = None,
|
|
max_steps_per_run: int | None = None,
|
|
max_retries_per_step: int | None = None,
|
|
) -> Organization:
|
|
async with self.Session() as session:
|
|
organization = (
|
|
await session.scalars(select(OrganizationModel).filter_by(organization_id=organization_id))
|
|
).first()
|
|
if not organization:
|
|
raise NotFoundError
|
|
if organization_name:
|
|
organization.organization_name = organization_name
|
|
if webhook_callback_url:
|
|
organization.webhook_callback_url = webhook_callback_url
|
|
if max_steps_per_run:
|
|
organization.max_steps_per_run = max_steps_per_run
|
|
if max_retries_per_step:
|
|
organization.max_retries_per_step = max_retries_per_step
|
|
await session.commit()
|
|
await session.refresh(organization)
|
|
return Organization.model_validate(organization)
|
|
|
|
@overload
|
|
async def get_valid_org_auth_token(
|
|
self,
|
|
organization_id: str,
|
|
token_type: Literal["api", "onepassword_service_account", "custom_credential_service"],
|
|
) -> OrganizationAuthToken | None: ...
|
|
|
|
@overload
|
|
async def get_valid_org_auth_token( # type: ignore
|
|
self,
|
|
organization_id: str,
|
|
token_type: Literal["azure_client_secret_credential"],
|
|
) -> AzureOrganizationAuthToken | None: ...
|
|
|
|
@overload
|
|
async def get_valid_org_auth_token( # type: ignore
|
|
self,
|
|
organization_id: str,
|
|
token_type: Literal["bitwarden_credential"],
|
|
) -> BitwardenOrganizationAuthToken | None: ...
|
|
|
|
@db_operation("get_valid_org_auth_token")
|
|
async def get_valid_org_auth_token(
|
|
self,
|
|
organization_id: str,
|
|
token_type: Literal[
|
|
"api",
|
|
"onepassword_service_account",
|
|
"azure_client_secret_credential",
|
|
"bitwarden_credential",
|
|
"custom_credential_service",
|
|
],
|
|
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken | None:
|
|
async with self.Session() as session:
|
|
if token := (
|
|
await session.scalars(
|
|
select(OrganizationAuthTokenModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(token_type=token_type)
|
|
.filter_by(valid=True)
|
|
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
|
)
|
|
).first():
|
|
return await convert_to_organization_auth_token(token, token_type)
|
|
else:
|
|
return None
|
|
|
|
@db_operation("get_valid_org_auth_tokens")
|
|
async def get_valid_org_auth_tokens(
|
|
self,
|
|
organization_id: str,
|
|
token_type: OrganizationAuthTokenType,
|
|
) -> list[OrganizationAuthToken]:
|
|
async with self.Session() as session:
|
|
tokens = (
|
|
await session.scalars(
|
|
select(OrganizationAuthTokenModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(token_type=token_type)
|
|
.filter_by(valid=True)
|
|
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
|
)
|
|
).all()
|
|
return [await convert_to_organization_auth_token(token, token_type) for token in tokens]
|
|
|
|
@db_operation("validate_org_auth_token")
|
|
async def validate_org_auth_token(
|
|
self,
|
|
organization_id: str,
|
|
token_type: OrganizationAuthTokenType,
|
|
token: str,
|
|
valid: bool | None = True,
|
|
encrypted_method: EncryptMethod | None = None,
|
|
) -> OrganizationAuthToken | None:
|
|
encrypted_token = ""
|
|
if encrypted_method is not None:
|
|
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
|
|
|
async with self.Session() as session:
|
|
query = (
|
|
select(OrganizationAuthTokenModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(token_type=token_type)
|
|
)
|
|
if encrypted_token:
|
|
query = query.filter_by(encrypted_token=encrypted_token)
|
|
else:
|
|
query = query.filter_by(token=token)
|
|
if valid is not None:
|
|
query = query.filter_by(valid=valid)
|
|
if token_obj := (await session.scalars(query)).first():
|
|
return await convert_to_organization_auth_token(token_obj, token_type)
|
|
else:
|
|
return None
|
|
|
|
@db_operation("create_org_auth_token")
|
|
async def create_org_auth_token(
|
|
self,
|
|
organization_id: str,
|
|
token_type: OrganizationAuthTokenType,
|
|
token: str | AzureClientSecretCredential | BitwardenCredential,
|
|
encrypted_method: EncryptMethod | None = None,
|
|
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
|
|
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
|
if not isinstance(token, AzureClientSecretCredential):
|
|
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
|
plaintext_token = token.model_dump_json()
|
|
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
|
|
if not isinstance(token, BitwardenCredential):
|
|
raise TypeError("Expected BitwardenCredential for this token_type")
|
|
plaintext_token = token.model_dump_json()
|
|
else:
|
|
if not isinstance(token, str):
|
|
raise TypeError("Expected str token for this token_type")
|
|
plaintext_token = token
|
|
|
|
encrypted_token = ""
|
|
|
|
if encrypted_method is not None:
|
|
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
|
plaintext_token = ""
|
|
|
|
async with self.Session() as session:
|
|
auth_token = OrganizationAuthTokenModel(
|
|
organization_id=organization_id,
|
|
token_type=token_type,
|
|
token=plaintext_token,
|
|
encrypted_token=encrypted_token,
|
|
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
|
)
|
|
session.add(auth_token)
|
|
await session.commit()
|
|
await session.refresh(auth_token)
|
|
|
|
return await convert_to_organization_auth_token(auth_token, token_type)
|
|
|
|
@db_operation("invalidate_org_auth_tokens")
|
|
async def invalidate_org_auth_tokens(
|
|
self,
|
|
organization_id: str,
|
|
token_type: OrganizationAuthTokenType,
|
|
) -> None:
|
|
"""Invalidate all existing tokens of a specific type for an organization."""
|
|
async with self.Session() as session:
|
|
await session.execute(
|
|
update(OrganizationAuthTokenModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(token_type=token_type)
|
|
.filter_by(valid=True)
|
|
.values(valid=False)
|
|
)
|
|
await session.commit()
|
|
|
|
@db_operation("replace_org_auth_token")
|
|
async def replace_org_auth_token(
|
|
self,
|
|
organization_id: str,
|
|
token_type: OrganizationAuthTokenType,
|
|
token: str | AzureClientSecretCredential | BitwardenCredential,
|
|
encrypted_method: EncryptMethod | None = None,
|
|
) -> OrganizationAuthToken | AzureOrganizationAuthToken | BitwardenOrganizationAuthToken:
|
|
"""Atomically invalidate existing tokens and create a new one in a single transaction."""
|
|
if token_type is OrganizationAuthTokenType.azure_client_secret_credential:
|
|
if not isinstance(token, AzureClientSecretCredential):
|
|
raise TypeError("Expected AzureClientSecretCredential for this token_type")
|
|
plaintext_token = token.model_dump_json()
|
|
elif token_type is OrganizationAuthTokenType.bitwarden_credential:
|
|
if not isinstance(token, BitwardenCredential):
|
|
raise TypeError("Expected BitwardenCredential for this token_type")
|
|
plaintext_token = token.model_dump_json()
|
|
else:
|
|
if not isinstance(token, str):
|
|
raise TypeError("Expected str token for this token_type")
|
|
plaintext_token = token
|
|
|
|
encrypted_token = ""
|
|
if encrypted_method is not None:
|
|
encrypted_token = await encryptor.encrypt(plaintext_token, encrypted_method)
|
|
plaintext_token = ""
|
|
|
|
async with self.Session() as session:
|
|
# Invalidate existing tokens
|
|
await session.execute(
|
|
update(OrganizationAuthTokenModel)
|
|
.filter_by(organization_id=organization_id)
|
|
.filter_by(token_type=token_type)
|
|
.filter_by(valid=True)
|
|
.values(valid=False)
|
|
)
|
|
# Create new token
|
|
auth_token = OrganizationAuthTokenModel(
|
|
organization_id=organization_id,
|
|
token_type=token_type,
|
|
token=plaintext_token,
|
|
encrypted_token=encrypted_token,
|
|
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
|
)
|
|
session.add(auth_token)
|
|
await session.commit()
|
|
await session.refresh(auth_token)
|
|
|
|
return await convert_to_organization_auth_token(auth_token, token_type)
|