mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-16 10:19:42 +00:00
feat: encrypt org auth tokens with AES (#3104)
This commit is contained in:
parent
977c9d4f13
commit
02576e5be3
8 changed files with 192 additions and 9 deletions
|
@ -0,0 +1,42 @@
|
|||
"""add_encrypt_token_and_method
|
||||
|
||||
Revision ID: dd29417b397c
|
||||
Revises: 1eedd7a957d1
|
||||
Create Date: 2025-08-05 04:32:13.735805+00:00
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "dd29417b397c"
|
||||
down_revision: Union[str, None] = "1eedd7a957d1"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("organization_auth_tokens", sa.Column("encrypted_token", sa.String(), nullable=True))
|
||||
op.add_column("organization_auth_tokens", sa.Column("encrypted_method", sa.String(), nullable=True))
|
||||
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.create_index(
|
||||
op.f("ix_organization_auth_tokens_encrypted_token"),
|
||||
"organization_auth_tokens",
|
||||
["encrypted_token"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_organization_auth_tokens_encrypted_token"), table_name="organization_auth_tokens")
|
||||
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=False)
|
||||
op.drop_column("organization_auth_tokens", "encrypted_method")
|
||||
op.drop_column("organization_auth_tokens", "encrypted_token")
|
||||
# ### end Alembic commands ###
|
|
@ -309,6 +309,10 @@ class Settings(BaseSettings):
|
|||
Otherwise we'll consider the persistent browser session to be expired.
|
||||
"""
|
||||
|
||||
ENCRYPTOR_AES_SECRET_KEY: str = "fillmein"
|
||||
ENCRYPTOR_AES_SALT: str | None = None
|
||||
ENCRYPTOR_AES_IV: str | None = None
|
||||
|
||||
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
|
||||
"""
|
||||
Keys are model names available to blocks in the frontend. These map to key names
|
||||
|
|
|
@ -66,6 +66,8 @@ from skyvern.forge.sdk.db.utils import (
|
|||
convert_to_workflow_run_parameter,
|
||||
hydrate_action,
|
||||
)
|
||||
from skyvern.forge.sdk.encrypt import encryptor
|
||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||
|
@ -867,7 +869,7 @@ class AgentDB:
|
|||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).first():
|
||||
return convert_to_organization_auth_token(token)
|
||||
return await convert_to_organization_auth_token(token)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
|
@ -893,7 +895,7 @@ class AgentDB:
|
|||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||
)
|
||||
).all()
|
||||
return [convert_to_organization_auth_token(token) for token in tokens]
|
||||
return [await convert_to_organization_auth_token(token) for token in tokens]
|
||||
except SQLAlchemyError:
|
||||
LOG.error("SQLAlchemyError", exc_info=True)
|
||||
raise
|
||||
|
@ -907,19 +909,27 @@ class AgentDB:
|
|||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
valid: bool | None = True,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken | None:
|
||||
try:
|
||||
encrypted_token = ""
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
|
||||
async with self.Session() as session:
|
||||
query = (
|
||||
select(OrganizationAuthTokenModel)
|
||||
.filter_by(organization_id=organization_id)
|
||||
.filter_by(token_type=token_type)
|
||||
.filter_by(token=token)
|
||||
)
|
||||
if encrypted_token:
|
||||
query = query.filter_by(encrypted_token=encrypted_token)
|
||||
else:
|
||||
query = query.filter_by(token=token)
|
||||
if valid is not None:
|
||||
query = query.filter_by(valid=valid)
|
||||
if token_obj := (await session.scalars(query)).first():
|
||||
return convert_to_organization_auth_token(token_obj)
|
||||
return await convert_to_organization_auth_token(token_obj)
|
||||
else:
|
||||
return None
|
||||
except SQLAlchemyError:
|
||||
|
@ -934,18 +944,28 @@ class AgentDB:
|
|||
organization_id: str,
|
||||
token_type: OrganizationAuthTokenType,
|
||||
token: str,
|
||||
encrypted_method: EncryptMethod | None = None,
|
||||
) -> OrganizationAuthToken:
|
||||
plaintext_token = token
|
||||
encrypted_token = ""
|
||||
|
||||
if encrypted_method is not None:
|
||||
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||
plaintext_token = ""
|
||||
|
||||
async with self.Session() as session:
|
||||
auth_token = OrganizationAuthTokenModel(
|
||||
organization_id=organization_id,
|
||||
token_type=token_type,
|
||||
token=token,
|
||||
token=plaintext_token,
|
||||
encrypted_token=encrypted_token,
|
||||
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
||||
)
|
||||
session.add(auth_token)
|
||||
await session.commit()
|
||||
await session.refresh(auth_token)
|
||||
|
||||
return convert_to_organization_auth_token(auth_token)
|
||||
return await convert_to_organization_auth_token(auth_token)
|
||||
|
||||
async def get_artifacts_for_task_v2(
|
||||
self,
|
||||
|
|
|
@ -167,7 +167,9 @@ class OrganizationAuthTokenModel(Base):
|
|||
|
||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
|
||||
token_type = Column(String, nullable=False)
|
||||
token = Column(String, index=True, nullable=False)
|
||||
token = Column(String, index=True, nullable=True)
|
||||
encrypted_token = Column(String, index=True, nullable=True)
|
||||
encrypted_method = Column(String, nullable=True)
|
||||
valid = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
|
|
|
@ -26,6 +26,8 @@ from skyvern.forge.sdk.db.models import (
|
|||
WorkflowRunOutputParameterModel,
|
||||
WorkflowRunParameterModel,
|
||||
)
|
||||
from skyvern.forge.sdk.encrypt import encryptor
|
||||
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||
from skyvern.forge.sdk.models import Step, StepStatus
|
||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||
|
@ -190,14 +192,18 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
|||
)
|
||||
|
||||
|
||||
def convert_to_organization_auth_token(
|
||||
async def convert_to_organization_auth_token(
|
||||
org_auth_token: OrganizationAuthTokenModel,
|
||||
) -> OrganizationAuthToken:
|
||||
token = org_auth_token.token
|
||||
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
||||
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
|
||||
|
||||
return OrganizationAuthToken(
|
||||
id=org_auth_token.id,
|
||||
organization_id=org_auth_token.organization_id,
|
||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||
token=org_auth_token.token,
|
||||
token=token,
|
||||
valid=org_auth_token.valid,
|
||||
created_at=org_auth_token.created_at,
|
||||
modified_at=org_auth_token.modified_at,
|
||||
|
|
26
skyvern/forge/sdk/encrypt/__init__.py
Normal file
26
skyvern/forge/sdk/encrypt/__init__.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
|
||||
|
||||
|
||||
class Encryptor(BaseModel):
|
||||
def __init__(self) -> None:
|
||||
self._methods: dict[EncryptMethod, BaseEncryptor] = {}
|
||||
|
||||
def add_encrypt_method(self, encrypt_method: BaseEncryptor) -> None:
|
||||
self._methods[encrypt_method.method()] = encrypt_method
|
||||
|
||||
async def encrypt(self, plaintext: str, method: EncryptMethod) -> str:
|
||||
if method not in self._methods:
|
||||
raise ValueError(f"encrypt method not registered: {method}")
|
||||
|
||||
return await self._methods[method].encrypt(plaintext)
|
||||
|
||||
async def decrypt(self, ciphertext: str, method: EncryptMethod) -> str:
|
||||
if method not in self._methods:
|
||||
raise ValueError(f"encrypt method not registered: {method}")
|
||||
|
||||
return await self._methods[method].decrypt(ciphertext)
|
||||
|
||||
|
||||
encryptor = Encryptor()
|
63
skyvern/forge/sdk/encrypt/aes.py
Normal file
63
skyvern/forge/sdk/encrypt/aes.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import base64
|
||||
import hashlib
|
||||
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
|
||||
|
||||
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
|
||||
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
|
||||
|
||||
|
||||
class AES(BaseEncryptor):
|
||||
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
|
||||
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
|
||||
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
|
||||
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
|
||||
|
||||
def method(self) -> EncryptMethod:
|
||||
return EncryptMethod.AES
|
||||
|
||||
def _derive_key(self) -> bytes:
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=self.salt,
|
||||
iterations=100000,
|
||||
)
|
||||
return kdf.derive(self.secret_key)
|
||||
|
||||
async def encrypt(self, plaintext: str) -> str:
|
||||
try:
|
||||
key = self._derive_key()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
|
||||
encryptor = cipher.encryptor()
|
||||
padded_plaintext = self._pad(plaintext.encode("utf-8"))
|
||||
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
|
||||
return base64.b64encode(ciphertext).decode("utf-8")
|
||||
except Exception as e:
|
||||
raise Exception("Failed to encrypt token") from e
|
||||
|
||||
async def decrypt(self, ciphertext: str) -> str:
|
||||
try:
|
||||
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
|
||||
key = self._derive_key()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
|
||||
decryptor = cipher.decryptor()
|
||||
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
plaintext = self._unpad(padded_plaintext)
|
||||
return plaintext.decode("utf-8")
|
||||
except Exception as e:
|
||||
raise Exception("Failed to decrypt token") from e
|
||||
|
||||
def _pad(self, data: bytes) -> bytes:
|
||||
block_size = 16
|
||||
padding_length = block_size - (len(data) % block_size)
|
||||
padding = bytes([padding_length] * padding_length)
|
||||
return data + padding
|
||||
|
||||
def _unpad(self, data: bytes) -> bytes:
|
||||
padding_length = data[-1]
|
||||
return data[:-padding_length]
|
20
skyvern/forge/sdk/encrypt/base.py
Normal file
20
skyvern/forge/sdk/encrypt/base.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EncryptMethod(Enum):
|
||||
AES = "aes"
|
||||
|
||||
|
||||
class BaseEncryptor(ABC):
|
||||
@abstractmethod
|
||||
def method(self) -> EncryptMethod:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def encrypt(self, plaintext: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def decrypt(self, ciphertext: str) -> str:
|
||||
pass
|
Loading…
Add table
Add a link
Reference in a new issue