feat(mcp): [Required for Anthropic Claude Connectors Listing] add tool titles, Origin validation, and response size cap (#5608)

This commit is contained in:
Marc Kelechava 2026-04-22 17:23:59 -07:00 committed by GitHub
parent d12e773b9b
commit 4187a4a6ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1101 additions and 95 deletions

View file

@ -0,0 +1,210 @@
"""Unit tests for the MCP Origin-validation middleware."""
from __future__ import annotations
import pytest
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from skyvern.cli.mcp_tools.origin_middleware import (
OriginValidationMiddleware,
_sanitize_origin_for_log,
is_allowed_origin,
)
async def _ok_handler(request: Request) -> JSONResponse:
return JSONResponse({"ok": True})
@pytest.fixture()
def client() -> TestClient:
# Starlette app with OriginValidationMiddleware wrapped around it.
async def app_factory(scope, receive, send): # type: ignore[no-untyped-def]
inner = Starlette(routes=[Route("/mcp/", _ok_handler, methods=["GET", "POST"])])
middleware = OriginValidationMiddleware(inner)
await middleware(scope, receive, send)
return TestClient(app_factory)
# -- is_allowed_origin ------------------------------------------------------
def test_is_allowed_origin_absent_is_allowed() -> None:
assert is_allowed_origin(None) is True
assert is_allowed_origin("") is True
def test_is_allowed_origin_claude_ai_allowed() -> None:
assert is_allowed_origin("https://claude.ai") is True
assert is_allowed_origin("https://www.claude.ai") is True
def test_is_allowed_origin_claude_com_allowed() -> None:
# The Connectors Directory is served from claude.com; submissions must not
# 403 out of the gate when the user installs the connector from there.
assert is_allowed_origin("https://claude.com") is True
assert is_allowed_origin("https://www.claude.com") is True
def test_is_allowed_origin_anthropic_rejected() -> None:
# anthropic.com is marketing / docs, not an MCP client surface. Admitting
# it would widen CSRF surface without adding a legitimate flow.
assert is_allowed_origin("https://anthropic.com") is False
assert is_allowed_origin("https://www.anthropic.com") is False
assert is_allowed_origin("https://api.anthropic.com") is False
def test_is_allowed_origin_loopback_allowed() -> None:
assert is_allowed_origin("http://localhost:5173") is True
assert is_allowed_origin("http://127.0.0.1:9000") is True
assert is_allowed_origin("http://[::1]:3000") is True
def test_is_allowed_origin_random_is_rejected() -> None:
assert is_allowed_origin("https://evil.example") is False
assert is_allowed_origin("https://attacker.com") is False
def test_is_allowed_origin_prefix_spoofing_rejected() -> None:
# Subdomain claiming claude.ai must not pass (only hostname exact matches).
assert is_allowed_origin("https://claude.ai.attacker.com") is False
assert is_allowed_origin("https://not-claude.ai") is False
def test_is_allowed_origin_missing_host_rejected() -> None:
# Malformed origin with no host.
assert is_allowed_origin("https://") is False
def test_sanitize_origin_for_log_escapes_control_chars() -> None:
assert _sanitize_origin_for_log("https://evil.example\r\nforged") == "https://evil.example\\r\\nforged"
def test_sanitize_origin_for_log_truncates_long_values() -> None:
value = "https://evil.example/" + ("x" * 500)
sanitized = _sanitize_origin_for_log(value)
assert sanitized is not None
assert sanitized.endswith("... [truncated]")
assert len(sanitized) < len(value)
# -- middleware integration -------------------------------------------------
def test_middleware_allows_missing_origin(client: TestClient) -> None:
response = client.get("/mcp/")
assert response.status_code == 200
assert response.json() == {"ok": True}
def test_middleware_allows_claude_ai(client: TestClient) -> None:
response = client.post("/mcp/", headers={"origin": "https://claude.ai"})
assert response.status_code == 200
def test_middleware_allows_loopback(client: TestClient) -> None:
response = client.post("/mcp/", headers={"origin": "http://127.0.0.1:12345"})
assert response.status_code == 200
def test_middleware_rejects_unknown_origin(client: TestClient) -> None:
response = client.post("/mcp/", headers={"origin": "https://evil.example"})
assert response.status_code == 403
body = response.json()
assert body["error"] == "forbidden_origin"
# Static message — rejected origin is in the structured log, not reflected
# back in the response body.
assert body["detail"] == "Origin not allowed"
assert "evil.example" not in body["detail"]
def test_middleware_rejects_subdomain_spoof(client: TestClient) -> None:
response = client.post(
"/mcp/",
headers={"origin": "https://claude.ai.attacker.com"},
)
assert response.status_code == 403
# -- websocket scope handling ----------------------------------------------
@pytest.mark.asyncio
async def test_middleware_rejects_websocket_with_unknown_origin() -> None:
# FastMCP currently only mounts streamable-HTTP, but the middleware must
# also gate WebSocket handshakes in case that transport is ever enabled
# at `/mcp`. Rejection happens via `websocket.close` with code 1008
# (policy violation) before the app sees `websocket.accept`.
sent: list[dict] = []
async def _should_not_be_called(scope, receive, send): # type: ignore[no-untyped-def]
raise AssertionError("inner app must not receive rejected websocket scope")
middleware = OriginValidationMiddleware(_should_not_be_called)
scope = {
"type": "websocket",
"path": "/mcp/",
"headers": [(b"origin", b"https://evil.example")],
}
async def _receive() -> dict:
return {"type": "websocket.connect"}
async def _send(message: dict) -> None:
sent.append(message)
await middleware(scope, _receive, _send)
assert sent == [{"type": "websocket.close", "code": 1008}]
@pytest.mark.asyncio
async def test_middleware_allows_websocket_with_claude_ai_origin() -> None:
# An allowlisted Origin on a WebSocket scope passes through to the inner
# app so it can run the `websocket.accept` handshake itself.
called = False
async def _inner(scope, receive, send): # type: ignore[no-untyped-def]
nonlocal called
called = True
middleware = OriginValidationMiddleware(_inner)
scope = {
"type": "websocket",
"path": "/mcp/",
"headers": [(b"origin", b"https://claude.ai")],
}
async def _receive() -> dict:
return {"type": "websocket.connect"}
async def _send(message: dict) -> None: # pragma: no cover — inner no-op
raise AssertionError("inner send should not fire in this fake")
await middleware(scope, _receive, _send)
assert called is True
@pytest.mark.asyncio
async def test_middleware_passes_lifespan_scope_through() -> None:
# Lifespan scopes have no Origin header; the middleware must not try to
# 403 them or it breaks app startup/shutdown.
reached = False
async def _inner(scope, receive, send): # type: ignore[no-untyped-def]
nonlocal reached
reached = True
middleware = OriginValidationMiddleware(_inner)
await middleware({"type": "lifespan"}, lambda: None, lambda _m: None) # type: ignore[arg-type]
assert reached is True

View file

@ -0,0 +1,196 @@
"""Unit tests for MCP response size cap."""
from __future__ import annotations
import json
from typing import Any
import pytest
from skyvern.cli.mcp_tools.response import (
MCP_MAX_RESPONSE_CHARS,
size_capped,
truncate_response,
)
def test_truncate_response_passes_small_payload_unchanged() -> None:
small = {"ok": True, "data": {"items": list(range(10))}}
assert truncate_response(small) is small
def test_truncate_response_wraps_large_payload_with_envelope() -> None:
# Construct a payload larger than the default cap.
big_payload = "x" * (MCP_MAX_RESPONSE_CHARS + 100)
large = {"ok": True, "data": {"body": big_payload}}
result = truncate_response(large)
assert result is not large
assert result["_truncated"] is True
assert result["_max_chars"] == MCP_MAX_RESPONSE_CHARS
assert result["_original_chars"] > MCP_MAX_RESPONSE_CHARS
assert "Narrow the query" in result["_hint"]
# Original top-level `ok` is preserved so callers reading .ok still work.
assert result["ok"] is True
# The oversized payload itself is dropped.
assert "data" not in result
def test_truncate_response_preserves_top_level_error_on_overflow() -> None:
large = {
"ok": False,
"error": {"code": "TIMEOUT", "message": "page did not load"},
# Padding to push the total over the cap.
"debug": "y" * (MCP_MAX_RESPONSE_CHARS + 50),
}
result = truncate_response(large)
assert result["_truncated"] is True
assert result["ok"] is False
assert result["error"] == {"code": "TIMEOUT", "message": "page did not load"}
def test_truncate_response_preserves_identifier_fields_on_overflow() -> None:
# A tool that returns identifier fields alongside a bulky payload should
# retain those identifiers in the envelope so the caller can re-query.
large = {
"ok": True,
"workflow_id": "wpid_abc123",
"run_id": "wr_xyz789",
"session_id": "pbs_qqq000",
"timestamp": "ignored",
"count": 12345,
"data": {"blob": "z" * (MCP_MAX_RESPONSE_CHARS + 500)},
}
result = truncate_response(large)
assert result["_truncated"] is True
assert result["ok"] is True
assert result["workflow_id"] == "wpid_abc123"
assert result["run_id"] == "wr_xyz789"
assert result["session_id"] == "pbs_qqq000"
# Keys that do not end with `_id` are not preserved.
assert "timestamp" not in result
assert "count" not in result
# The oversized payload itself is dropped.
assert "data" not in result
def test_truncate_response_caps_oversize_error_field() -> None:
# Pathological input: the `error` field itself is bigger than the cap
# (e.g. a full HTML dump or stack trace serialized into `error.message`).
# Without bounding, copying it verbatim into the envelope would blow the
# envelope past max_chars and break the "under cap" contract.
large_error_message = "x" * (MCP_MAX_RESPONSE_CHARS + 500)
large = {
"ok": False,
"error": {"code": "INTERNAL", "message": large_error_message},
"data": {"n": 1},
}
result = truncate_response(large)
assert result["_truncated"] is True
assert result["ok"] is False
# The oversized error payload is replaced with a structured placeholder,
# not copied verbatim.
assert result["error"] != large["error"]
assert isinstance(result["error"], dict)
assert "_original_error_chars" in result["error"]
assert result["error"]["_error_preview"].endswith("... [truncated]")
# Envelope itself stays under the cap (module contract).
assert len(json.dumps(result, ensure_ascii=False)) <= MCP_MAX_RESPONSE_CHARS
def test_truncate_response_drops_oversize_identifier_values() -> None:
# An identifier value that itself exceeds the per-value cap is dropped so
# the envelope cannot be re-inflated past the overall limit.
large = {
"ok": True,
"short_id": "abc",
"huge_id": "x" * 10_000,
"data": "y" * (MCP_MAX_RESPONSE_CHARS + 100),
}
result = truncate_response(large)
assert result["_truncated"] is True
assert result["short_id"] == "abc"
assert "huge_id" not in result
def test_truncate_response_accepts_custom_max() -> None:
payload = {"data": "z" * 200}
# payload JSON is ~213 chars; cap at 100 forces truncation.
result = truncate_response(payload, max_chars=100)
assert result["_truncated"] is True
assert result["_max_chars"] == 100
def test_truncate_response_non_dict_overflow_wraps_into_envelope() -> None:
# A tool that returns a raw list (unusual but legal) should still be guarded.
big_list = ["x" * 100] * 2000
result = truncate_response(big_list)
assert isinstance(result, dict)
assert result["_truncated"] is True
assert "ok" not in result
def test_truncate_response_unserializable_input_returned_as_is() -> None:
# object() is not JSON-serializable; json.dumps(..., default=str) stringifies
# it, so the helper returns the payload unchanged (size is small).
sentinel: dict[str, Any] = {"x": object()}
result = truncate_response(sentinel)
assert result is sentinel
def test_truncate_response_serialization_failure_is_fail_closed() -> None:
# Circular references make json.dumps raise ValueError. A size cap that
# can't measure a payload must fail CLOSED (wrap in the truncation
# envelope) rather than passing the unmeasurable payload through.
import sys
circular: dict[str, Any] = {"ok": True, "error": None}
circular["self"] = circular
result = truncate_response(circular)
assert result is not circular
assert result["_truncated"] is True
# Sentinel: unmeasurable payloads report `sys.maxsize` for
# `_original_chars`. Locks in the fail-closed contract so an accidental
# change (e.g. returning 0 or None on serialization error) trips here.
assert result["_original_chars"] == sys.maxsize
# Top-level `ok` / `error` are still preserved from the original dict so
# callers reading those fields continue to work.
assert result["ok"] is True
assert result["error"] is None
@pytest.mark.asyncio
async def test_size_capped_decorator_no_op_for_small_result() -> None:
@size_capped
async def small_tool() -> dict[str, Any]:
return {"ok": True, "data": {"n": 1}}
result = await small_tool()
assert result == {"ok": True, "data": {"n": 1}}
@pytest.mark.asyncio
async def test_size_capped_decorator_wraps_oversize_result() -> None:
@size_capped
async def big_tool() -> dict[str, Any]:
return {"ok": True, "data": {"blob": "q" * (MCP_MAX_RESPONSE_CHARS + 500)}}
result = await big_tool()
assert result["_truncated"] is True
assert result["ok"] is True
# Re-serializing the wrapped envelope must be under the cap.
assert len(json.dumps(result, ensure_ascii=False)) <= MCP_MAX_RESPONSE_CHARS
@pytest.mark.asyncio
async def test_size_capped_decorator_preserves_signature() -> None:
@size_capped
async def typed_tool(x: int, y: str = "default") -> dict[str, Any]:
return {"x": x, "y": y}
result = await typed_tool(1, y="override")
assert result == {"x": 1, "y": "override"}
assert typed_tool.__name__ == "typed_tool"

View file

@ -0,0 +1,79 @@
"""Registry-level invariant: every registered MCP tool must carry a human-readable title.
The Claude Connectors Directory submission form rejects servers whose tools
are missing `title` in `ToolAnnotations` (the raw snake_case function name is
not user-facing).
"""
from __future__ import annotations
import pytest
from skyvern.cli.mcp_tools import mcp
@pytest.mark.asyncio
async def test_every_tool_has_a_title() -> None:
tools = await mcp.list_tools()
assert tools, "MCP server registered zero tools"
missing = [t.name for t in tools if t.annotations is None or not t.annotations.title]
assert not missing, f"Tools missing title annotation: {missing}"
@pytest.mark.asyncio
async def test_destructive_tools_flagged() -> None:
"""Tools that delete / close / cancel must carry destructiveHint=True."""
tools = await mcp.list_tools()
by_name = {t.name: t for t in tools}
# A representative sample — extending this set is fine, but none of the
# listed tools should silently lose their destructive annotation. The
# three AI-driven / eval tools (`skyvern_act`, `skyvern_run_task`,
# `skyvern_evaluate`) are included because a user-supplied prompt or
# JavaScript expression can mutate the page destructively; the
# `destructiveHint` tells the client's consent surface so.
expected_destructive = {
"skyvern_browser_session_close",
"skyvern_tab_close",
"skyvern_clear_session_storage",
"skyvern_clear_local_storage",
"skyvern_credential_delete",
"skyvern_folder_delete",
"skyvern_workflow_delete",
"skyvern_workflow_cancel",
"skyvern_act",
"skyvern_run_task",
"skyvern_evaluate",
}
for name in expected_destructive:
tool = by_name.get(name)
assert tool is not None, f"Expected tool not registered: {name}"
assert tool.annotations is not None, f"Tool missing annotations: {name}"
assert tool.annotations.destructiveHint is True, f"Tool {name} expected destructiveHint=True"
@pytest.mark.asyncio
async def test_read_only_sampling_marked_read_only() -> None:
"""Sanity check that known read-only tools keep readOnlyHint=True."""
tools = await mcp.list_tools()
by_name = {t.name: t for t in tools}
expected_ro = {
"skyvern_browser_session_list",
"skyvern_browser_session_get",
"skyvern_extract",
"skyvern_validate",
"skyvern_screenshot",
"skyvern_find",
"skyvern_get_html",
"skyvern_workflow_list",
"skyvern_workflow_get",
}
for name in expected_ro:
tool = by_name.get(name)
assert tool is not None, f"Expected tool not registered: {name}"
assert tool.annotations is not None, f"Tool missing annotations: {name}"
assert tool.annotations.readOnlyHint is True, f"Tool {name} expected readOnlyHint=True"

View file

@ -28,6 +28,15 @@ def test_mcp_instructions_guide_text_prompt_defaults() -> None:
assert "skyvern_observe" in mcp.instructions
@pytest.mark.asyncio
async def test_expected_prompts_registered() -> None:
prompts = await mcp.list_prompts()
prompt_names = {prompt.name for prompt in prompts}
# Deliberately additive-only: this guards that core prompts remain
# registered without breaking when new prompts are introduced later.
assert {"build_workflow", "debug_automation", "extract_data", "qa_test"} <= prompt_names
@pytest.mark.asyncio
async def test_text_prompt_block_schema_example_omits_raw_llm_key() -> None:
result = await skyvern_block_schema(block_type="text_prompt")

View file

@ -6,6 +6,7 @@ from pathlib import Path
import pytest
from skyvern.cli.core.session_manager import set_stateless_http_mode
from skyvern.cli.mcp_tools.prompts import QA_TEST_CONTENT, qa_test
from tests.unit.skill_test_helpers import first_nonempty_line_after_h1
@ -83,6 +84,20 @@ def test_qa_test_prompt_includes_target_url_and_focus_area() -> None:
assert "choose the correct validation mode" in rendered
def test_qa_test_prompt_stateless_http_omits_local_shell_and_filesystem_steps() -> None:
set_stateless_http_mode(True)
try:
rendered = qa_test()
finally:
set_stateless_http_mode(False)
assert ".qa/latest-report.md" not in rendered
assert "gh pr comment" not in rendered
assert "git diff --name-only HEAD~1" not in rendered
assert "local shell, git,\nfilesystem, or `gh` access" in rendered
assert "writing a local report file" in rendered
def test_qa_pr_evidence_markers_present() -> None:
"""Assert the PR evidence posting instructions are present in all /qa surfaces."""
skill_text = BUNDLED_QA_SKILL.read_text(encoding="utf-8")