mirror of
https://github.com/block/goose.git
synced 2026-04-28 03:29:36 +00:00
Use RMCP for StreamableHTTP OAuth support (#3845)
This commit is contained in:
parent
ee450254b6
commit
6b93260fd0
13 changed files with 221 additions and 553 deletions
10
Cargo.lock
generated
10
Cargo.lock
generated
|
|
@ -6988,14 +6988,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.3.1"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "824daba0a34f8c5c5392295d381e0800f88fd986ba291699f8785f05fa344c1e"
|
||||
checksum = "9a04b2d9da174d72bc0511410242eb3e8f47a02d9e23868e4b076c7c70208eb4"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"futures",
|
||||
"http 1.2.0",
|
||||
"oauth2",
|
||||
"paste",
|
||||
"pin-project-lite",
|
||||
"process-wrap",
|
||||
|
|
@ -7010,13 +7011,14 @@ dependencies = [
|
|||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.3.1"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad6543c0572a4dbc125c23e6f54963ea9ba002294fd81dd4012c204219b0dcaa"
|
||||
checksum = "651b4292f292d4a4ba191e64f0ce553aef5455b8ddf831dc6a8cd328dd68d721"
|
||||
dependencies = [
|
||||
"darling 0.21.0",
|
||||
"proc-macro2",
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ description = "An AI agent"
|
|||
uninlined_format_args = "allow"
|
||||
|
||||
[workspace.dependencies]
|
||||
rmcp = { version = "0.3.1", features = ["schemars"] }
|
||||
rmcp = { version = "0.4.0", features = ["schemars", "auth"] }
|
||||
|
||||
# Patch for Windows cross-compilation issue with crunchy
|
||||
[patch.crates-io]
|
||||
|
|
|
|||
|
|
@ -96,8 +96,9 @@ impl McpClientTrait for MockClient {
|
|||
if let Some(handler) = self.handlers.get(name) {
|
||||
match handler(&arguments) {
|
||||
Ok(content) => Ok(CallToolResult {
|
||||
content,
|
||||
content: Some(content),
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
}),
|
||||
Err(e) => Err(Error::UnexpectedResponse),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ use anyhow::Result;
|
|||
use base64::Engine;
|
||||
use etcetera::{choose_app_strategy, AppStrategy};
|
||||
use indoc::formatdoc;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
|
|
@ -33,7 +34,7 @@ use mcp_server::Router;
|
|||
|
||||
use rmcp::model::{
|
||||
Content, JsonRpcMessage, JsonRpcNotification, JsonRpcVersion2_0, Notification, Prompt,
|
||||
PromptArgument, PromptTemplate, Resource, Role, Tool, ToolAnnotations,
|
||||
PromptArgument, Resource, Role, Tool, ToolAnnotations,
|
||||
};
|
||||
use rmcp::object;
|
||||
|
||||
|
|
@ -46,6 +47,20 @@ use xcap::{Monitor, Window};
|
|||
|
||||
use ignore::gitignore::{Gitignore, GitignoreBuilder};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PromptTemplate {
|
||||
pub id: String,
|
||||
pub template: String,
|
||||
pub arguments: Vec<PromptArgumentTemplate>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PromptArgumentTemplate {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub required: Option<bool>,
|
||||
}
|
||||
|
||||
// Embeds the prompts directory to the build
|
||||
static PROMPTS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/src/developer/prompts");
|
||||
const LINE_READ_LIMIT: usize = 2000;
|
||||
|
|
|
|||
|
|
@ -27,9 +27,11 @@ use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, Extension
|
|||
use super::tool_execution::ToolCallResult;
|
||||
use crate::agents::extension::{Envs, ProcessExit};
|
||||
use crate::config::{Config, ExtensionConfigManager};
|
||||
use crate::oauth::oauth_flow;
|
||||
use crate::prompt_template;
|
||||
use mcp_client::client::{McpClient, McpClientTrait};
|
||||
use rmcp::model::{Content, GetPromptResult, Prompt, ResourceContents, Tool};
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use serde_json::Value;
|
||||
|
||||
type McpClientBox = Arc<Mutex<Box<dyn McpClientTrait>>>;
|
||||
|
|
@ -205,6 +207,7 @@ impl ExtensionManager {
|
|||
uri,
|
||||
timeout,
|
||||
headers,
|
||||
name,
|
||||
..
|
||||
} => {
|
||||
let mut default_headers = HeaderMap::new();
|
||||
|
|
@ -231,13 +234,38 @@ impl ExtensionManager {
|
|||
..Default::default()
|
||||
},
|
||||
);
|
||||
let client = McpClient::connect(
|
||||
let client_res = McpClient::connect(
|
||||
transport,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
.await;
|
||||
let client = if let Err(e) = client_res {
|
||||
// make an attempt at oauth, but failing that, return the original error,
|
||||
// because this might not have been an auth error at all
|
||||
let am = match oauth_flow(uri, name).await {
|
||||
Ok(am) => am,
|
||||
Err(_) => return Err(e.into()),
|
||||
};
|
||||
let client = AuthClient::new(reqwest::Client::default(), am);
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
client,
|
||||
StreamableHttpClientTransportConfig {
|
||||
uri: uri.clone().into(),
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
McpClient::connect(
|
||||
transport,
|
||||
Duration::from_secs(
|
||||
timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT),
|
||||
),
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
client_res?
|
||||
};
|
||||
Box::new(client)
|
||||
}
|
||||
ExtensionConfig::Stdio {
|
||||
|
|
@ -463,6 +491,7 @@ impl ExtensionManager {
|
|||
description: tool.description,
|
||||
input_schema: tool.input_schema,
|
||||
annotations: tool.annotations,
|
||||
output_schema: tool.output_schema,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -719,7 +748,7 @@ impl ExtensionManager {
|
|||
client_guard
|
||||
.call_tool(&tool_name, arguments, cancellation_token)
|
||||
.await
|
||||
.map(|call| call.content)
|
||||
.map(|call| call.content.unwrap_or_default())
|
||||
.map_err(|e| ToolError::ExecutionError(e.to_string()))
|
||||
};
|
||||
|
||||
|
|
@ -947,8 +976,9 @@ mod tests {
|
|||
) -> Result<CallToolResult, Error> {
|
||||
match name {
|
||||
"tool" | "test__tool" => Ok(CallToolResult {
|
||||
content: vec![],
|
||||
content: Some(vec![]),
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
}),
|
||||
_ => Err(Error::TransportClosed),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ pub mod context_mgmt;
|
|||
mod conversation_fixer;
|
||||
pub mod message;
|
||||
pub mod model;
|
||||
pub mod oauth;
|
||||
pub mod permission;
|
||||
pub mod project;
|
||||
pub mod prompt_template;
|
||||
|
|
|
|||
81
crates/goose/src/oauth.rs
Normal file
81
crates/goose/src/oauth.rs
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
use axum::extract::{Query, State};
|
||||
use axum::response::Html;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use minijinja::render;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::AuthorizationManager;
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
|
||||
const CALLBACK_TEMPLATE: &str = include_str!("oauth_callback.html");
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
code_receiver: Arc<Mutex<Option<oneshot::Sender<String>>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CallbackParams {
|
||||
code: String,
|
||||
#[allow(dead_code)]
|
||||
state: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn oauth_flow(
|
||||
mcp_server_url: &String,
|
||||
name: &String,
|
||||
) -> Result<AuthorizationManager, anyhow::Error> {
|
||||
let (code_sender, code_receiver) = oneshot::channel::<String>();
|
||||
let app_state = AppState {
|
||||
code_receiver: Arc::new(Mutex::new(Some(code_sender))),
|
||||
};
|
||||
|
||||
let rendered = render!(CALLBACK_TEMPLATE, name => name);
|
||||
let handler = move |Query(params): Query<CallbackParams>, State(state): State<AppState>| {
|
||||
let rendered = rendered.clone();
|
||||
async move {
|
||||
if let Some(sender) = state.code_receiver.lock().await.take() {
|
||||
let _ = sender.send(params.code);
|
||||
}
|
||||
Html(rendered)
|
||||
}
|
||||
};
|
||||
let app = Router::new()
|
||||
.route("/oauth_callback", get(handler))
|
||||
.with_state(app_state);
|
||||
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
let used_addr = listener.local_addr()?;
|
||||
tokio::spawn(async move {
|
||||
let result = axum::serve(listener, app).await;
|
||||
|
||||
if let Err(e) = result {
|
||||
eprintln!("Callback server error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let mut oauth_state = OAuthState::new(mcp_server_url, None).await?;
|
||||
let redirect_uri = format!("http://localhost:{}/oauth_callback", used_addr.port());
|
||||
oauth_state
|
||||
.start_authorization(&[], redirect_uri.as_str())
|
||||
.await?;
|
||||
|
||||
let authorization_url = oauth_state.get_authorization_url().await?;
|
||||
if webbrowser::open(authorization_url.as_str()).is_err() {
|
||||
eprintln!("Open the following URL to authorize {}:", name);
|
||||
eprintln!(" {}", authorization_url);
|
||||
}
|
||||
|
||||
let auth_code = code_receiver.await?;
|
||||
oauth_state.handle_callback(&auth_code).await?;
|
||||
|
||||
let am = oauth_state
|
||||
.into_authorization_manager()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to get authorization manager"))?;
|
||||
|
||||
Ok(am)
|
||||
}
|
||||
73
crates/goose/src/oauth_callback.html
Normal file
73
crates/goose/src/oauth_callback.html
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>{{ name }} OAuth Success</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: "Cash Sans", -apple-system, BlinkMacSystemFont, "Segoe UI",
|
||||
Roboto, sans-serif;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background-color: #f4f6f7;
|
||||
color: #3f434b;
|
||||
}
|
||||
|
||||
.container {
|
||||
text-align: center;
|
||||
padding: 2rem;
|
||||
background: #ffffff;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0px 12px 32px 0px rgba(0, 0, 0, 0.04),
|
||||
0px 8px 16px 0px rgba(0, 0, 0, 0.02),
|
||||
0px 2px 4px 0px rgba(0, 0, 0, 0.04),
|
||||
0px 0px 1px 0px rgba(0, 0, 0, 0.2);
|
||||
max-width: 400px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #32353b;
|
||||
margin-bottom: 1rem;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.client-name {
|
||||
font-weight: 700;
|
||||
color: #22252a;
|
||||
}
|
||||
|
||||
button {
|
||||
background-color: #32353b;
|
||||
color: #ffffff;
|
||||
border: none;
|
||||
padding: 0.75rem 1.5rem;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
font-family: "Cash Sans", sans-serif;
|
||||
font-weight: 500;
|
||||
margin-top: 1rem;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
background-color: #22252a;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Authorization Success</h1>
|
||||
<p>
|
||||
You have successfully authorized
|
||||
<span class="client-name">{{ name }}</span>. You can now close this
|
||||
window and return to Goose.
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
|
|
@ -1,8 +1,3 @@
|
|||
pub mod client;
|
||||
pub mod oauth;
|
||||
|
||||
#[cfg(test)]
|
||||
mod oauth_tests;
|
||||
|
||||
pub use client::{Error, McpClient, McpClientTrait};
|
||||
pub use oauth::{authenticate_service, ServiceConfig};
|
||||
|
|
|
|||
|
|
@ -1,456 +0,0 @@
|
|||
use anyhow::Result;
|
||||
use axum::{extract::Query, response::Html, routing::get, Router};
|
||||
use base64::Engine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sha2::Digest;
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use tokio::sync::{oneshot, Mutex as TokioMutex};
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct OidcEndpoints {
|
||||
authorization_endpoint: String,
|
||||
token_endpoint: String,
|
||||
registration_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct TokenData {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ClientRegistrationRequest {
|
||||
redirect_uris: Vec<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
grant_types: Vec<String>,
|
||||
response_types: Vec<String>,
|
||||
client_name: String,
|
||||
client_uri: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct ClientRegistrationResponse {
|
||||
client_id: String,
|
||||
client_id_issued_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
client_secret: Option<String>,
|
||||
}
|
||||
|
||||
/// OAuth configuration for any service
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServiceConfig {
|
||||
pub oauth_host: String,
|
||||
pub redirect_uri: String,
|
||||
pub client_name: String,
|
||||
pub client_uri: String,
|
||||
pub discovery_path: Option<String>,
|
||||
}
|
||||
|
||||
impl ServiceConfig {
|
||||
/// Create a generic OAuth configuration from an MCP endpoint URL
|
||||
/// Extracts the base URL for OAuth discovery
|
||||
pub fn from_mcp_endpoint(mcp_url: &str) -> Result<Self> {
|
||||
let parsed_url = Url::parse(mcp_url.trim())?;
|
||||
let oauth_host = format!(
|
||||
"{}://{}{}",
|
||||
parsed_url.scheme(),
|
||||
parsed_url.host_str().ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url)
|
||||
})?,
|
||||
if let Some(port) = parsed_url.port() {
|
||||
format!(":{}", port)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
oauth_host,
|
||||
redirect_uri: "http://localhost:8020".to_string(),
|
||||
client_name: "Goose MCP Client".to_string(),
|
||||
client_uri: "https://github.com/block/goose".to_string(),
|
||||
discovery_path: None, // Use standard discovery
|
||||
})
|
||||
}
|
||||
|
||||
/// Create configuration with custom discovery path for non-standard services
|
||||
pub fn with_custom_discovery(mut self, discovery_path: String) -> Self {
|
||||
self.discovery_path = Some(discovery_path);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the canonical resource URI for the MCP server
|
||||
/// This is used as the resource parameter in OAuth requests (RFC 8707)
|
||||
pub fn get_canonical_resource_uri(&self, mcp_url: &str) -> Result<String> {
|
||||
let parsed_url = Url::parse(mcp_url.trim())?;
|
||||
|
||||
// Build canonical URI: scheme://host[:port][/path]
|
||||
let mut canonical = format!(
|
||||
"{}://{}",
|
||||
parsed_url.scheme().to_lowercase(),
|
||||
parsed_url
|
||||
.host_str()
|
||||
.ok_or_else(|| {
|
||||
anyhow::anyhow!("Invalid MCP URL: no host found in {}", mcp_url)
|
||||
})?
|
||||
.to_lowercase()
|
||||
);
|
||||
|
||||
// Add port if not default
|
||||
if let Some(port) = parsed_url.port() {
|
||||
canonical.push_str(&format!(":{}", port));
|
||||
}
|
||||
|
||||
// Add path if present and not just "/"
|
||||
let path = parsed_url.path();
|
||||
if !path.is_empty() && path != "/" {
|
||||
canonical.push_str(path);
|
||||
}
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
}
|
||||
|
||||
struct OAuthFlow {
|
||||
endpoints: OidcEndpoints,
|
||||
client_id: String,
|
||||
redirect_url: String,
|
||||
state: String,
|
||||
verifier: String,
|
||||
}
|
||||
|
||||
impl OAuthFlow {
|
||||
fn new(endpoints: OidcEndpoints, client_id: String, redirect_url: String) -> Self {
|
||||
Self {
|
||||
endpoints,
|
||||
client_id,
|
||||
redirect_url,
|
||||
state: nanoid::nanoid!(16),
|
||||
verifier: nanoid::nanoid!(64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a dynamic client and return the client_id
|
||||
async fn register_client(endpoints: &OidcEndpoints, config: &ServiceConfig) -> Result<String> {
|
||||
let Some(registration_endpoint) = &endpoints.registration_endpoint else {
|
||||
return Err(anyhow::anyhow!("No registration endpoint available"));
|
||||
};
|
||||
|
||||
let registration_request = ClientRegistrationRequest {
|
||||
redirect_uris: vec![config.redirect_uri.clone()],
|
||||
token_endpoint_auth_method: "none".to_string(),
|
||||
grant_types: vec![
|
||||
"authorization_code".to_string(),
|
||||
"refresh_token".to_string(),
|
||||
],
|
||||
response_types: vec!["code".to_string()],
|
||||
client_name: config.client_name.clone(),
|
||||
client_uri: config.client_uri.clone(),
|
||||
};
|
||||
|
||||
tracing::info!("Registering dynamic client with OAuth server...");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(registration_endpoint)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(®istration_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to register client: {} - {}",
|
||||
status,
|
||||
err_text
|
||||
));
|
||||
}
|
||||
|
||||
let registration_response: ClientRegistrationResponse = resp.json().await?;
|
||||
|
||||
tracing::info!(
|
||||
"Client registered successfully with ID: {}",
|
||||
registration_response.client_id
|
||||
);
|
||||
Ok(registration_response.client_id)
|
||||
}
|
||||
|
||||
fn get_authorization_url(&self, resource: &str) -> String {
|
||||
let challenge = {
|
||||
let digest = sha2::Sha256::digest(self.verifier.as_bytes());
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
|
||||
};
|
||||
|
||||
let params = [
|
||||
("response_type", "code"),
|
||||
("client_id", &self.client_id),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("state", &self.state),
|
||||
("code_challenge", &challenge),
|
||||
("code_challenge_method", "S256"),
|
||||
("resource", resource), // RFC 8707 Resource Parameter
|
||||
];
|
||||
|
||||
format!(
|
||||
"{}?{}",
|
||||
self.endpoints.authorization_endpoint,
|
||||
serde_urlencoded::to_string(params).unwrap()
|
||||
)
|
||||
}
|
||||
|
||||
async fn exchange_code_for_token(&self, code: &str, resource: &str) -> Result<TokenData> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", &self.redirect_url),
|
||||
("code_verifier", &self.verifier),
|
||||
("client_id", &self.client_id),
|
||||
("resource", resource), // RFC 8707 Resource Parameter
|
||||
];
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&self.endpoints.token_endpoint)
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err_text = resp.text().await?;
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to exchange code for token: {}",
|
||||
err_text
|
||||
));
|
||||
}
|
||||
|
||||
let token_response: Value = resp.json().await?;
|
||||
|
||||
let access_token = token_response
|
||||
.get("access_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("access_token not found in token response"))?
|
||||
.to_string();
|
||||
|
||||
let refresh_token = token_response
|
||||
.get("refresh_token")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Ok(TokenData {
|
||||
access_token,
|
||||
refresh_token,
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, resource: &str) -> Result<TokenData> {
|
||||
// Create a channel that will send the auth code from the callback
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let state = self.state.clone();
|
||||
let tx = Arc::new(TokioMutex::new(Some(tx)));
|
||||
|
||||
// Setup a server that will receive the redirect and capture the code
|
||||
let app = Router::new().route(
|
||||
"/",
|
||||
get(move |Query(params): Query<HashMap<String, String>>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
let state = state.clone();
|
||||
async move {
|
||||
let code = params.get("code").cloned();
|
||||
let received_state = params.get("state").cloned();
|
||||
|
||||
if let (Some(code), Some(received_state)) = (code, received_state) {
|
||||
if received_state == state {
|
||||
if let Some(sender) = tx.lock().await.take() {
|
||||
if sender.send(code).is_ok() {
|
||||
return Html(
|
||||
"<h2>Authentication Successful!</h2><p>You can close this window and return to the application.</p>",
|
||||
);
|
||||
}
|
||||
}
|
||||
Html("<h2>Error</h2><p>Authentication already completed.</p>")
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>State mismatch - possible security issue.</p>")
|
||||
}
|
||||
} else {
|
||||
Html("<h2>Error</h2><p>Authentication failed - missing parameters.</p>")
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Start the callback server
|
||||
let redirect_url = Url::parse(&self.redirect_url)?;
|
||||
let port = redirect_url.port().unwrap_or(8020);
|
||||
let addr = SocketAddr::from(([127, 0, 0, 1], port));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
let server = axum::serve(listener, app);
|
||||
server.await.unwrap();
|
||||
});
|
||||
|
||||
// Open the browser for OAuth
|
||||
let authorization_url = self.get_authorization_url(resource);
|
||||
tracing::info!("Opening browser for OAuth authentication...");
|
||||
|
||||
if webbrowser::open(&authorization_url).is_err() {
|
||||
tracing::warn!("Could not open browser automatically. Please open this URL manually:");
|
||||
tracing::warn!("{}", authorization_url);
|
||||
}
|
||||
|
||||
// Wait for the authorization code with a timeout
|
||||
let code = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(120), // 2 minute timeout
|
||||
rx,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Authentication timed out after 2 minutes"))??;
|
||||
|
||||
// Stop the callback server
|
||||
server_handle.abort();
|
||||
|
||||
// Exchange the code for a token
|
||||
self.exchange_code_for_token(&code, resource).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_oauth_endpoints(
|
||||
host: &str,
|
||||
custom_discovery_path: Option<&str>,
|
||||
) -> Result<OidcEndpoints> {
|
||||
let base_url = Url::parse(host)?;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Define discovery paths to try, with custom path first if provided
|
||||
let mut discovery_paths = Vec::new();
|
||||
if let Some(custom_path) = custom_discovery_path {
|
||||
discovery_paths.push(custom_path);
|
||||
}
|
||||
discovery_paths.extend([
|
||||
"/.well-known/oauth-authorization-server",
|
||||
"/.well-known/openid_configuration",
|
||||
"/oauth/.well-known/oauth-authorization-server",
|
||||
"/.well-known/oauth_authorization_server", // Some services use underscore
|
||||
]);
|
||||
|
||||
let discovery_paths_for_error = discovery_paths.clone(); // Clone for error message
|
||||
let mut last_error = None;
|
||||
|
||||
// Try each discovery path until one works
|
||||
for path in discovery_paths {
|
||||
match base_url.join(path) {
|
||||
Ok(discovery_url) => {
|
||||
tracing::debug!("Trying OAuth discovery at: {}", discovery_url);
|
||||
|
||||
match client.get(discovery_url.clone()).send().await {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
match resp.json::<Value>().await {
|
||||
Ok(oidc_config) => {
|
||||
// Try to parse the OAuth configuration
|
||||
match parse_oauth_config(oidc_config) {
|
||||
Ok(endpoints) => {
|
||||
tracing::info!(
|
||||
"Successfully discovered OAuth endpoints at: {}",
|
||||
discovery_url
|
||||
);
|
||||
return Ok(endpoints);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
"Invalid OAuth config at {}: {}",
|
||||
discovery_url,
|
||||
e
|
||||
);
|
||||
last_error = Some(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!(
|
||||
"Failed to parse JSON from {}: {}",
|
||||
discovery_url,
|
||||
e
|
||||
);
|
||||
last_error = Some(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(resp) => {
|
||||
tracing::debug!("HTTP {} from {}", resp.status(), discovery_url);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("Request failed to {}: {}", discovery_url, e);
|
||||
last_error = Some(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("Invalid discovery URL {}{}: {}", host, path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"No OAuth discovery endpoint found at {}. Tried paths: {:?}",
|
||||
host,
|
||||
discovery_paths_for_error
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_oauth_config(oidc_config: Value) -> Result<OidcEndpoints> {
|
||||
let authorization_endpoint = oidc_config
|
||||
.get("authorization_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("authorization_endpoint not found in OAuth configuration"))?
|
||||
.to_string();
|
||||
|
||||
let token_endpoint = oidc_config
|
||||
.get("token_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("token_endpoint not found in OAuth configuration"))?
|
||||
.to_string();
|
||||
|
||||
let registration_endpoint = oidc_config
|
||||
.get("registration_endpoint")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
Ok(OidcEndpoints {
|
||||
authorization_endpoint,
|
||||
token_endpoint,
|
||||
registration_endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform OAuth flow for a service
|
||||
pub async fn authenticate_service(config: ServiceConfig, mcp_url: &str) -> Result<String> {
|
||||
tracing::info!("Starting OAuth authentication for service...");
|
||||
|
||||
// Get the canonical resource URI for the MCP server
|
||||
let resource_uri = config.get_canonical_resource_uri(mcp_url)?;
|
||||
tracing::info!("Using resource URI: {}", resource_uri);
|
||||
|
||||
// Get OAuth endpoints using flexible discovery
|
||||
let endpoints =
|
||||
get_oauth_endpoints(&config.oauth_host, config.discovery_path.as_deref()).await?;
|
||||
|
||||
// Register dynamic client to get client_id
|
||||
let client_id = OAuthFlow::register_client(&endpoints, &config).await?;
|
||||
|
||||
// Create and execute OAuth flow with the dynamic client_id
|
||||
let flow = OAuthFlow::new(endpoints, client_id, config.redirect_uri);
|
||||
|
||||
let token_data = flow.execute(&resource_uri).await?;
|
||||
|
||||
tracing::info!("OAuth authentication successful!");
|
||||
Ok(token_data.access_token)
|
||||
}
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::oauth::ServiceConfig;
|
||||
|
||||
#[test]
|
||||
fn test_canonical_resource_uri_generation() {
|
||||
let config = ServiceConfig {
|
||||
oauth_host: "https://example.com".to_string(),
|
||||
redirect_uri: "http://localhost:8020".to_string(),
|
||||
client_name: "Test Client".to_string(),
|
||||
client_uri: "https://test.com".to_string(),
|
||||
discovery_path: None,
|
||||
};
|
||||
|
||||
// Test basic URL
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com/mcp")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com/mcp");
|
||||
|
||||
// Test URL with port
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com:8443/mcp")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com:8443/mcp");
|
||||
|
||||
// Test URL without path
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com");
|
||||
|
||||
// Test URL with root path
|
||||
let result = config
|
||||
.get_canonical_resource_uri("https://mcp.example.com/")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com");
|
||||
|
||||
// Test case normalization
|
||||
let result = config
|
||||
.get_canonical_resource_uri("HTTPS://MCP.EXAMPLE.COM/mcp")
|
||||
.unwrap();
|
||||
assert_eq!(result, "https://mcp.example.com/mcp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_from_mcp_endpoint() {
|
||||
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/api/mcp").unwrap();
|
||||
|
||||
assert_eq!(config.oauth_host, "https://mcp.example.com");
|
||||
assert_eq!(config.redirect_uri, "http://localhost:8020");
|
||||
assert_eq!(config.client_name, "Goose MCP Client");
|
||||
assert_eq!(config.client_uri, "https://github.com/block/goose");
|
||||
assert!(config.discovery_path.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_with_port() {
|
||||
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com:8443/mcp").unwrap();
|
||||
|
||||
assert_eq!(config.oauth_host, "https://mcp.example.com:8443");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_service_config_invalid_url() {
|
||||
let result = ServiceConfig::from_mcp_endpoint("invalid-url");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_discovery_path() {
|
||||
let config = ServiceConfig::from_mcp_endpoint("https://mcp.example.com/mcp")
|
||||
.unwrap()
|
||||
.with_custom_discovery("/custom/oauth/discovery".to_string());
|
||||
|
||||
assert_eq!(
|
||||
config.discovery_path,
|
||||
Some("/custom/oauth/discovery".to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -2961,6 +2961,10 @@
|
|||
},
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -703,6 +703,9 @@ export type Tool = {
|
|||
[key: string]: unknown;
|
||||
};
|
||||
name: string;
|
||||
outputSchema?: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
};
|
||||
|
||||
export type ToolAnnotations = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue