mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
Add browser-based CLI signup flow (skyvern signup) (#4925)
This commit is contained in:
parent
d9b1c6a0ce
commit
2e663beec0
5 changed files with 353 additions and 17 deletions
178
skyvern/cli/auth_command.py
Normal file
178
skyvern/cli/auth_command.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import secrets
|
||||
import socket
|
||||
import threading
|
||||
import urllib.parse
|
||||
import webbrowser
|
||||
|
||||
import typer
|
||||
|
||||
from .console import console
|
||||
|
||||
_DEFAULT_FRONTEND_URL = "https://app.skyvern.com"
|
||||
_CALLBACK_TIMEOUT = 300
|
||||
|
||||
_SUCCESS_HTML = """\
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<body style="font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #0a0a0a; color: #fafafa;">
|
||||
<div style="text-align: center;">
|
||||
<h1>Signup Successful</h1>
|
||||
<p>You can close this tab and return to your terminal.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
def _find_free_port() -> socket.socket:
|
||||
"""Bind to a free port and return the socket (kept open to prevent TOCTOU race)."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s
|
||||
|
||||
|
||||
def _derive_api_base_url(frontend_url: str) -> str:
|
||||
"""Derive the API base URL from the frontend URL.
|
||||
|
||||
app.skyvern.com -> https://api.skyvern.com
|
||||
localhost:8080 -> http://localhost:8000
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(frontend_url)
|
||||
hostname = parsed.hostname or ""
|
||||
if hostname in ("localhost", "127.0.0.1"):
|
||||
return "http://localhost:8000"
|
||||
if hostname.startswith("app."):
|
||||
new_host = "api." + hostname[4:]
|
||||
if parsed.port:
|
||||
new_host = f"{new_host}:{parsed.port}"
|
||||
return urllib.parse.urlunparse(parsed._replace(netloc=new_host))
|
||||
console.print(
|
||||
f"[yellow]Could not derive API base URL from '{frontend_url}'. "
|
||||
f"You may need to set SKYVERN_BASE_URL manually in .env.[/yellow]"
|
||||
)
|
||||
return frontend_url
|
||||
|
||||
|
||||
class _CallbackHandler(http.server.BaseHTTPRequestHandler):
|
||||
"""HTTP handler that captures the auth callback via form POST.
|
||||
|
||||
The frontend submits a hidden HTML form (browser navigation, not fetch),
|
||||
which avoids CORS / Private Network Access issues entirely.
|
||||
State is stored on the server instance (self.server).
|
||||
"""
|
||||
|
||||
def do_POST(self) -> None:
|
||||
parsed = urllib.parse.urlparse(self.path)
|
||||
if parsed.path != "/callback":
|
||||
self.send_error(404)
|
||||
return
|
||||
|
||||
content_length = int(self.headers.get("Content-Length", 0))
|
||||
body = self.rfile.read(content_length).decode("utf-8")
|
||||
data = urllib.parse.parse_qs(body)
|
||||
|
||||
# Validate state nonce to prevent CSRF
|
||||
expected_state = getattr(self.server, "expected_state", None)
|
||||
state_values = data.get("state", [])
|
||||
if not state_values or state_values[0] != expected_state:
|
||||
self.send_error(403, "Invalid state parameter")
|
||||
return
|
||||
|
||||
api_key_values = data.get("api_key", [])
|
||||
if not api_key_values or not api_key_values[0]:
|
||||
self.send_error(400, "Missing api_key")
|
||||
return
|
||||
|
||||
org_id_values = data.get("organization_id", [])
|
||||
email_values = data.get("email", [])
|
||||
|
||||
self.server.auth_result = { # type: ignore[attr-defined]
|
||||
"api_key": api_key_values[0],
|
||||
"organization_id": org_id_values[0] if org_id_values else None,
|
||||
"email": email_values[0] if email_values else None,
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(_SUCCESS_HTML.encode())
|
||||
|
||||
self.server.received_event.set() # type: ignore[attr-defined]
|
||||
|
||||
def log_message(self, format: str, *args: object) -> None:
|
||||
pass # suppress default HTTP server logs
|
||||
|
||||
|
||||
def run_signup(
|
||||
base_url: str = _DEFAULT_FRONTEND_URL,
|
||||
timeout: int = _CALLBACK_TIMEOUT,
|
||||
) -> None:
|
||||
"""Core signup logic. Called by both the Typer command and init_command."""
|
||||
from .llm_setup import update_or_add_env_var
|
||||
|
||||
bound_socket = _find_free_port()
|
||||
port = bound_socket.getsockname()[1]
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
server = http.server.HTTPServer(("127.0.0.1", port), _CallbackHandler, bind_and_activate=False)
|
||||
server.socket = bound_socket
|
||||
server.server_activate()
|
||||
server.auth_result = {"api_key": None, "organization_id": None, "email": None} # type: ignore[attr-defined]
|
||||
server.received_event = threading.Event() # type: ignore[attr-defined]
|
||||
server.expected_state = state # type: ignore[attr-defined]
|
||||
|
||||
server_thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
try:
|
||||
auth_url = f"{base_url.rstrip('/')}/cli-auth?port={port}&state={state}"
|
||||
console.print("Opening browser for Skyvern signup...")
|
||||
console.print(f"If the browser doesn't open, visit: [link]{auth_url}[/link]")
|
||||
webbrowser.open(auth_url)
|
||||
|
||||
if not server.received_event.wait(timeout=timeout): # type: ignore[attr-defined]
|
||||
console.print("[red]Signup timed out. Please try again.[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
result = server.auth_result # type: ignore[attr-defined]
|
||||
api_key = result["api_key"]
|
||||
organization_id = result["organization_id"]
|
||||
email = result["email"]
|
||||
|
||||
if not api_key:
|
||||
console.print("[red]Failed to receive API key. Please try again.[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
api_base_url = _derive_api_base_url(base_url)
|
||||
|
||||
update_or_add_env_var("SKYVERN_API_KEY", api_key)
|
||||
update_or_add_env_var("SKYVERN_BASE_URL", api_base_url)
|
||||
|
||||
console.print("\n[bold green]Signup successful![/bold green]")
|
||||
if email:
|
||||
console.print(f"Email: {email}")
|
||||
if organization_id:
|
||||
console.print(f"Organization: {organization_id}")
|
||||
console.print("API key saved to .env")
|
||||
console.print(f"Base URL: {api_base_url}")
|
||||
|
||||
|
||||
def signup(
|
||||
base_url: str = typer.Option(
|
||||
_DEFAULT_FRONTEND_URL,
|
||||
"--base-url",
|
||||
help="Frontend URL (e.g. http://localhost:8080 for local dev)",
|
||||
),
|
||||
timeout: int = typer.Option(
|
||||
_CALLBACK_TIMEOUT,
|
||||
"--timeout",
|
||||
help="Timeout in seconds waiting for browser signup",
|
||||
),
|
||||
) -> None:
|
||||
"""Sign up for Skyvern Cloud and save your API key."""
|
||||
run_signup(base_url=base_url, timeout=timeout)
|
||||
|
|
@ -6,6 +6,7 @@ from dotenv import load_dotenv
|
|||
from skyvern.forge.sdk.forge_log import setup_logger as _setup_logger
|
||||
from skyvern.utils.env_paths import resolve_backend_env_path
|
||||
|
||||
from ..auth_command import signup as signup_command
|
||||
from ..block import block_app
|
||||
from ..credential import credential_app
|
||||
from ..credentials import credentials_app
|
||||
|
|
@ -82,6 +83,8 @@ cli_app.add_typer(
|
|||
quickstart_app, name="quickstart", help="One-command setup and start for Skyvern (combines init and run)."
|
||||
)
|
||||
|
||||
cli_app.command(name="signup")(signup_command)
|
||||
|
||||
# Browser automation commands
|
||||
cli_app.add_typer(browser_app, name="browser", help="Browser automation commands.")
|
||||
cli_app.add_typer(skill_app, name="skill", help="Manage bundled skill reference files.")
|
||||
|
|
|
|||
|
|
@ -102,27 +102,47 @@ def init_env(
|
|||
)
|
||||
else:
|
||||
console.print(Panel("[bold purple]Cloud Deployment Setup[/bold purple]", border_style="purple"))
|
||||
base_url = Prompt.ask("Enter Skyvern base URL", default="https://api.skyvern.com", show_default=True)
|
||||
if not base_url:
|
||||
base_url = "https://api.skyvern.com"
|
||||
api_key = None
|
||||
|
||||
console.print("\n[bold]To get your API key:[/bold]")
|
||||
console.print("1. Create an account at [link]https://app.skyvern.com[/link]")
|
||||
console.print("2. Go to [bold cyan]Settings[/bold cyan]")
|
||||
console.print("3. [bold green]Copy your API key[/bold green]")
|
||||
api_key = Prompt.ask("Enter your Skyvern API key", password=True)
|
||||
if not api_key:
|
||||
console.print("[red]API key is required.[/red]")
|
||||
api_key = Prompt.ask("Please re-enter your Skyvern API key", password=True)
|
||||
auth_method = Prompt.ask(
|
||||
"Authenticate via [bold blue]browser[/bold blue] (recommended) or paste an [bold yellow]api-key[/bold yellow] manually?",
|
||||
choices=["browser", "api-key"],
|
||||
default="browser",
|
||||
)
|
||||
|
||||
if auth_method == "browser":
|
||||
from .auth_command import run_signup
|
||||
|
||||
frontend_url = Prompt.ask(
|
||||
"Frontend URL",
|
||||
default="https://app.skyvern.com",
|
||||
show_default=True,
|
||||
)
|
||||
run_signup(base_url=frontend_url)
|
||||
api_key = None # already saved by browser_auth
|
||||
else:
|
||||
base_url = Prompt.ask("Enter Skyvern base URL", default="https://api.skyvern.com", show_default=True)
|
||||
if not base_url:
|
||||
base_url = "https://api.skyvern.com"
|
||||
|
||||
console.print("\n[bold]To get your API key:[/bold]")
|
||||
console.print("1. Create an account at [link]https://app.skyvern.com[/link]")
|
||||
console.print("2. Go to [bold cyan]Settings[/bold cyan]")
|
||||
console.print("3. [bold green]Copy your API key[/bold green]")
|
||||
api_key = Prompt.ask("Enter your Skyvern API key", password=True)
|
||||
if not api_key:
|
||||
console.print("[bold red]Error: API key cannot be empty. Aborting initialization.[/bold red]")
|
||||
return False
|
||||
update_or_add_env_var("SKYVERN_BASE_URL", base_url)
|
||||
console.print("[red]API key is required.[/red]")
|
||||
api_key = Prompt.ask("Please re-enter your Skyvern API key", password=True)
|
||||
if not api_key:
|
||||
console.print("[bold red]Error: API key cannot be empty. Aborting initialization.[/bold red]")
|
||||
return False
|
||||
update_or_add_env_var("SKYVERN_BASE_URL", base_url)
|
||||
|
||||
analytics_id_input = Prompt.ask("Please enter your email for analytics (press enter to skip)", default="")
|
||||
analytics_id = analytics_id_input if analytics_id_input else str(uuid.uuid4())
|
||||
update_or_add_env_var("ANALYTICS_ID", analytics_id)
|
||||
update_or_add_env_var("SKYVERN_API_KEY", api_key)
|
||||
if api_key:
|
||||
update_or_add_env_var("SKYVERN_API_KEY", api_key)
|
||||
console.print(f"✅ [green]{resolve_backend_env_path()} file has been initialized.[/green]")
|
||||
|
||||
if Confirm.ask("\nWould you like to [bold yellow]configure the MCP server[/bold yellow]?", default=True):
|
||||
|
|
|
|||
|
|
@ -103,7 +103,13 @@ async def get_current_org_with_authentication(
|
|||
|
||||
|
||||
async def _authenticate_helper(authorization: str) -> Organization:
|
||||
token = authorization.split(" ")[1]
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) < 2 or not parts[1]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
token = parts[1]
|
||||
if not app.authentication_function:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
@ -158,7 +164,13 @@ async def get_current_user_id_with_authentication(
|
|||
|
||||
|
||||
async def _authenticate_user_helper(authorization: str) -> str:
|
||||
token = authorization.split(" ")[1]
|
||||
parts = authorization.split(" ", 1)
|
||||
if len(parts) < 2 or not parts[1]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid credentials",
|
||||
)
|
||||
token = parts[1]
|
||||
if not app.authenticate_user_function:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
|
|
|||
123
tests/unit/test_cli_auth.py
Normal file
123
tests/unit/test_cli_auth.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
"""Tests for skyvern/cli/auth_command.py"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import threading
|
||||
import urllib.parse
|
||||
|
||||
from skyvern.cli.auth_command import _CallbackHandler, _derive_api_base_url, _find_free_port
|
||||
|
||||
|
||||
class TestDeriveApiBaseUrl:
|
||||
def test_localhost(self) -> None:
|
||||
assert _derive_api_base_url("http://localhost:8080") == "http://localhost:8000"
|
||||
|
||||
def test_localhost_no_port(self) -> None:
|
||||
assert _derive_api_base_url("http://localhost") == "http://localhost:8000"
|
||||
|
||||
def test_127_0_0_1(self) -> None:
|
||||
assert _derive_api_base_url("http://127.0.0.1:5173") == "http://localhost:8000"
|
||||
|
||||
def test_app_skyvern(self) -> None:
|
||||
assert _derive_api_base_url("https://app.skyvern.com") == "https://api.skyvern.com"
|
||||
|
||||
def test_app_skyvern_with_port(self) -> None:
|
||||
assert _derive_api_base_url("https://app.skyvern.com:8443") == "https://api.skyvern.com:8443"
|
||||
|
||||
def test_unknown_hostname_returns_input(self) -> None:
|
||||
result = _derive_api_base_url("https://staging.skyvern.com")
|
||||
assert result == "https://staging.skyvern.com"
|
||||
|
||||
|
||||
class TestFindFreePort:
|
||||
def test_returns_bound_socket(self) -> None:
|
||||
sock = _find_free_port()
|
||||
try:
|
||||
port = sock.getsockname()[1]
|
||||
assert 1024 <= port <= 65535
|
||||
# Socket should still be open (bound)
|
||||
assert sock.fileno() != -1
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
class TestCallbackHandlerStateValidation:
|
||||
def _make_server(self, state: str) -> http.server.HTTPServer:
|
||||
sock = _find_free_port()
|
||||
port = sock.getsockname()[1]
|
||||
server = http.server.HTTPServer(("127.0.0.1", port), _CallbackHandler, bind_and_activate=False)
|
||||
server.socket = sock
|
||||
server.server_activate()
|
||||
server.auth_result = {"api_key": None, "organization_id": None, "email": None} # type: ignore[attr-defined]
|
||||
server.received_event = threading.Event() # type: ignore[attr-defined]
|
||||
server.expected_state = state # type: ignore[attr-defined]
|
||||
return server
|
||||
|
||||
def test_valid_state_accepted(self) -> None:
|
||||
server = self._make_server("test-nonce-123")
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
port = server.server_address[1]
|
||||
try:
|
||||
import http.client
|
||||
|
||||
conn = http.client.HTTPConnection("127.0.0.1", port)
|
||||
body = urllib.parse.urlencode(
|
||||
{
|
||||
"api_key": "sk_test_key",
|
||||
"organization_id": "o_123",
|
||||
"email": "test@example.com",
|
||||
"state": "test-nonce-123",
|
||||
}
|
||||
)
|
||||
conn.request("POST", "/callback", body=body, headers={"Content-Type": "application/x-www-form-urlencoded"})
|
||||
resp = conn.getresponse()
|
||||
assert resp.status == 200
|
||||
assert server.auth_result["api_key"] == "sk_test_key" # type: ignore[attr-defined]
|
||||
assert server.auth_result["email"] == "test@example.com" # type: ignore[attr-defined]
|
||||
assert server.received_event.wait(timeout=5) # type: ignore[attr-defined]
|
||||
conn.close()
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
def test_invalid_state_rejected(self) -> None:
|
||||
server = self._make_server("correct-nonce")
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
port = server.server_address[1]
|
||||
try:
|
||||
import http.client
|
||||
|
||||
conn = http.client.HTTPConnection("127.0.0.1", port)
|
||||
body = urllib.parse.urlencode(
|
||||
{
|
||||
"api_key": "sk_test_key",
|
||||
"state": "wrong-nonce",
|
||||
}
|
||||
)
|
||||
conn.request("POST", "/callback", body=body, headers={"Content-Type": "application/x-www-form-urlencoded"})
|
||||
resp = conn.getresponse()
|
||||
assert resp.status == 403
|
||||
assert server.auth_result["api_key"] is None # type: ignore[attr-defined]
|
||||
assert not server.received_event.is_set() # type: ignore[attr-defined]
|
||||
conn.close()
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
def test_missing_api_key_rejected(self) -> None:
|
||||
server = self._make_server("test-nonce")
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
port = server.server_address[1]
|
||||
try:
|
||||
import http.client
|
||||
|
||||
conn = http.client.HTTPConnection("127.0.0.1", port)
|
||||
body = urllib.parse.urlencode({"state": "test-nonce"})
|
||||
conn.request("POST", "/callback", body=body, headers={"Content-Type": "application/x-www-form-urlencoded"})
|
||||
resp = conn.getresponse()
|
||||
assert resp.status == 400
|
||||
conn.close()
|
||||
finally:
|
||||
server.shutdown()
|
||||
Loading…
Add table
Add a link
Reference in a new issue