mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
891 lines
32 KiB
Python
891 lines
32 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, Mock
|
|
|
|
import httpx
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from starlette.applications import Starlette
|
|
from starlette.middleware import Middleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import Route
|
|
|
|
from skyvern.cli.core import client as client_mod
|
|
from skyvern.cli.core import mcp_http_auth
|
|
|
|
_TEST_BASE_URL = "http://testserver"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset_auth_context() -> None:
|
|
client_mod._api_key_override.set(None)
|
|
mcp_http_auth._auth_db = None
|
|
mcp_http_auth._api_key_validation_cache.clear()
|
|
mcp_http_auth._API_KEY_CACHE_TTL_SECONDS = 30.0
|
|
mcp_http_auth._API_KEY_CACHE_MAX_SIZE = 1024
|
|
mcp_http_auth._MAX_VALIDATION_RETRIES = 2
|
|
mcp_http_auth._RETRY_DELAY_SECONDS = 0.0 # no delay in tests
|
|
|
|
|
|
async def _echo_request_context(request: Request) -> JSONResponse:
|
|
return JSONResponse(
|
|
{
|
|
"api_key": client_mod.get_active_api_key(),
|
|
"organization_id": getattr(request.state, "organization_id", None),
|
|
}
|
|
)
|
|
|
|
|
|
def _build_validation(
|
|
organization_id: str,
|
|
) -> mcp_http_auth.MCPAPIKeyValidation:
|
|
return mcp_http_auth.MCPAPIKeyValidation(
|
|
organization_id=organization_id,
|
|
token_type=mcp_http_auth.OrganizationAuthTokenType.api,
|
|
)
|
|
|
|
|
|
def _build_resolved_validation(
|
|
organization_id: str,
|
|
) -> SimpleNamespace:
|
|
return SimpleNamespace(
|
|
organization=SimpleNamespace(organization_id=organization_id),
|
|
token=SimpleNamespace(token_type=mcp_http_auth.OrganizationAuthTokenType.api),
|
|
)
|
|
|
|
|
|
def _build_test_app() -> Starlette:
|
|
return Starlette(
|
|
routes=[Route("/mcp", endpoint=_echo_request_context, methods=["GET", "HEAD", "POST"])],
|
|
middleware=[Middleware(mcp_http_auth.MCPAPIKeyMiddleware)],
|
|
)
|
|
|
|
|
|
def _jwtish_token(
|
|
*,
|
|
header: dict[str, object] | None = None,
|
|
payload: dict[str, object] | None = None,
|
|
) -> str:
|
|
def _encode(segment: dict[str, object]) -> str:
|
|
return base64.urlsafe_b64encode(json.dumps(segment, separators=(",", ":")).encode()).rstrip(b"=").decode()
|
|
|
|
encoded_header = _encode(header or {"alg": "RS256", "typ": "JWT"})
|
|
encoded_payload = _encode(payload or {"sub": "user_123"})
|
|
return f"{encoded_header}.{encoded_payload}.signature"
|
|
|
|
|
|
async def _request(
|
|
app: Starlette,
|
|
method: str,
|
|
path: str,
|
|
**kwargs: object,
|
|
) -> httpx.Response:
|
|
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url=_TEST_BASE_URL) as client:
|
|
return await client.request(method, path, **kwargs)
|
|
|
|
|
|
def _stub_auth_db(monkeypatch: pytest.MonkeyPatch, db: object) -> None:
|
|
monkeypatch.setattr(mcp_http_auth, "get_auth_db", lambda: db)
|
|
|
|
|
|
def _expected_oauth_challenge(monkeypatch: pytest.MonkeyPatch) -> str:
|
|
monkeypatch.setattr(mcp_http_auth.settings, "SKYVERN_BASE_URL", "https://api.skyvern.com")
|
|
return mcp_http_auth._oauth_challenge_header()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_rejects_missing_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
app = _build_test_app()
|
|
expected_challenge = _expected_oauth_challenge(monkeypatch)
|
|
|
|
response = await _request(app, "POST", "/mcp", json={})
|
|
|
|
assert response.status_code == 401
|
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
|
assert "x-api-key" in response.json()["error"]["message"]
|
|
assert response.headers["www-authenticate"] == expected_challenge
|
|
assert response.headers["access-control-expose-headers"] == "WWW-Authenticate"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_head_request_exposes_oauth_challenge(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
app = _build_test_app()
|
|
expected_challenge = _expected_oauth_challenge(monkeypatch)
|
|
|
|
response = await _request(app, "HEAD", "/mcp")
|
|
|
|
assert response.status_code == 401
|
|
assert response.headers["www-authenticate"] == expected_challenge
|
|
assert response.headers["access-control-expose-headers"] == "WWW-Authenticate"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_allows_health_checks_without_api_key() -> None:
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "GET", "/healthz")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "ok"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_rejects_invalid_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(side_effect=HTTPException(status_code=403, detail="Invalid credentials")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"x-api-key": "bad-key"}, json={})
|
|
|
|
assert response.status_code == 401
|
|
assert response.json()["error"]["code"] == "UNAUTHORIZED"
|
|
assert response.json()["error"]["message"] == "Invalid API key"
|
|
assert "www-authenticate" not in response.headers
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_returns_500_on_non_auth_http_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(side_effect=HTTPException(status_code=500, detail="db down")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_returns_503_on_transient_validation_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(side_effect=HTTPException(status_code=503, detail="API key validation temporarily unavailable")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
|
|
|
assert response.status_code == 503
|
|
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
|
|
assert response.json()["error"]["message"] == "API key validation temporarily unavailable"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_returns_500_on_unexpected_validation_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(side_effect=RuntimeError("boom")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_sets_request_scoped_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(return_value=_build_validation("org_123")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"x-api-key": "sk_live_abc"}, json={})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"api_key": "sk_live_abc",
|
|
"organization_id": "org_123",
|
|
}
|
|
assert client_mod.get_active_api_key() != "sk_live_abc"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_api_key_uses_ttl_cache(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
calls = 0
|
|
|
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
|
nonlocal calls
|
|
calls += 1
|
|
return _build_resolved_validation("org_cached")
|
|
|
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
|
_stub_auth_db(monkeypatch, object())
|
|
|
|
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
|
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache")
|
|
|
|
assert first.organization_id == "org_cached"
|
|
assert second.organization_id == "org_cached"
|
|
assert calls == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_api_key_cache_expires(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
calls = 0
|
|
|
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
|
nonlocal calls
|
|
calls += 1
|
|
return _build_resolved_validation(f"org_{calls}")
|
|
|
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
|
_stub_auth_db(monkeypatch, object())
|
|
|
|
first = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
|
cache_key = mcp_http_auth.cache_key("sk_test_cache_expire")
|
|
mcp_http_auth._api_key_validation_cache[cache_key] = (first, 0.0)
|
|
second = await mcp_http_auth.validate_mcp_api_key("sk_test_cache_expire")
|
|
|
|
assert first.organization_id == "org_1"
|
|
assert second.organization_id == "org_2"
|
|
assert calls == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_api_key_negative_caches_auth_failures(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
calls = 0
|
|
|
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
|
nonlocal calls
|
|
calls += 1
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
|
_stub_auth_db(monkeypatch, object())
|
|
|
|
with pytest.raises(HTTPException, match="Invalid credentials"):
|
|
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
|
|
|
with pytest.raises(HTTPException, match="Invalid API key"):
|
|
await mcp_http_auth.validate_mcp_api_key("sk_test_auth_failure")
|
|
|
|
assert calls == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_api_key_retries_transient_failure_without_negative_cache(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
calls = 0
|
|
|
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
|
nonlocal calls
|
|
calls += 1
|
|
if calls == 1:
|
|
raise RuntimeError("transient db error")
|
|
return _build_resolved_validation("org_recovered")
|
|
|
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
|
_stub_auth_db(monkeypatch, object())
|
|
|
|
recovered_org = await mcp_http_auth.validate_mcp_api_key("sk_test_transient")
|
|
|
|
cache_key = mcp_http_auth.cache_key("sk_test_transient")
|
|
assert mcp_http_auth._api_key_validation_cache[cache_key][0].organization_id == "org_recovered"
|
|
|
|
assert recovered_org.organization_id == "org_recovered"
|
|
assert calls == 2
|
|
|
|
|
|
def test_profile_to_mcp_url_normalizes_base_variants() -> None:
|
|
# Canonical form has no trailing slash so the advertised MCP resource URI
|
|
# matches what clients send during RFC 8707 audience / RFC 9728
|
|
# protected-resource comparison.
|
|
assert mcp_http_auth._canonical_mcp_url("https://api.skyvern.com") == "https://api.skyvern.com/mcp"
|
|
assert mcp_http_auth._canonical_mcp_url("https://api.skyvern.com/") == "https://api.skyvern.com/mcp"
|
|
assert mcp_http_auth._canonical_mcp_url("https://api.skyvern.com/mcp") == "https://api.skyvern.com/mcp"
|
|
assert mcp_http_auth._canonical_mcp_url("https://api.skyvern.com/mcp/") == "https://api.skyvern.com/mcp"
|
|
|
|
|
|
def test_resource_metadata_url_normalizes_base_variants() -> None:
|
|
assert (
|
|
mcp_http_auth._canonical_resource_metadata_url("https://api.skyvern.com")
|
|
== "https://api.skyvern.com/.well-known/oauth-protected-resource/mcp"
|
|
)
|
|
assert (
|
|
mcp_http_auth._canonical_resource_metadata_url("https://api.skyvern.com/mcp/")
|
|
== "https://api.skyvern.com/.well-known/oauth-protected-resource/mcp"
|
|
)
|
|
|
|
|
|
def test_validate_token_audience_rejects_wrong_resource() -> None:
|
|
with pytest.raises(HTTPException, match="Token audience is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_audience(
|
|
{"aud": ["https://some-other-resource.example.com/mcp/"]},
|
|
"https://api.skyvern.com/mcp/",
|
|
)
|
|
|
|
|
|
def test_validate_token_audience_accepts_matching_url() -> None:
|
|
mcp_http_auth._validate_token_audience(
|
|
{"aud": ["https://api.skyvern.com/mcp/"]},
|
|
"https://api.skyvern.com/mcp/",
|
|
)
|
|
|
|
|
|
def test_validate_token_audience_tolerates_trailing_slash_mismatch() -> None:
|
|
# Token audience minted against the slashed form must still validate when
|
|
# the canonical (slashless) expected_resource is used, and vice versa.
|
|
mcp_http_auth._validate_token_audience(
|
|
{"aud": ["https://api.skyvern.com/mcp/"]},
|
|
"https://api.skyvern.com/mcp",
|
|
)
|
|
mcp_http_auth._validate_token_audience(
|
|
{"aud": ["https://api.skyvern.com/mcp"]},
|
|
"https://api.skyvern.com/mcp/",
|
|
)
|
|
|
|
|
|
def test_validate_token_resource_claim_tolerates_trailing_slash_mismatch() -> None:
|
|
# Same normalization applies to the RFC 8707 `resource` claim.
|
|
mcp_http_auth._validate_token_resource_claims(
|
|
{"resource": "https://api.skyvern.com/mcp/"},
|
|
"https://api.skyvern.com/mcp",
|
|
)
|
|
mcp_http_auth._validate_token_resource_claims(
|
|
{"resource": "https://api.skyvern.com/mcp"},
|
|
"https://api.skyvern.com/mcp/",
|
|
)
|
|
|
|
|
|
def test_validate_token_audience_rejects_missing_aud() -> None:
|
|
# Payload without any `aud` key at all must reject — the `any(...)` check
|
|
# on an empty audience list cannot match the expected resource.
|
|
with pytest.raises(HTTPException, match="Token audience is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_audience({}, "https://api.skyvern.com/mcp")
|
|
|
|
|
|
def test_validate_token_audience_rejects_none_aud() -> None:
|
|
with pytest.raises(HTTPException, match="Token audience is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_audience({"aud": None}, "https://api.skyvern.com/mcp")
|
|
|
|
|
|
def test_validate_token_audience_rejects_empty_list_aud() -> None:
|
|
with pytest.raises(HTTPException, match="Token audience is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_audience({"aud": []}, "https://api.skyvern.com/mcp")
|
|
|
|
|
|
def test_validate_token_audience_filters_non_string_list_items() -> None:
|
|
# Non-string items inside the `aud` array are silently dropped (per the
|
|
# asymmetry documented in _validate_token_audience); with only garbage in
|
|
# the list, there is nothing to match against the expected resource.
|
|
with pytest.raises(HTTPException, match="Token audience is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_audience({"aud": [42, None, {}]}, "https://api.skyvern.com/mcp")
|
|
|
|
|
|
def test_validate_token_audience_rejects_different_path_despite_normalization() -> None:
|
|
# Guards against a future refactor broadening rstrip normalization into a
|
|
# prefix / startswith check. `/mcp-other/` is not a slash-variant of
|
|
# `/mcp` and must be rejected.
|
|
with pytest.raises(HTTPException, match="Token audience is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_audience(
|
|
{"aud": ["https://api.skyvern.com/mcp-other/"]},
|
|
"https://api.skyvern.com/mcp",
|
|
)
|
|
|
|
|
|
def test_validate_token_resource_claim_rejects_different_path_despite_normalization() -> None:
|
|
# Same boundary guard for the `resource` claim.
|
|
with pytest.raises(HTTPException, match="Token resource is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_token_resource_claims(
|
|
{"resource": "https://api.skyvern.com/mcp-other/"},
|
|
"https://api.skyvern.com/mcp",
|
|
)
|
|
|
|
|
|
def test_validate_token_resource_claim_rejects_non_string_claim() -> None:
|
|
# Explicit type guard: a non-string `resource` claim is a malformed token,
|
|
# not a slash-variant of the expected value, and gets its own error detail
|
|
# so the cause is obvious in logs.
|
|
with pytest.raises(HTTPException, match="Token resource claim must be a string"):
|
|
mcp_http_auth._validate_token_resource_claims(
|
|
{"resource": 42},
|
|
"https://api.skyvern.com/mcp",
|
|
)
|
|
|
|
|
|
def test_looks_like_jwt_rejects_dotted_opaque_token() -> None:
|
|
assert mcp_http_auth._looks_like_jwt("opaque.with.dots") is False
|
|
|
|
|
|
def test_looks_like_jwt_accepts_jwt_header() -> None:
|
|
assert mcp_http_auth._looks_like_jwt(_jwtish_token()) is True
|
|
|
|
|
|
def test_validate_oauth_token_contract_rejects_invalid_issuer() -> None:
|
|
with pytest.raises(HTTPException, match="Token issuer is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_oauth_token_contract(
|
|
{
|
|
"iss": "https://wrong-issuer.example.com",
|
|
"aud": ["https://api.skyvern.com/mcp/"],
|
|
},
|
|
expected_resource="https://api.skyvern.com/mcp/",
|
|
expected_issuer="https://clerk.example.com",
|
|
)
|
|
|
|
|
|
def test_validate_oauth_token_contract_rejects_mismatched_resource_claim() -> None:
|
|
with pytest.raises(HTTPException, match="Token resource is not valid for this MCP resource"):
|
|
mcp_http_auth._validate_oauth_token_contract(
|
|
{
|
|
"iss": "https://clerk.example.com",
|
|
"aud": ["https://api.skyvern.com/mcp/"],
|
|
"resource": "https://api.skyvern.com/other/",
|
|
},
|
|
expected_resource="https://api.skyvern.com/mcp/",
|
|
expected_issuer="https://clerk.example.com",
|
|
)
|
|
|
|
|
|
def test_validate_oauth_token_contract_accepts_valid_jwt_claims() -> None:
|
|
mcp_http_auth._validate_oauth_token_contract(
|
|
{
|
|
"iss": "https://clerk.example.com/",
|
|
"aud": ["https://api.skyvern.com/mcp/"],
|
|
"resource": "https://api.skyvern.com/mcp/",
|
|
},
|
|
expected_resource="https://api.skyvern.com/mcp/",
|
|
expected_issuer="https://clerk.example.com",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_oauth_userinfo_returns_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"aiohttp_request",
|
|
AsyncMock(return_value=(200, {}, {"sub": "user_123", "email": "user@example.com"})),
|
|
)
|
|
|
|
payload = await mcp_http_auth._fetch_oauth_userinfo("opaque-token", "https://clerk.example.com")
|
|
|
|
assert payload == {"sub": "user_123", "email": "user@example.com"}
|
|
mcp_http_auth.aiohttp_request.assert_awaited_once_with(
|
|
"GET",
|
|
"https://clerk.example.com/oauth/userinfo",
|
|
headers={"Authorization": "Bearer opaque-token"},
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_oauth_token_rejects_opaque_tokens(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""Opaque (non-JWT) bearer tokens cannot be audience-validated; we reject with 401."""
|
|
fetch_userinfo = AsyncMock(return_value={"sub": "user_123"})
|
|
monkeypatch.setattr(mcp_http_auth, "_fetch_oauth_userinfo", fetch_userinfo)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"_get_oauth_issuer_url",
|
|
lambda: "https://clerk.example.com",
|
|
)
|
|
monkeypatch.setattr(mcp_http_auth.settings, "SKYVERN_BASE_URL", "https://api.skyvern.com")
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await mcp_http_auth.validate_mcp_oauth_token("opaque-token")
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Opaque Bearer tokens" in exc_info.value.detail
|
|
# The reject path must never call userinfo — we decide purely on shape.
|
|
fetch_userinfo.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_oauth_token_rejects_dotted_opaque_tokens(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""A token with dots but not a valid JWT header is still opaque and must be rejected."""
|
|
fetch_userinfo = AsyncMock(return_value={"sub": "user_123"})
|
|
monkeypatch.setattr(mcp_http_auth, "_fetch_oauth_userinfo", fetch_userinfo)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"_get_oauth_issuer_url",
|
|
lambda: "https://clerk.example.com",
|
|
)
|
|
monkeypatch.setattr(mcp_http_auth.settings, "SKYVERN_BASE_URL", "https://api.skyvern.com")
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await mcp_http_auth.validate_mcp_oauth_token("opaque.with.dots")
|
|
|
|
assert exc_info.value.status_code == 401
|
|
fetch_userinfo.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_oauth_token_does_not_negative_cache_503_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""A 503 from Clerk JWKS fetch must not be cached — the next call should retry."""
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"_get_oauth_issuer_url",
|
|
lambda: "https://clerk.example.com",
|
|
)
|
|
_stub_auth_db(monkeypatch, object())
|
|
monkeypatch.setattr(mcp_http_auth, "_looks_like_jwt", lambda _token: True)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"app",
|
|
SimpleNamespace(
|
|
AGENT_FUNCTION=SimpleNamespace(
|
|
get_mcp_oauth_jwt_key=AsyncMock(side_effect=RuntimeError("clerk down")),
|
|
)
|
|
),
|
|
)
|
|
|
|
jwt_token = _jwtish_token()
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await mcp_http_auth.validate_mcp_oauth_token(jwt_token)
|
|
|
|
assert exc_info.value.status_code == 503
|
|
assert mcp_http_auth._oauth_cache_key(jwt_token) not in mcp_http_auth._api_key_validation_cache
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_oauth_token_negative_caches_401_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""An invalid-signature 401 should be negative-cached so repeated bad tokens are cheap."""
|
|
import jwt
|
|
from jwt.exceptions import InvalidSignatureError
|
|
|
|
def _fake_decode(*_args: object, **_kwargs: object) -> dict[str, object]:
|
|
raise InvalidSignatureError("bad signature")
|
|
|
|
monkeypatch.setattr(jwt, "PyJWK", lambda key: key)
|
|
monkeypatch.setattr(jwt, "decode", _fake_decode)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"_get_oauth_issuer_url",
|
|
lambda: "https://clerk.example.com",
|
|
)
|
|
_stub_auth_db(monkeypatch, object())
|
|
monkeypatch.setattr(mcp_http_auth, "_looks_like_jwt", lambda _token: True)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"app",
|
|
SimpleNamespace(
|
|
AGENT_FUNCTION=SimpleNamespace(
|
|
get_mcp_oauth_jwt_key=AsyncMock(return_value="jwk"),
|
|
)
|
|
),
|
|
)
|
|
|
|
jwt_token = _jwtish_token()
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await mcp_http_auth.validate_mcp_oauth_token(jwt_token)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert mcp_http_auth._oauth_cache_key(jwt_token) in mcp_http_auth._api_key_validation_cache
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_oauth_token_returns_401_when_cloud_jwk_is_unavailable(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"_get_oauth_issuer_url",
|
|
lambda: "https://clerk.example.com",
|
|
)
|
|
_stub_auth_db(monkeypatch, object())
|
|
monkeypatch.setattr(mcp_http_auth, "_looks_like_jwt", lambda _token: True)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"app",
|
|
SimpleNamespace(
|
|
AGENT_FUNCTION=SimpleNamespace(
|
|
get_mcp_oauth_jwt_key=AsyncMock(return_value=None),
|
|
)
|
|
),
|
|
)
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await mcp_http_auth.validate_mcp_oauth_token("header.payload.signature")
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert exc_info.value.detail == "OAuth authentication requires cloud deployment"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_oauth_token_passes_clock_skew_leeway_to_pyjwt(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
import jwt
|
|
|
|
captured: dict[str, object] = {}
|
|
|
|
def _fake_jwk(key: object) -> object:
|
|
captured["jwk_key"] = key
|
|
return key
|
|
|
|
def _fake_decode(token: str, signing_key: object, **kwargs: object) -> dict[str, object]:
|
|
captured["token"] = token
|
|
captured["signing_key"] = signing_key
|
|
captured["kwargs"] = kwargs
|
|
return {
|
|
"iss": "https://clerk.example.com",
|
|
"aud": ["https://api.skyvern.com/mcp/"],
|
|
"resource": "https://api.skyvern.com/mcp/",
|
|
"sub": "user_123",
|
|
}
|
|
|
|
fake_db = SimpleNamespace(
|
|
get_organization_entities=AsyncMock(return_value=[SimpleNamespace(organization_id="org_jwt")]),
|
|
get_valid_org_auth_token=AsyncMock(return_value=SimpleNamespace(token="sk_live_from_jwt")),
|
|
)
|
|
monkeypatch.setattr(jwt, "PyJWK", _fake_jwk)
|
|
monkeypatch.setattr(jwt, "decode", _fake_decode)
|
|
_stub_auth_db(monkeypatch, fake_db)
|
|
monkeypatch.setattr(mcp_http_auth, "_looks_like_jwt", lambda _token: True)
|
|
monkeypatch.setattr(mcp_http_auth, "_get_oauth_issuer_url", lambda: "https://clerk.example.com")
|
|
monkeypatch.setattr(mcp_http_auth.settings, "SKYVERN_BASE_URL", "https://api.skyvern.com")
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"app",
|
|
SimpleNamespace(
|
|
AGENT_FUNCTION=SimpleNamespace(
|
|
get_mcp_oauth_jwt_key=AsyncMock(return_value="jwk"),
|
|
)
|
|
),
|
|
)
|
|
|
|
resolution = await mcp_http_auth.validate_mcp_oauth_token(_jwtish_token())
|
|
|
|
assert resolution.api_key == "sk_live_from_jwt"
|
|
assert captured["kwargs"]["leeway"] == mcp_http_auth._TOKEN_CLOCK_SKEW_SECONDS
|
|
assert "options" not in captured["kwargs"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_accepts_opaque_bearer_tokens(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_oauth_token",
|
|
AsyncMock(
|
|
return_value=mcp_http_auth._OAuthResolution(
|
|
api_key="sk_live_opaque",
|
|
validation=_build_validation("org_opaque"),
|
|
)
|
|
),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"authorization": "Bearer opaque-token"}, json={})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"api_key": "sk_live_opaque",
|
|
"organization_id": "org_opaque",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_falls_back_to_api_key_after_invalid_oauth_token(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_oauth_token",
|
|
AsyncMock(side_effect=HTTPException(status_code=401, detail="Invalid Bearer token")),
|
|
)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(return_value=_build_validation("org_from_api_key")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(
|
|
app,
|
|
"POST",
|
|
"/mcp",
|
|
headers={"authorization": "Bearer sk_live_proxy_token"},
|
|
json={},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"api_key": "sk_live_proxy_token",
|
|
"organization_id": "org_from_api_key",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_returns_503_when_clerk_is_unavailable(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_oauth_token",
|
|
AsyncMock(side_effect=HTTPException(status_code=503, detail="Authentication service temporarily unavailable")),
|
|
)
|
|
# API-key fallback also rejects the token; we must still surface 503 because
|
|
# the OAuth path was the authoritative validator for a JWT-shaped token.
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(side_effect=HTTPException(status_code=401, detail="Invalid API key")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"authorization": "Bearer a.b.c"}, json={})
|
|
|
|
assert response.status_code == 503
|
|
assert response.json()["error"]["code"] == "SERVICE_UNAVAILABLE"
|
|
assert response.json()["error"]["message"] == "Authentication service temporarily unavailable"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_falls_back_to_api_key_after_oauth_service_outage(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""A raw API key in the Bearer slot must still authenticate when Clerk is degraded."""
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_oauth_token",
|
|
AsyncMock(side_effect=HTTPException(status_code=503, detail="Authentication service temporarily unavailable")),
|
|
)
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_api_key",
|
|
AsyncMock(return_value=_build_validation("org_recovered")),
|
|
)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(
|
|
app,
|
|
"POST",
|
|
"/mcp",
|
|
headers={"authorization": "Bearer sk_live_proxy_token"},
|
|
json={},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {
|
|
"api_key": "sk_live_proxy_token",
|
|
"organization_id": "org_recovered",
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_http_auth_returns_500_when_oauth_validation_crashes(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"validate_mcp_oauth_token",
|
|
AsyncMock(side_effect=RuntimeError("boom")),
|
|
)
|
|
validate_mcp_api_key = AsyncMock(return_value=_build_validation("org_should_not_run"))
|
|
monkeypatch.setattr(mcp_http_auth, "validate_mcp_api_key", validate_mcp_api_key)
|
|
app = _build_test_app()
|
|
|
|
response = await _request(app, "POST", "/mcp", headers={"authorization": "Bearer a.b.c"}, json={})
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["error"]["code"] == "INTERNAL_ERROR"
|
|
validate_mcp_api_key.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_oauth_subject_to_org_logs_missing_db_methods(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
debug_log = Mock()
|
|
monkeypatch.setattr(mcp_http_auth.LOG, "debug", debug_log)
|
|
|
|
with pytest.raises(HTTPException, match="OAuth authentication requires cloud deployment"):
|
|
await mcp_http_auth._resolve_oauth_subject_to_org({"sub": "user_123"}, object())
|
|
|
|
debug_log.assert_called_once()
|
|
|
|
|
|
def test_get_auth_db_uses_agent_function_builder(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
built_db = object()
|
|
builder = Mock(return_value=built_db)
|
|
|
|
monkeypatch.setattr(
|
|
mcp_http_auth,
|
|
"app",
|
|
SimpleNamespace(
|
|
AGENT_FUNCTION=SimpleNamespace(build_mcp_auth_db=builder),
|
|
),
|
|
)
|
|
monkeypatch.setattr(mcp_http_auth, "_auth_db", None)
|
|
|
|
db = mcp_http_auth.get_auth_db()
|
|
|
|
assert db is built_db
|
|
builder.assert_called_once_with(
|
|
mcp_http_auth.settings.DATABASE_STRING,
|
|
debug_enabled=mcp_http_auth.settings.DEBUG_MODE,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_api_key_concurrent_callers_all_succeed(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
"""Multiple concurrent callers for the same key all succeed; the cache
|
|
collapses subsequent calls after the first one populates it."""
|
|
calls = 0
|
|
|
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
|
nonlocal calls
|
|
calls += 1
|
|
return _build_resolved_validation("org_concurrent")
|
|
|
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
|
_stub_auth_db(monkeypatch, object())
|
|
|
|
results = await asyncio.gather(*[mcp_http_auth.validate_mcp_api_key("test-key-concurrent") for _ in range(5)])
|
|
assert all(r.organization_id == "org_concurrent" for r in results)
|
|
# First call populates cache; remaining may or may not hit DB depending on
|
|
# scheduling, but all must succeed.
|
|
assert calls >= 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_mcp_api_key_returns_503_after_retry_exhaustion(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
calls = 0
|
|
|
|
async def _resolve(_api_key: str, _db: object, **_: object) -> object:
|
|
nonlocal calls
|
|
calls += 1
|
|
raise RuntimeError("persistent db outage")
|
|
|
|
monkeypatch.setattr(mcp_http_auth, "_MAX_VALIDATION_RETRIES", 2)
|
|
monkeypatch.setattr(mcp_http_auth, "resolve_org_from_api_key", _resolve)
|
|
_stub_auth_db(monkeypatch, object())
|
|
|
|
with pytest.raises(HTTPException, match="temporarily unavailable") as exc_info:
|
|
await mcp_http_auth.validate_mcp_api_key("sk_test_transient_exhausted")
|
|
|
|
assert exc_info.value.status_code == 503
|
|
assert calls == 3 # initial + 2 retries
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_auth_db_disposes_engine() -> None:
|
|
dispose = AsyncMock()
|
|
mcp_http_auth._auth_db = SimpleNamespace(engine=SimpleNamespace(dispose=dispose))
|
|
mcp_http_auth._api_key_validation_cache["k"] = ("org", 123.0)
|
|
|
|
await mcp_http_auth.close_auth_db()
|
|
|
|
dispose.assert_awaited_once()
|
|
assert mcp_http_auth._auth_db is None
|
|
assert mcp_http_auth._api_key_validation_cache == {}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_auth_db_noop_when_uninitialized() -> None:
|
|
mcp_http_auth._auth_db = None
|
|
await mcp_http_auth.close_auth_db()
|
|
assert mcp_http_auth._auth_db is None
|