Skyvern/tests/unit/test_org_auth_service.py

177 lines
7.2 KiB
Python

from types import SimpleNamespace
import jwt
import pytest
from fastapi import HTTPException
from skyvern.config import settings
from skyvern.forge.sdk.core.security import create_access_token
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.services.org_auth_service import (
_get_api_key_debug_fields,
_normalize_api_key_with_flags,
)
def test_normalize_api_key_strips_whitespace() -> None:
raw_api_key = " token.value.parts \n"
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "token.value.parts"
assert debug_fields["api_key_had_whitespace_padding"] is True
assert debug_fields["api_key_was_normalized"] is True
def test_normalize_api_key_strips_outer_quotes() -> None:
raw_api_key = '"token.value.parts"'
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "token.value.parts"
assert debug_fields["api_key_had_outer_quotes"] is True
assert debug_fields["api_key_was_normalized"] is True
def test_normalize_api_key_strips_bearer_prefix() -> None:
raw_api_key = "Bearer token.value.parts"
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "token.value.parts"
assert debug_fields["api_key_had_bearer_prefix"] is True
assert debug_fields["api_key_normalized_segment_count"] == 3
def test_normalize_api_key_handles_quoted_bearer_value() -> None:
raw_api_key = '"Bearer token.value.parts"'
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "token.value.parts"
assert debug_fields["api_key_had_bearer_prefix"] is True
assert debug_fields["api_key_had_outer_quotes"] is True
def test_normalize_api_key_tracks_whitespace_removed_after_wrapper_stripping() -> None:
raw_api_key = 'Bearer " token.value.parts "'
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "token.value.parts"
assert debug_fields["api_key_had_whitespace_padding"] is True
assert debug_fields["api_key_had_bearer_prefix"] is True
assert debug_fields["api_key_had_outer_quotes"] is True
def test_debug_fields_report_no_shadow_decode_for_unchanged_value() -> None:
raw_api_key = "token.value.parts"
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "token.value.parts"
assert debug_fields["api_key_was_normalized"] is False
assert debug_fields["normalized_api_key_decodes"] is None
assert debug_fields["normalized_api_key_would_be_expired"] is None
assert debug_fields["normalized_api_key_error_type"] is None
def test_debug_fields_show_when_normalized_token_would_decode(monkeypatch) -> None:
token = create_access_token("o_test")
monkeypatch.setattr(org_auth_service.time, "time", lambda: 0)
raw_api_key = f"Bearer {token}"
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert debug_fields["api_key_had_bearer_prefix"] is True
assert debug_fields["normalized_api_key_decodes"] is True
assert debug_fields["normalized_api_key_would_be_expired"] is False
assert debug_fields["normalized_api_key_error_type"] is None
def test_debug_fields_show_when_normalized_token_still_fails() -> None:
raw_api_key = '"Bearer definitely-not-a-jwt"'
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == "definitely-not-a-jwt"
assert debug_fields["normalized_api_key_decodes"] is False
assert debug_fields["normalized_api_key_error_type"] == "DecodeError"
assert debug_fields["normalized_api_key_error_reason"] == "Not enough segments"
def test_normalize_api_key_handles_empty_string() -> None:
raw_api_key = ""
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == ""
assert debug_fields["api_key_raw_segment_count"] == 0
assert debug_fields["normalized_api_key_decodes"] is None
def test_normalize_api_key_handles_single_character() -> None:
raw_api_key = '"'
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert normalized == '"'
assert debug_fields["api_key_had_outer_quotes"] is False
assert debug_fields["normalized_api_key_decodes"] is None
def test_debug_fields_reports_validation_error_for_missing_claims() -> None:
raw_api_key = f"Bearer {jwt.encode({}, settings.SECRET_KEY, algorithm='HS256')}"
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
assert debug_fields["normalized_api_key_decodes"] is False
assert debug_fields["normalized_api_key_error_type"] == "ValidationError"
assert debug_fields["normalized_api_key_error_reason"] == "2 validation error(s): [('sub',), ('exp',)]"
def test_debug_fields_handles_none_inputs() -> None:
debug_fields = _get_api_key_debug_fields(None, None, None)
assert debug_fields["api_key_original_length"] is None
assert debug_fields["normalized_api_key_decodes"] is None
assert debug_fields["normalized_api_key_error_type"] is None
assert debug_fields["normalized_api_key_error_reason"] is None
@pytest.mark.asyncio
async def test_resolve_org_from_api_key_logs_decode_error_reason(monkeypatch) -> None:
logged: dict[str, object] = {}
def fake_error(_message: str, **kwargs: object) -> None:
logged.update(kwargs)
monkeypatch.setattr(org_auth_service.LOG, "error", fake_error)
with pytest.raises(HTTPException) as exc_info:
await org_auth_service.resolve_org_from_api_key("definitely-not-a-jwt", SimpleNamespace(), ())
assert exc_info.value.status_code == 403
assert logged["error_type"] == "DecodeError"
assert logged["error_reason"] == "Not enough segments"
@pytest.mark.asyncio
async def test_resolve_org_from_api_key_returns_403_when_diagnostic_helper_fails(monkeypatch) -> None:
warnings: dict[str, object] = {}
def fake_warning(_message: str, **kwargs: object) -> None:
warnings.update(kwargs)
monkeypatch.setattr(org_auth_service.LOG, "warning", fake_warning)
def fail_helper(*_args: object, **_kwargs: object) -> object:
raise RuntimeError("boom")
monkeypatch.setattr(org_auth_service, "_get_api_key_debug_fields", fail_helper)
with pytest.raises(HTTPException) as exc_info:
await org_auth_service.resolve_org_from_api_key("definitely-not-a-jwt", SimpleNamespace(), ())
assert exc_info.value.status_code == 403
assert warnings["diagnostic_error_type"] == "RuntimeError"