mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-16 10:19:42 +00:00
feat: encrypt org auth tokens with AES (#3104)
This commit is contained in:
parent
977c9d4f13
commit
02576e5be3
8 changed files with 192 additions and 9 deletions
|
@ -0,0 +1,42 @@
|
||||||
|
"""add_encrypt_token_and_method
|
||||||
|
|
||||||
|
Revision ID: dd29417b397c
|
||||||
|
Revises: 1eedd7a957d1
|
||||||
|
Create Date: 2025-08-05 04:32:13.735805+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "dd29417b397c"
|
||||||
|
down_revision: Union[str, None] = "1eedd7a957d1"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column("organization_auth_tokens", sa.Column("encrypted_token", sa.String(), nullable=True))
|
||||||
|
op.add_column("organization_auth_tokens", sa.Column("encrypted_method", sa.String(), nullable=True))
|
||||||
|
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=True)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_organization_auth_tokens_encrypted_token"),
|
||||||
|
"organization_auth_tokens",
|
||||||
|
["encrypted_token"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f("ix_organization_auth_tokens_encrypted_token"), table_name="organization_auth_tokens")
|
||||||
|
op.alter_column("organization_auth_tokens", "token", existing_type=sa.VARCHAR(), nullable=False)
|
||||||
|
op.drop_column("organization_auth_tokens", "encrypted_method")
|
||||||
|
op.drop_column("organization_auth_tokens", "encrypted_token")
|
||||||
|
# ### end Alembic commands ###
|
|
@ -309,6 +309,10 @@ class Settings(BaseSettings):
|
||||||
Otherwise we'll consider the persistent browser session to be expired.
|
Otherwise we'll consider the persistent browser session to be expired.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ENCRYPTOR_AES_SECRET_KEY: str = "fillmein"
|
||||||
|
ENCRYPTOR_AES_SALT: str | None = None
|
||||||
|
ENCRYPTOR_AES_IV: str | None = None
|
||||||
|
|
||||||
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
|
def get_model_name_to_llm_key(self) -> dict[str, dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Keys are model names available to blocks in the frontend. These map to key names
|
Keys are model names available to blocks in the frontend. These map to key names
|
||||||
|
|
|
@ -66,6 +66,8 @@ from skyvern.forge.sdk.db.utils import (
|
||||||
convert_to_workflow_run_parameter,
|
convert_to_workflow_run_parameter,
|
||||||
hydrate_action,
|
hydrate_action,
|
||||||
)
|
)
|
||||||
|
from skyvern.forge.sdk.encrypt import encryptor
|
||||||
|
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||||
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
|
from skyvern.forge.sdk.log_artifacts import save_workflow_run_logs
|
||||||
from skyvern.forge.sdk.models import Step, StepStatus
|
from skyvern.forge.sdk.models import Step, StepStatus
|
||||||
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
from skyvern.forge.sdk.schemas.ai_suggestions import AISuggestion
|
||||||
|
@ -867,7 +869,7 @@ class AgentDB:
|
||||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||||
)
|
)
|
||||||
).first():
|
).first():
|
||||||
return convert_to_organization_auth_token(token)
|
return await convert_to_organization_auth_token(token)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
|
@ -893,7 +895,7 @@ class AgentDB:
|
||||||
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
.order_by(OrganizationAuthTokenModel.created_at.desc())
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
return [convert_to_organization_auth_token(token) for token in tokens]
|
return [await convert_to_organization_auth_token(token) for token in tokens]
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
LOG.error("SQLAlchemyError", exc_info=True)
|
LOG.error("SQLAlchemyError", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
@ -907,19 +909,27 @@ class AgentDB:
|
||||||
token_type: OrganizationAuthTokenType,
|
token_type: OrganizationAuthTokenType,
|
||||||
token: str,
|
token: str,
|
||||||
valid: bool | None = True,
|
valid: bool | None = True,
|
||||||
|
encrypted_method: EncryptMethod | None = None,
|
||||||
) -> OrganizationAuthToken | None:
|
) -> OrganizationAuthToken | None:
|
||||||
try:
|
try:
|
||||||
|
encrypted_token = ""
|
||||||
|
if encrypted_method is not None:
|
||||||
|
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||||
|
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
query = (
|
query = (
|
||||||
select(OrganizationAuthTokenModel)
|
select(OrganizationAuthTokenModel)
|
||||||
.filter_by(organization_id=organization_id)
|
.filter_by(organization_id=organization_id)
|
||||||
.filter_by(token_type=token_type)
|
.filter_by(token_type=token_type)
|
||||||
.filter_by(token=token)
|
|
||||||
)
|
)
|
||||||
|
if encrypted_token:
|
||||||
|
query = query.filter_by(encrypted_token=encrypted_token)
|
||||||
|
else:
|
||||||
|
query = query.filter_by(token=token)
|
||||||
if valid is not None:
|
if valid is not None:
|
||||||
query = query.filter_by(valid=valid)
|
query = query.filter_by(valid=valid)
|
||||||
if token_obj := (await session.scalars(query)).first():
|
if token_obj := (await session.scalars(query)).first():
|
||||||
return convert_to_organization_auth_token(token_obj)
|
return await convert_to_organization_auth_token(token_obj)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
|
@ -934,18 +944,28 @@ class AgentDB:
|
||||||
organization_id: str,
|
organization_id: str,
|
||||||
token_type: OrganizationAuthTokenType,
|
token_type: OrganizationAuthTokenType,
|
||||||
token: str,
|
token: str,
|
||||||
|
encrypted_method: EncryptMethod | None = None,
|
||||||
) -> OrganizationAuthToken:
|
) -> OrganizationAuthToken:
|
||||||
|
plaintext_token = token
|
||||||
|
encrypted_token = ""
|
||||||
|
|
||||||
|
if encrypted_method is not None:
|
||||||
|
encrypted_token = await encryptor.encrypt(token, encrypted_method)
|
||||||
|
plaintext_token = ""
|
||||||
|
|
||||||
async with self.Session() as session:
|
async with self.Session() as session:
|
||||||
auth_token = OrganizationAuthTokenModel(
|
auth_token = OrganizationAuthTokenModel(
|
||||||
organization_id=organization_id,
|
organization_id=organization_id,
|
||||||
token_type=token_type,
|
token_type=token_type,
|
||||||
token=token,
|
token=plaintext_token,
|
||||||
|
encrypted_token=encrypted_token,
|
||||||
|
encrypted_method=encrypted_method.value if encrypted_method is not None else "",
|
||||||
)
|
)
|
||||||
session.add(auth_token)
|
session.add(auth_token)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(auth_token)
|
await session.refresh(auth_token)
|
||||||
|
|
||||||
return convert_to_organization_auth_token(auth_token)
|
return await convert_to_organization_auth_token(auth_token)
|
||||||
|
|
||||||
async def get_artifacts_for_task_v2(
|
async def get_artifacts_for_task_v2(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -167,7 +167,9 @@ class OrganizationAuthTokenModel(Base):
|
||||||
|
|
||||||
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
|
organization_id = Column(String, ForeignKey("organizations.organization_id"), index=True, nullable=False)
|
||||||
token_type = Column(String, nullable=False)
|
token_type = Column(String, nullable=False)
|
||||||
token = Column(String, index=True, nullable=False)
|
token = Column(String, index=True, nullable=True)
|
||||||
|
encrypted_token = Column(String, index=True, nullable=True)
|
||||||
|
encrypted_method = Column(String, nullable=True)
|
||||||
valid = Column(Boolean, nullable=False, default=True)
|
valid = Column(Boolean, nullable=False, default=True)
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
|
|
|
@ -26,6 +26,8 @@ from skyvern.forge.sdk.db.models import (
|
||||||
WorkflowRunOutputParameterModel,
|
WorkflowRunOutputParameterModel,
|
||||||
WorkflowRunParameterModel,
|
WorkflowRunParameterModel,
|
||||||
)
|
)
|
||||||
|
from skyvern.forge.sdk.encrypt import encryptor
|
||||||
|
from skyvern.forge.sdk.encrypt.base import EncryptMethod
|
||||||
from skyvern.forge.sdk.models import Step, StepStatus
|
from skyvern.forge.sdk.models import Step, StepStatus
|
||||||
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthToken
|
||||||
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
from skyvern.forge.sdk.schemas.tasks import Task, TaskStatus
|
||||||
|
@ -190,14 +192,18 @@ def convert_to_organization(org_model: OrganizationModel) -> Organization:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_organization_auth_token(
|
async def convert_to_organization_auth_token(
|
||||||
org_auth_token: OrganizationAuthTokenModel,
|
org_auth_token: OrganizationAuthTokenModel,
|
||||||
) -> OrganizationAuthToken:
|
) -> OrganizationAuthToken:
|
||||||
|
token = org_auth_token.token
|
||||||
|
if org_auth_token.encrypted_token and org_auth_token.encrypted_method:
|
||||||
|
token = await encryptor.decrypt(org_auth_token.encrypted_token, EncryptMethod(org_auth_token.encrypted_method))
|
||||||
|
|
||||||
return OrganizationAuthToken(
|
return OrganizationAuthToken(
|
||||||
id=org_auth_token.id,
|
id=org_auth_token.id,
|
||||||
organization_id=org_auth_token.organization_id,
|
organization_id=org_auth_token.organization_id,
|
||||||
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
token_type=OrganizationAuthTokenType(org_auth_token.token_type),
|
||||||
token=org_auth_token.token,
|
token=token,
|
||||||
valid=org_auth_token.valid,
|
valid=org_auth_token.valid,
|
||||||
created_at=org_auth_token.created_at,
|
created_at=org_auth_token.created_at,
|
||||||
modified_at=org_auth_token.modified_at,
|
modified_at=org_auth_token.modified_at,
|
||||||
|
|
26
skyvern/forge/sdk/encrypt/__init__.py
Normal file
26
skyvern/forge/sdk/encrypt/__init__.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
|
||||||
|
|
||||||
|
|
||||||
|
class Encryptor(BaseModel):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._methods: dict[EncryptMethod, BaseEncryptor] = {}
|
||||||
|
|
||||||
|
def add_encrypt_method(self, encrypt_method: BaseEncryptor) -> None:
|
||||||
|
self._methods[encrypt_method.method()] = encrypt_method
|
||||||
|
|
||||||
|
async def encrypt(self, plaintext: str, method: EncryptMethod) -> str:
|
||||||
|
if method not in self._methods:
|
||||||
|
raise ValueError(f"encrypt method not registered: {method}")
|
||||||
|
|
||||||
|
return await self._methods[method].encrypt(plaintext)
|
||||||
|
|
||||||
|
async def decrypt(self, ciphertext: str, method: EncryptMethod) -> str:
|
||||||
|
if method not in self._methods:
|
||||||
|
raise ValueError(f"encrypt method not registered: {method}")
|
||||||
|
|
||||||
|
return await self._methods[method].decrypt(ciphertext)
|
||||||
|
|
||||||
|
|
||||||
|
encryptor = Encryptor()
|
63
skyvern/forge/sdk/encrypt/aes.py
Normal file
63
skyvern/forge/sdk/encrypt/aes.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||||
|
|
||||||
|
from skyvern.forge.sdk.encrypt.base import BaseEncryptor, EncryptMethod
|
||||||
|
|
||||||
|
default_iv = hashlib.md5(b"deterministic_iv_0123456789").digest()
|
||||||
|
default_salt = hashlib.md5(b"deterministic_salt_0123456789").digest()
|
||||||
|
|
||||||
|
|
||||||
|
class AES(BaseEncryptor):
|
||||||
|
def __init__(self, *, secret_key: str, salt: str | None = None, iv: str | None = None) -> None:
|
||||||
|
self.secret_key = hashlib.md5(secret_key.encode("utf-8")).digest()
|
||||||
|
self.salt = hashlib.md5(salt.encode("utf-8")).digest() if salt else default_salt
|
||||||
|
self.iv = hashlib.md5(iv.encode("utf-8")).digest() if iv else default_iv
|
||||||
|
|
||||||
|
def method(self) -> EncryptMethod:
|
||||||
|
return EncryptMethod.AES
|
||||||
|
|
||||||
|
def _derive_key(self) -> bytes:
|
||||||
|
kdf = PBKDF2HMAC(
|
||||||
|
algorithm=hashes.SHA256(),
|
||||||
|
length=32,
|
||||||
|
salt=self.salt,
|
||||||
|
iterations=100000,
|
||||||
|
)
|
||||||
|
return kdf.derive(self.secret_key)
|
||||||
|
|
||||||
|
async def encrypt(self, plaintext: str) -> str:
|
||||||
|
try:
|
||||||
|
key = self._derive_key()
|
||||||
|
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
|
||||||
|
encryptor = cipher.encryptor()
|
||||||
|
padded_plaintext = self._pad(plaintext.encode("utf-8"))
|
||||||
|
ciphertext = encryptor.update(padded_plaintext) + encryptor.finalize()
|
||||||
|
return base64.b64encode(ciphertext).decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Failed to encrypt token") from e
|
||||||
|
|
||||||
|
async def decrypt(self, ciphertext: str) -> str:
|
||||||
|
try:
|
||||||
|
encrypted_data = base64.b64decode(ciphertext.encode("utf-8"))
|
||||||
|
key = self._derive_key()
|
||||||
|
cipher = Cipher(algorithms.AES(key), modes.CBC(self.iv))
|
||||||
|
decryptor = cipher.decryptor()
|
||||||
|
padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||||
|
plaintext = self._unpad(padded_plaintext)
|
||||||
|
return plaintext.decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("Failed to decrypt token") from e
|
||||||
|
|
||||||
|
def _pad(self, data: bytes) -> bytes:
|
||||||
|
block_size = 16
|
||||||
|
padding_length = block_size - (len(data) % block_size)
|
||||||
|
padding = bytes([padding_length] * padding_length)
|
||||||
|
return data + padding
|
||||||
|
|
||||||
|
def _unpad(self, data: bytes) -> bytes:
|
||||||
|
padding_length = data[-1]
|
||||||
|
return data[:-padding_length]
|
20
skyvern/forge/sdk/encrypt/base.py
Normal file
20
skyvern/forge/sdk/encrypt/base.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EncryptMethod(Enum):
|
||||||
|
AES = "aes"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEncryptor(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def method(self) -> EncryptMethod:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def encrypt(self, plaintext: str) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def decrypt(self, ciphertext: str) -> str:
|
||||||
|
pass
|
Loading…
Add table
Add a link
Reference in a new issue