diff --git a/alembic/versions/2025_08_05_0432-dd29417b397c_add_encrypt_token_and_method.py b/alembic/versions/2025_08_05_0432-dd29417b397c_add_encrypt_token_and_method.py new file mode 100644 index 00000000..996ffcee --- /dev/null +++ b/alembic/versions/2025_08_05_0432-dd29417b397c_add_encrypt_token_and_method.py @@ -0,0 +1,42 @@ +"""add_encrypt_token_and_method + +Revision ID: dd29417b397c +Revises: 1eedd7a957d1 +Create Date: 2025-08-05 04:32:13.735805+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "dd29417b397c" +down_revision: Union[str, None] = "1eedd7a957d1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("organization_auth_tokens", sa.Column("encrypted_token", sa.String(), nullable=True)) + op.add_column("organization_auth_tokens", sa.Column("encrypted_method", sa.String(), nullable=True)) + op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=True) + op.create_index( + op.f("ix_organization_auth_tokens_encrypted_token"), + "organization_auth_tokens", + ["encrypted_token"], + unique=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_organization_auth_tokens_encrypted_token"), table_name="organization_auth_tokens") + op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=False) + op.drop_column("organization_auth_tokens", "encrypted_method") + op.drop_column("organization_auth_tokens", "encrypted_token") + # ### end Alembic commands ### diff --git a/skyvern/config.py b/skyvern/config.py index ba34c686..2b4146d6 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -309,6 +309,10 @@ class Settings(BaseSettings): Otherwise we'll consider the persistent browser session to be expired. """ + ENCRYPTOR_AES_SECRET_KEY: str = "fillmein" + ENCRYPTOR_AES_SALT: str | None = None + ENCRYPTOR_AES_IV: str | None = None + def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]: """ Keys are model names available to blocks in the frontend. These map to key names diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index a5cb68cd..a3bc4393 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -66,6 +66,8 @@ from skyvern.forge.sdk.db.utils import ( convert_to_workflow_run_parameter, hydrate_action, ) +from skyvern.forge.sdk.encrypt import encryptor +from skyvern.forge.sdk.encrypt.base import EncryptMethod from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion @@ -867,7 +869,7 @@ class AgentDB: .order_by(OrganizationAuthTokenModel.created_at.desc()) ) ).first(): - return convert_to_organization_auth_token(token) + return await convert_to_organization_auth_token(token) else: return None except SQLAlchemyError: @@ -893,7 +895,7 @@ class AgentDB: .order_by(OrganizationAuthTokenModel.created_at.desc()) ) ).all() - return [convert_to_organization_auth_token(token) for token in tokens] + return [await convert_to_organization_auth_token(token) for token in tokens] except SQLAlchemyError: LOG.error("SQLAlchemyError", exc_info=True) raise @@ -907,19 +909,27 @@ class AgentDB: token_type: OrganizationAuthTokenType, token: str, valid: bool | None = True, + encrypted_method: EncryptMethod | None = None, ) -> OrganizationAuthToken | None: try: + encrypted_token = "" + if encrypted_method is not None: + encrypted_token = await encryptor.encrypt(token, encrypted_method) + async with self.Session() as session: query = ( select(OrganizationAuthTokenModel) .filter_by(organization_id=organization_id) .filter_by(token_type=token_type) - .filter_by(token=token) ) + if encrypted_token: + query = query.filter_by(encrypted_token=encrypted_token) + else: + query = query.filter_by(token=token) if valid is not None: query = query.filter_by(valid=valid) if token_obj := (await session.scalars(query)).first(): - return convert_to_organization_auth_token(token_obj) + return await convert_to_organization_auth_token(token_obj) else: return None except SQLAlchemyError: @@ -934,18 +944,28 @@ class AgentDB: organization_id: str, token_type: OrganizationAuthTokenType, token: str, + encrypted_method: EncryptMethod | None = None, ) -> OrganizationAuthToken: + plaintext_token = token + encrypted_token = "" + + if encrypted_method is not None: + encrypted_token = await encryptor.encrypt(token, encrypted_method) + plaintext_token = "" + async with self.Session() as session: auth_token = OrganizationAuthTokenModel( organization_id=organization_id, token_type=token_type, - token=token, + token=plaintext_token, + encrypted_token=encrypted_token, + encrypted_method=encrypted_method.value if encrypted_method is not None else "", ) session.add(auth_token) await session.commit() await session.refresh(auth_token) - return convert_to_organization_auth_token(auth_token) + return await convert_to_organization_auth_token(auth_token) async def get_artifacts_for_task_v2( self, diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 618c5666..3e335b17 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -167,7 +167,9 @@ class OrganizationAuthTokenModel(Base): organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False) token_type = Column(String, nullable=False) - token = Column(String, index=True, nullable=False) + token = Column(String, index=True, nullable=True) + encrypted_token = Column(String, index=True, nullable=True) + encrypted_method = Column(String, nullable=True) valid = Column(Boolean, nullable=False, default=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index 342ba401..1ad90a11 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -26,6 +26,8 @@ from skyvern.forge.sdk.db.models import ( WorkflowRunOutputParameterModel, WorkflowRunParameterModel, ) +from skyvern.forge.sdk.encrypt import encryptor +from skyvern.forge.sdk.encrypt.base import EncryptMethod from skyvern.forge.sdk.models import Step, StepStatus from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus @@ -190,14 +192,18 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization: ) -def convert_to_organization_auth_token( +async def convert_to_organization_auth_token( org_auth_token: OrganizationAuthTokenModel, ) -> OrganizationAuthToken: + token = org_auth_token.token + if org_auth_token.encrypted_token and org_auth_token.encrypted_method: + token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method)) + return OrganizationAuthToken( id=org_auth_token.id, organization_id=org_auth_token.organization_id, token_type=OrganizationAuthTokenType(org_auth_token.token_type), - token=org_auth_token.token, + token=token, valid=org_auth_token.valid, created_at=org_auth_token.created_at, modified_at=org_auth_token.modified_at, diff --git a/skyvern/forge/sdk/encrypt/__init__.py b/skyvern/forge/sdk/encrypt/__init__.py new file mode 100644 index 00000000..02499c12 --- /dev/null +++ b/skyvern/forge/sdk/encrypt/__init__.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel + +from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod + + +class Encryptor(BaseModel): + def __init__(self) -> None: + self._methods: dict[EncryptMethod, BaseEncryptor] = {} + + def add_encrypt_method(self, encrypt_method: BaseEncryptor) -> None: + self._methods[encrypt_method.method()] = encrypt_method + + async def encrypt(self, plaintext: str, method: EncryptMethod) -> str: + if method not in self._methods: + raise ValueError(f"encrypt method not registered: {method}") + + return await self._methods[method].encrypt(plaintext) + + async def decrypt(self, ciphertext: str, method: EncryptMethod) -> str: + if method not in self._methods: + raise ValueError(f"encrypt method not registered: {method}") + + return await self._methods[method].decrypt(ciphertext) + + +encryptor = Encryptor() diff --git a/skyvern/forge/sdk/encrypt/aes.py b/skyvern/forge/sdk/encrypt/aes.py new file mode 100644 index 00000000..31784a96 --- /dev/null +++ b/skyvern/forge/sdk/encrypt/aes.py @@ -0,0 +1,63 @@ +import base64 +import hashlib + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod + +default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest() +default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest() + + +class AES(BaseEncryptor): + def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None: + self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest() + self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt + self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv + + def method(self) -> EncryptMethod: + return EncryptMethod.AES + + def _derive_key(self) -> bytes: + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=self.salt, + iterations=100000, + ) + return kdf.derive(self.secret_key) + + async def encrypt(self, plaintext: str) -> str: + try: + key = self._derive_key() + cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv)) + encryptor = cipher.encryptor() + padded_plaintext = self._pad(plaintext.encode("utf-8")) + ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize() + return base64.b64encode(ciphertext).decode("utf-8") + except Exception as e: + raise Exception("Failed to encrypt token") from e + + async def decrypt(self, ciphertext: str) -> str: + try: + encrypted_data = base64.b64decode(ciphertext.encode("utf-8")) + key = self._derive_key() + cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv)) + decryptor = cipher.decryptor() + padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize() + plaintext = self._unpad(padded_plaintext) + return plaintext.decode("utf-8") + except Exception as e: + raise Exception("Failed to decrypt token") from e + + def _pad(self, data: bytes) -> bytes: + block_size = 16 + padding_length = block_size - (len(data) % block_size) + padding = bytes([padding_length] * padding_length) + return data + padding + + def _unpad(self, data: bytes) -> bytes: + padding_length = data[-1] + return data[:-padding_length] diff --git a/skyvern/forge/sdk/encrypt/base.py b/skyvern/forge/sdk/encrypt/base.py new file mode 100644 index 00000000..31fb8653 --- /dev/null +++ b/skyvern/forge/sdk/encrypt/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from enum import Enum + + +class EncryptMethod(Enum): + AES = "aes" + + +class BaseEncryptor(ABC): + @abstractmethod + def method(self) -> EncryptMethod: + pass + + @abstractmethod + async def encrypt(self, plaintext: str) -> str: + pass + + @abstractmethod + async def decrypt(self, ciphertext: str) -> str: + pass