feat: encrypt org auth tokens with AES (#3104)
Some checks are pending
Run tests and pre-commit / Run tests and pre-commit hooks (push) Waiting to run
Run tests and pre-commit / Frontend Lint and Build (push) Waiting to run
Publish Fern Docs / run (push) Waiting to run

This commit is contained in:
LawyZheng 2025-08-05 12:36:24 +08:00 committed by GitHub
parent 977c9d4f13
commit 02576e5be3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 192 additions and 9 deletions

View file

@ -0,0 +1,42 @@
"""add_encrypt_token_and_method
Revision ID: dd29417b397c
Revises: 1eedd7a957d1
Create Date: 2025-08-05 04:32:13.735805+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "dd29417b397c"
down_revision: Union[str, None] = "1eedd7a957d1"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("organization_auth_tokens", sa.Column("encrypted_token", sa.String(), nullable=True))
op.add_column("organization_auth_tokens", sa.Column("encrypted_method", sa.String(), nullable=True))
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=True)
op.create_index(
op.f("ix_organization_auth_tokens_encrypted_token"),
"organization_auth_tokens",
["encrypted_token"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_organization_auth_tokens_encrypted_token"), table_name="organization_auth_tokens")
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=False)
op.drop_column("organization_auth_tokens", "encrypted_method")
op.drop_column("organization_auth_tokens", "encrypted_token")
# ### end Alembic commands ###

View file

@ -309,6 +309,10 @@ class Settings(BaseSettings):
Otherwise we'll consider the persistent browser session to be expired. Otherwise we'll consider the persistent browser session to be expired.
""" """
ENCRYPTOR_AES_SECRET_KEY: str = "fillmein"
ENCRYPTOR_AES_SALT: str | None = None
ENCRYPTOR_AES_IV: str | None = None
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]: def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
""" """
Keys are model names available to blocks in the frontend. These map to key names Keys are model names available to blocks in the frontend. These map to key names

View file

@ -66,6 +66,8 @@ from skyvern.forge.sdk.db.utils import (
convert_to_workflow_run_parameter, convert_to_workflow_run_parameter,
hydrate_action, hydrate_action,
) )
from skyvern.forge.sdk.encrypt import encryptor
from skyvern.forge.sdk.encrypt.base import EncryptMethod
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
@ -867,7 +869,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc()) .order_by(OrganizationAuthTokenModel.created_at.desc())
) )
).first(): ).first():
return convert_to_organization_auth_token(token) return await convert_to_organization_auth_token(token)
else: else:
return None return None
except SQLAlchemyError: except SQLAlchemyError:
@ -893,7 +895,7 @@ class AgentDB:
.order_by(OrganizationAuthTokenModel.created_at.desc()) .order_by(OrganizationAuthTokenModel.created_at.desc())
) )
).all() ).all()
return [convert_to_organization_auth_token(token) for token in tokens] return [await convert_to_organization_auth_token(token) for token in tokens]
except SQLAlchemyError: except SQLAlchemyError:
LOG.error("SQLAlchemyError", exc_info=True) LOG.error("SQLAlchemyError", exc_info=True)
raise raise
@ -907,19 +909,27 @@ class AgentDB:
token_type: OrganizationAuthTokenType, token_type: OrganizationAuthTokenType,
token: str, token: str,
valid: bool | None = True, valid: bool | None = True,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken | None: ) -> OrganizationAuthToken | None:
try: try:
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
async with self.Session() as session: async with self.Session() as session:
query = ( query = (
select(OrganizationAuthTokenModel) select(OrganizationAuthTokenModel)
.filter_by(organization_id=organization_id) .filter_by(organization_id=organization_id)
.filter_by(token_type=token_type) .filter_by(token_type=token_type)
.filter_by(token=token)
) )
if encrypted_token:
query = query.filter_by(encrypted_token=encrypted_token)
else:
query = query.filter_by(token=token)
if valid is not None: if valid is not None:
query = query.filter_by(valid=valid) query = query.filter_by(valid=valid)
if token_obj := (await session.scalars(query)).first(): if token_obj := (await session.scalars(query)).first():
return convert_to_organization_auth_token(token_obj) return await convert_to_organization_auth_token(token_obj)
else: else:
return None return None
except SQLAlchemyError: except SQLAlchemyError:
@ -934,18 +944,28 @@ class AgentDB:
organization_id: str, organization_id: str,
token_type: OrganizationAuthTokenType, token_type: OrganizationAuthTokenType,
token: str, token: str,
encrypted_method: EncryptMethod | None = None,
) -> OrganizationAuthToken: ) -> OrganizationAuthToken:
plaintext_token = token
encrypted_token = ""
if encrypted_method is not None:
encrypted_token = await encryptor.encrypt(token, encrypted_method)
plaintext_token = ""
async with self.Session() as session: async with self.Session() as session:
auth_token = OrganizationAuthTokenModel( auth_token = OrganizationAuthTokenModel(
organization_id=organization_id, organization_id=organization_id,
token_type=token_type, token_type=token_type,
token=token, token=plaintext_token,
encrypted_token=encrypted_token,
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
) )
session.add(auth_token) session.add(auth_token)
await session.commit() await session.commit()
await session.refresh(auth_token) await session.refresh(auth_token)
return convert_to_organization_auth_token(auth_token) return await convert_to_organization_auth_token(auth_token)
async def get_artifacts_for_task_v2( async def get_artifacts_for_task_v2(
self, self,

View file

@ -167,7 +167,9 @@ class OrganizationAuthTokenModel(Base):
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False) organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
token_type = Column(String, nullable=False) token_type = Column(String, nullable=False)
token = Column(String, index=True, nullable=False) token = Column(String, index=True, nullable=True)
encrypted_token = Column(String, index=True, nullable=True)
encrypted_method = Column(String, nullable=True)
valid = Column(Boolean, nullable=False, default=True) valid = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)

View file

@ -26,6 +26,8 @@ from skyvern.forge.sdk.db.models import (
WorkflowRunOutputParameterModel, WorkflowRunOutputParameterModel,
WorkflowRunParameterModel, WorkflowRunParameterModel,
) )
from skyvern.forge.sdk.encrypt import encryptor
from skyvern.forge.sdk.encrypt.base import EncryptMethod
from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.models import Step, StepStatus
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
@ -190,14 +192,18 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
) )
def convert_to_organization_auth_token( async def convert_to_organization_auth_token(
org_auth_token: OrganizationAuthTokenModel, org_auth_token: OrganizationAuthTokenModel,
) -> OrganizationAuthToken: ) -> OrganizationAuthToken:
token = org_auth_token.token
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
return OrganizationAuthToken( return OrganizationAuthToken(
id=org_auth_token.id, id=org_auth_token.id,
organization_id=org_auth_token.organization_id, organization_id=org_auth_token.organization_id,
token_type=OrganizationAuthTokenType(org_auth_token.token_type), token_type=OrganizationAuthTokenType(org_auth_token.token_type),
token=org_auth_token.token, token=token,
valid=org_auth_token.valid, valid=org_auth_token.valid,
created_at=org_auth_token.created_at, created_at=org_auth_token.created_at,
modified_at=org_auth_token.modified_at, modified_at=org_auth_token.modified_at,

View file

@ -0,0 +1,26 @@
from pydantic import BaseModel
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
class Encryptor(BaseModel):
def __init__(self) -> None:
self._methods: dict[EncryptMethod, BaseEncryptor] = {}
def add_encrypt_method(self, encrypt_method: BaseEncryptor) -> None:
self._methods[encrypt_method.method()] = encrypt_method
async def encrypt(self, plaintext: str, method: EncryptMethod) -> str:
if method not in self._methods:
raise ValueError(f"encrypt method not registered: {method}")
return await self._methods[method].encrypt(plaintext)
async def decrypt(self, ciphertext: str, method: EncryptMethod) -> str:
if method not in self._methods:
raise ValueError(f"encrypt method not registered: {method}")
return await self._methods[method].decrypt(ciphertext)
encryptor = Encryptor()

View file

@ -0,0 +1,63 @@
import base64
import hashlib
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
class AES(BaseEncryptor):
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
def method(self) -> EncryptMethod:
return EncryptMethod.AES
def _derive_key(self) -> bytes:
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=self.salt,
iterations=100000,
)
return kdf.derive(self.secret_key)
async def encrypt(self, plaintext: str) -> str:
try:
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
encryptor = cipher.encryptor()
padded_plaintext = self._pad(plaintext.encode("utf-8"))
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
return base64.b64encode(ciphertext).decode("utf-8")
except Exception as e:
raise Exception("Failed to encrypt token") from e
async def decrypt(self, ciphertext: str) -> str:
try:
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
key = self._derive_key()
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
decryptor = cipher.decryptor()
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
plaintext = self._unpad(padded_plaintext)
return plaintext.decode("utf-8")
except Exception as e:
raise Exception("Failed to decrypt token") from e
def _pad(self, data: bytes) -> bytes:
block_size = 16
padding_length = block_size - (len(data) % block_size)
padding = bytes([padding_length] * padding_length)
return data + padding
def _unpad(self, data: bytes) -> bytes:
padding_length = data[-1]
return data[:-padding_length]

View file

@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from enum import Enum
class EncryptMethod(Enum):
AES = "aes"
class BaseEncryptor(ABC):
@abstractmethod
def method(self) -> EncryptMethod:
pass
@abstractmethod
async def encrypt(self, plaintext: str) -> str:
pass
@abstractmethod
async def decrypt(self, ciphertext: str) -> str:
pass