Skyvern/skyvern/forge/sdk/services/org_auth_service.py
2025-06-03 10:58:28 -07:00

142 lines
4.6 KiB
Python

import time
from typing import Annotated
import structlog
from asyncache import cached
from cachetools import TTLCache
from fastapi import Header, HTTPException, status
from jose import jwt
from jose.exceptions import JWTError
from pydantic import ValidationError
from skyvern.config import settings
from skyvern.forge import app
from skyvern.forge.sdk.core import skyvern_context
from skyvern.forge.sdk.db.client import AgentDB
from skyvern.forge.sdk.models import TokenPayload
from skyvern.forge.sdk.schemas.organizations import Organization, OrganizationAuthTokenType
LOG = structlog.get_logger()
AUTHENTICATION_TTL = 60 * 60 # one hour
CACHE_SIZE = 128
ALGORITHM = "HS256"
async def get_current_org(
x_api_key: Annotated[
str | None,
Header(
description="Skyvern API key for authentication. API key can be found at https://app.skyvern.com/settings."
),
] = None,
authorization: Annotated[str | None, Header(include_in_schema=False)] = None,
) -> Organization:
if not x_api_key and not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
if x_api_key:
return await _get_current_org_cached(x_api_key, app.DATABASE)
elif authorization:
return await _authenticate_helper(authorization)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
async def get_current_org_with_api_key(
x_api_key: Annotated[str | None, Header()] = None,
) -> Organization:
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _get_current_org_cached(x_api_key, app.DATABASE)
async def get_current_org_with_authentication(
authorization: Annotated[str | None, Header()] = None,
) -> Organization:
if not authorization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return await _authenticate_helper(authorization)
async def _authenticate_helper(authorization: str) -> Organization:
token = authorization.split(" ")[1]
if not app.authentication_function:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid authentication method",
)
organization = await app.authentication_function(token)
if not organization:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
return organization
@cached(cache=TTLCache(maxsize=CACHE_SIZE, ttl=AUTHENTICATION_TTL))
async def _get_current_org_cached(x_api_key: str, db: AgentDB) -> Organization:
"""
Authentication is cached for one hour
"""
try:
payload = jwt.decode(
x_api_key,
settings.SECRET_KEY,
algorithms=[ALGORITHM],
)
api_key_data = TokenPayload(**payload)
except (JWTError, ValidationError):
LOG.error("Error decoding JWT", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
if api_key_data.exp < time.time():
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Auth token is expired",
)
organization = await db.get_organization(organization_id=api_key_data.sub)
if not organization:
LOG.warning("Organization not found", organization_id=api_key_data.sub, **payload)
raise HTTPException(status_code=404, detail="Organization not found")
# check if the token exists in the database
api_key_db_obj = await db.validate_org_auth_token(
organization_id=organization.organization_id,
token_type=OrganizationAuthTokenType.api,
token=x_api_key,
valid=None,
)
if not api_key_db_obj:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid credentials",
)
if api_key_db_obj.valid is False:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Your API key has expired. Please retrieve the latest one from https://app.skyvern.com/settings",
)
# set organization_id in skyvern context and log context
context = skyvern_context.current()
if context:
context.organization_id = organization.organization_id
context.organization_name = organization.organization_name
return organization