mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
177 lines
7.2 KiB
Python
177 lines
7.2 KiB
Python
from types import SimpleNamespace
|
|
|
|
import jwt
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
|
|
from skyvern.config import settings
|
|
from skyvern.forge.sdk.core.security import create_access_token
|
|
from skyvern.forge.sdk.services import org_auth_service
|
|
from skyvern.forge.sdk.services.org_auth_service import (
|
|
_get_api_key_debug_fields,
|
|
_normalize_api_key_with_flags,
|
|
)
|
|
|
|
|
|
def test_normalize_api_key_strips_whitespace() -> None:
|
|
raw_api_key = " token.value.parts \n"
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "token.value.parts"
|
|
assert debug_fields["api_key_had_whitespace_padding"] is True
|
|
assert debug_fields["api_key_was_normalized"] is True
|
|
|
|
|
|
def test_normalize_api_key_strips_outer_quotes() -> None:
|
|
raw_api_key = '"token.value.parts"'
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "token.value.parts"
|
|
assert debug_fields["api_key_had_outer_quotes"] is True
|
|
assert debug_fields["api_key_was_normalized"] is True
|
|
|
|
|
|
def test_normalize_api_key_strips_bearer_prefix() -> None:
|
|
raw_api_key = "Bearer token.value.parts"
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "token.value.parts"
|
|
assert debug_fields["api_key_had_bearer_prefix"] is True
|
|
assert debug_fields["api_key_normalized_segment_count"] == 3
|
|
|
|
|
|
def test_normalize_api_key_handles_quoted_bearer_value() -> None:
|
|
raw_api_key = '"Bearer token.value.parts"'
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "token.value.parts"
|
|
assert debug_fields["api_key_had_bearer_prefix"] is True
|
|
assert debug_fields["api_key_had_outer_quotes"] is True
|
|
|
|
|
|
def test_normalize_api_key_tracks_whitespace_removed_after_wrapper_stripping() -> None:
|
|
raw_api_key = 'Bearer " token.value.parts "'
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "token.value.parts"
|
|
assert debug_fields["api_key_had_whitespace_padding"] is True
|
|
assert debug_fields["api_key_had_bearer_prefix"] is True
|
|
assert debug_fields["api_key_had_outer_quotes"] is True
|
|
|
|
|
|
def test_debug_fields_report_no_shadow_decode_for_unchanged_value() -> None:
|
|
raw_api_key = "token.value.parts"
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "token.value.parts"
|
|
assert debug_fields["api_key_was_normalized"] is False
|
|
assert debug_fields["normalized_api_key_decodes"] is None
|
|
assert debug_fields["normalized_api_key_would_be_expired"] is None
|
|
assert debug_fields["normalized_api_key_error_type"] is None
|
|
|
|
|
|
def test_debug_fields_show_when_normalized_token_would_decode(monkeypatch) -> None:
|
|
token = create_access_token("o_test")
|
|
monkeypatch.setattr(org_auth_service.time, "time", lambda: 0)
|
|
raw_api_key = f"Bearer {token}"
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert debug_fields["api_key_had_bearer_prefix"] is True
|
|
assert debug_fields["normalized_api_key_decodes"] is True
|
|
assert debug_fields["normalized_api_key_would_be_expired"] is False
|
|
assert debug_fields["normalized_api_key_error_type"] is None
|
|
|
|
|
|
def test_debug_fields_show_when_normalized_token_still_fails() -> None:
|
|
raw_api_key = '"Bearer definitely-not-a-jwt"'
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == "definitely-not-a-jwt"
|
|
assert debug_fields["normalized_api_key_decodes"] is False
|
|
assert debug_fields["normalized_api_key_error_type"] == "DecodeError"
|
|
assert debug_fields["normalized_api_key_error_reason"] == "Not enough segments"
|
|
|
|
|
|
def test_normalize_api_key_handles_empty_string() -> None:
|
|
raw_api_key = ""
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == ""
|
|
assert debug_fields["api_key_raw_segment_count"] == 0
|
|
assert debug_fields["normalized_api_key_decodes"] is None
|
|
|
|
|
|
def test_normalize_api_key_handles_single_character() -> None:
|
|
raw_api_key = '"'
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert normalized == '"'
|
|
assert debug_fields["api_key_had_outer_quotes"] is False
|
|
assert debug_fields["normalized_api_key_decodes"] is None
|
|
|
|
|
|
def test_debug_fields_reports_validation_error_for_missing_claims() -> None:
|
|
raw_api_key = f"Bearer {jwt.encode({}, settings.SECRET_KEY, algorithm='HS256')}"
|
|
normalized, flags = _normalize_api_key_with_flags(raw_api_key)
|
|
debug_fields = _get_api_key_debug_fields(raw_api_key, normalized, flags)
|
|
|
|
assert debug_fields["normalized_api_key_decodes"] is False
|
|
assert debug_fields["normalized_api_key_error_type"] == "ValidationError"
|
|
assert debug_fields["normalized_api_key_error_reason"] == "2 validation error(s): [('sub',), ('exp',)]"
|
|
|
|
|
|
def test_debug_fields_handles_none_inputs() -> None:
|
|
debug_fields = _get_api_key_debug_fields(None, None, None)
|
|
|
|
assert debug_fields["api_key_original_length"] is None
|
|
assert debug_fields["normalized_api_key_decodes"] is None
|
|
assert debug_fields["normalized_api_key_error_type"] is None
|
|
assert debug_fields["normalized_api_key_error_reason"] is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_org_from_api_key_logs_decode_error_reason(monkeypatch) -> None:
|
|
logged: dict[str, object] = {}
|
|
|
|
def fake_error(_message: str, **kwargs: object) -> None:
|
|
logged.update(kwargs)
|
|
|
|
monkeypatch.setattr(org_auth_service.LOG, "error", fake_error)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await org_auth_service.resolve_org_from_api_key("definitely-not-a-jwt", SimpleNamespace(), ())
|
|
|
|
assert exc_info.value.status_code == 403
|
|
assert logged["error_type"] == "DecodeError"
|
|
assert logged["error_reason"] == "Not enough segments"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_org_from_api_key_returns_403_when_diagnostic_helper_fails(monkeypatch) -> None:
|
|
warnings: dict[str, object] = {}
|
|
|
|
def fake_warning(_message: str, **kwargs: object) -> None:
|
|
warnings.update(kwargs)
|
|
|
|
monkeypatch.setattr(org_auth_service.LOG, "warning", fake_warning)
|
|
|
|
def fail_helper(*_args: object, **_kwargs: object) -> object:
|
|
raise RuntimeError("boom")
|
|
|
|
monkeypatch.setattr(org_auth_service, "_get_api_key_debug_fields", fail_helper)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await org_auth_service.resolve_org_from_api_key("definitely-not-a-jwt", SimpleNamespace(), ())
|
|
|
|
assert exc_info.value.status_code == 403
|
|
assert warnings["diagnostic_error_type"] == "RuntimeError"
|