feat: encrypt org auth tokens with AES (#3104)
Some checks are pending
Run tests and pre-commit / Run tests and pre-commit hooks (push) Waiting to run
Run tests and pre-commit / Frontend Lint and Build (push) Waiting to run
Publish Fern Docs / run (push) Waiting to run

This commit is contained in:
LawyZheng 2025-08-05 12:36:24 +08:00 committed by GitHub
parent 977c9d4f13
commit 02576e5be3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 192 additions and 9 deletions

View file

@ -0,0 +1,42 @@
"""add_encrypt_token_and_method
Revision ID: dd29417b397c
Revises: 1eedd7a957d1
Create Date: 2025-08-05 04:32:13.735805+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "dd29417b397c"
down_revision: Union[str, None] = "1eedd7a957d1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("organization_auth_tokens", sa.Column("encrypted_token", sa.String(), nullable=True))
op.add_column("organization_auth_tokens", sa.Column("encrypted_method", sa.String(), nullable=True))
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=True)
op.create_index(
op.f("ix_organization_auth_tokens_encrypted_token"),
"organization_auth_tokens",
["encrypted_token"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_organization_auth_tokens_encrypted_token"), table_name="organization_auth_tokens")
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=False)
op.drop_column("organization_auth_tokens", "encrypted_method")
op.drop_column("organization_auth_tokens", "encrypted_token")
# ### end Alembic commands ###

View file

@ -309,6 +309,10 @@ class Settings(BaseSettings):
Otherwise we'll consider the persistent browser session to be expired.
"""
ENCRYPTOR_AES_SECRET_KEY: str = "fillmein"
ENCRYPTOR_AES_SALT: str | None = None
ENCRYPTOR_AES_IV: str | None = None
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
"""
Keys are model names available to blocks in the frontend. These map to key names

View file

@ -66,6 +66,8 @@ from skyvern.forge.sdk.db.utils import (
convert_to_workflow_run_parameter,
hydrate_action,
)
from skyvern.forge.sdk.encrypt import encryptor
from skyvern.forge.sdk.encrypt.base import EncryptMethod
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
@ -867,7 +869,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).first():
return convert_to_organization_auth_token(token)
return await convert_to_organization_auth_token(token)
else:
return None
except SQLAlchemyError:
@ -893,7 +895,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc())
)
).all()
return [convert_to_organization_auth_token(token) for token in tokens]
return [await convert_to_organization_auth_token(token) for token in tokens]
except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True)
raise
@ -907,19 +909,27 @@ class AgentDB:
token_type: OrganizationAuthTokenType,
token: str,
valid: bool | None = True,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken | None:
try:
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
async with self.Session() as session:
query = (
select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id)
.filter_by(token_type=token_type)
.filter_by(token=token)
)
if encrypted_token:
query = query.filter_by(encrypted_token=encrypted_token)
else:
query = query.filter_by(token=token)
if valid is not None:
query = query.filter_by(valid=valid)
if token_obj := (await session.scalars(query)).first():
return convert_to_organization_auth_token(token_obj)
return await convert_to_organization_auth_token(token_obj)
else:
return None
except SQLAlchemyError:
@ -934,18 +944,28 @@ class AgentDB:
organization_id: str,
token_type: OrganizationAuthTokenType,
token: str,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken:
plaintext_token = token
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
plaintext_token = ""
async with self.Session() as session:
auth_token = OrganizationAuthTokenModel(
organization_id=organization_id,
token_type=token_type,
token=token,
token=plaintext_token,
encrypted_token=encrypted_token,
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
)
session.add(auth_token)
await session.commit()
await session.refresh(auth_token)
return convert_to_organization_auth_token(auth_token)
return await convert_to_organization_auth_token(auth_token)
async def get_artifacts_for_task_v2(
self,

View file

@ -167,7 +167,9 @@ class OrganizationAuthTokenModel(Base):
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
token_type = Column(String, nullable=False)
token = Column(String, index=True, nullable=False)
token = Column(String, index=True, nullable=True)
encrypted_token = Column(String, index=True, nullable=True)
encrypted_method = Column(String, nullable=True)
valid = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View file

@ -26,6 +26,8 @@ from skyvern.forge.sdk.db.models import (
WorkflowRunOutputParameterModel,
WorkflowRunParameterModel,
)
from skyvern.forge.sdk.encrypt import encryptor
from skyvern.forge.sdk.encrypt.base import EncryptMethod
from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
@ -190,14 +192,18 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
)
def convert_to_organization_auth_token(
async def convert_to_organization_auth_token(
org_auth_token: OrganizationAuthTokenModel,
) -> OrganizationAuthToken:
token = org_auth_token.token
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
return OrganizationAuthToken(
id=org_auth_token.id,
organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=org_auth_token.token,
token=token,
valid=org_auth_token.valid,
created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at,

View file

@ -0,0 +1,26 @@
from pydantic import BaseModel
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
class Encryptor(BaseModel):
def __init__(self) -> None:
self._methods: dict[EncryptMethod, BaseEncryptor] = {}
def add_encrypt_method(self, encrypt_method: BaseEncryptor) -> None:
self._methods[encrypt_method.method()] = encrypt_method
async def encrypt(self, plaintext: str, method: EncryptMethod) -> str:
if method not in self._methods:
raise ValueError(f"encrypt method not registered: {method}")
return await self._methods[method].encrypt(plaintext)
async def decrypt(self, ciphertext: str, method: EncryptMethod) -> str:
if method not in self._methods:
raise ValueError(f"encrypt method not registered: {method}")
return await self._methods[method].decrypt(ciphertext)
encryptor = Encryptor()

View file

@ -0,0 +1,63 @@
import base64
import hashlib
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
class AES(BaseEncryptor):
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
def method(self) -> EncryptMethod:
return EncryptMethod.AES
def _derive_key(self) -> bytes:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self.salt,
iterations=100000,
)
return kdf.derive(self.secret_key)
async def encrypt(self, plaintext: str) -> str:
try:
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
encryptor = cipher.encryptor()
padded_plaintext = self._pad(plaintext.encode("utf-8"))
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
return base64.b64encode(ciphertext).decode("utf-8")
except Exception as e:
raise Exception("Failed to encrypt token") from e
async def decrypt(self, ciphertext: str) -> str:
try:
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
decryptor = cipher.decryptor()
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
plaintext = self._unpad(padded_plaintext)
return plaintext.decode("utf-8")
except Exception as e:
raise Exception("Failed to decrypt token") from e
def _pad(self, data: bytes) -> bytes:
block_size = 16
padding_length = block_size - (len(data) % block_size)
padding = bytes([padding_length] * padding_length)
return data + padding
def _unpad(self, data: bytes) -> bytes:
padding_length = data[-1]
return data[:-padding_length]

View file

@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from enum import Enum
class EncryptMethod(Enum):
AES = "aes"
class BaseEncryptor(ABC):
@abstractmethod
def method(self) -> EncryptMethod:
pass
@abstractmethod
async def encrypt(self, plaintext: str) -> str:
pass
@abstractmethod
async def decrypt(self, ciphertext: str) -> str:
pass