diff --git a/lib/src/http_downstream.rs b/lib/src/http_downstream.rs index f215629..006039c 100644 --- a/lib/src/http_downstream.rs +++ b/lib/src/http_downstream.rs @@ -22,7 +22,6 @@ const HEALTH_CHECK_AUTHORITY: &str = "_check"; const UDP_AUTHORITY: &str = "_udp2"; const ICMP_AUTHORITY: &str = "_icmp"; -const AUTHORIZATION_FAILURE_STATUS_CODE: StatusCode = StatusCode::PROXY_AUTHENTICATION_REQUIRED; const AUTHORIZATION_FAILURE_EXTRA_HEADER: (&str, &str) = ("proxy-authenticate", "Basic realm=Authorization Required"); @@ -40,11 +39,13 @@ pub(crate) struct HttpDownstream { struct TcpConnection { stream: Box, id: log_utils::IdChain, + auth_failure_status_code: StatusCode, } struct DatagramMultiplexer { stream: Box, id: log_utils::IdChain, + auth_failure_status_code: StatusCode, } struct DatagramEncoder { @@ -61,6 +62,7 @@ struct DatagramDecoder { struct PendingRequest { stream: Box, id: log_utils::IdChain, + auth_failure_status_code: StatusCode, } impl HttpDownstream { @@ -112,9 +114,13 @@ impl Downstream for HttpDownstream { match channel { net_utils::Channel::Tunnel => { log_id!(trace, stream_id, "HTTP downstream: tunnel request"); + let auth_failure_status_code = + StatusCode::from_u16(self.context.settings.auth_failure_status_code) + .unwrap_or(StatusCode::PROXY_AUTHENTICATION_REQUIRED); return Ok(Some(Box::new(PendingRequest { stream, id: stream_id, + auth_failure_status_code, }))); } net_utils::Channel::Ping => { @@ -202,7 +208,7 @@ impl downstream::PendingRequest for TcpConnection { } fn fail_request(self: Box, error: tunnel::ConnectionError) { - fail_request_with_error(self.stream, error); + fail_request_with_error(self.stream, error, self.auth_failure_status_code); } } @@ -262,6 +268,7 @@ impl downstream::PendingRequest for PendingRequest { DatagramMultiplexer { stream: self.stream, id: self.id, + auth_failure_status_code: self.auth_failure_status_code, }, )), )) @@ -275,13 +282,14 @@ impl downstream::PendingRequest for PendingRequest { Box::new(TcpConnection { stream: self.stream, id: self.id, + auth_failure_status_code: self.auth_failure_status_code, }), ))), } } fn fail_request(self: Box, error: tunnel::ConnectionError) { - fail_request_with_error(self.stream, error); + fail_request_with_error(self.stream, error, self.auth_failure_status_code); } } @@ -325,7 +333,7 @@ impl downstream::PendingRequest for DatagramMultiplexer { } fn fail_request(self: Box, error: tunnel::ConnectionError) { - fail_request_with_error(self.stream, error); + fail_request_with_error(self.stream, error, self.auth_failure_status_code); } } @@ -395,9 +403,12 @@ impl datagram_pipe::Sink for DatagramEncoder { } } -fn tunnel_error_to_status_code(error: &tunnel::ConnectionError) -> StatusCode { +fn tunnel_error_to_status_code( + error: &tunnel::ConnectionError, + auth_failure_status_code: StatusCode, +) -> StatusCode { match error { - tunnel::ConnectionError::Authentication(_) => AUTHORIZATION_FAILURE_STATUS_CODE, + tunnel::ConnectionError::Authentication(_) => auth_failure_status_code, _ => BAD_STATUS_CODE, } } @@ -405,16 +416,23 @@ fn tunnel_error_to_status_code(error: &tunnel::ConnectionError) -> StatusCode { fn tunnel_error_to_warn_header( error: &tunnel::ConnectionError, hostname: &str, + auth_failure_status_code: StatusCode, ) -> Vec<(String, String)> { match error { tunnel::ConnectionError::Io(_) => vec![( WARNING_HEADER_NAME.to_string(), "300 - Connection failed for some reason".to_string(), )], - tunnel::ConnectionError::Authentication(_) => vec![( - AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(), - AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(), - )], + tunnel::ConnectionError::Authentication(_) => { + if auth_failure_status_code == StatusCode::PROXY_AUTHENTICATION_REQUIRED { + vec![( + AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(), + AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(), + )] + } else { + vec![] + } + } tunnel::ConnectionError::Timeout => { vec![(WARNING_HEADER_NAME.to_string(), format!("302 - {}", error))] } @@ -447,9 +465,21 @@ fn fail_request( } } -fn fail_request_with_error(stream: Box, error: tunnel::ConnectionError) { - let extra_headers = tunnel_error_to_warn_header(&error, request_hostname(stream.request())); - fail_request(stream, tunnel_error_to_status_code(&error), extra_headers); +fn fail_request_with_error( + stream: Box, + error: tunnel::ConnectionError, + auth_failure_status_code: StatusCode, +) { + let extra_headers = tunnel_error_to_warn_header( + &error, + request_hostname(stream.request()), + auth_failure_status_code, + ); + fail_request( + stream, + tunnel_error_to_status_code(&error, auth_failure_status_code), + extra_headers, + ); } fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str { @@ -458,3 +488,54 @@ fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str { .map(http::uri::Authority::as_str) .unwrap_or_default() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn tunnel_error_to_status_code_default_407() { + let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED; + let error = tunnel::ConnectionError::Authentication("bad creds".into()); + assert_eq!( + tunnel_error_to_status_code(&error, status_code), + StatusCode::PROXY_AUTHENTICATION_REQUIRED + ); + } + + #[test] + fn tunnel_error_to_status_code_configured_405() { + let status_code = StatusCode::METHOD_NOT_ALLOWED; + let error = tunnel::ConnectionError::Authentication("bad creds".into()); + assert_eq!( + tunnel_error_to_status_code(&error, status_code), + StatusCode::METHOD_NOT_ALLOWED + ); + } + + #[test] + fn tunnel_error_to_status_code_non_auth_error_unaffected() { + let error = tunnel::ConnectionError::Timeout; + assert_eq!( + tunnel_error_to_status_code(&error, StatusCode::METHOD_NOT_ALLOWED), + BAD_STATUS_CODE + ); + } + + #[test] + fn warn_header_includes_proxy_authenticate_for_407() { + let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED; + let error = tunnel::ConnectionError::Authentication("bad creds".into()); + let headers = tunnel_error_to_warn_header(&error, "example.com", status_code); + assert_eq!(headers.len(), 1); + assert_eq!(headers[0].0, "proxy-authenticate"); + } + + #[test] + fn warn_header_empty_for_405() { + let status_code = StatusCode::METHOD_NOT_ALLOWED; + let error = tunnel::ConnectionError::Authentication("bad creds".into()); + let headers = tunnel_error_to_warn_header(&error, "example.com", status_code); + assert!(headers.is_empty()); + } +} diff --git a/lib/src/settings.rs b/lib/src/settings.rs index 91f93d7..1430351 100644 --- a/lib/src/settings.rs +++ b/lib/src/settings.rs @@ -34,6 +34,8 @@ pub enum ValidationError { RulesFile(String), /// No credentials configured while listening on a public address NoCredentialsOnPublicAddress, + /// Invalid auth failure status code + InvalidAuthFailureStatusCode(u16), } impl Debug for ValidationError { @@ -52,6 +54,11 @@ impl Debug for ValidationError { "No credentials configured (credentials_file is missing) while listening on a public address. \ This is a security risk. Either configure credentials or use a loopback address (127.0.0.1 or ::1)" ), + Self::InvalidAuthFailureStatusCode(code) => write!( + f, + "Invalid auth_failure_status_code: {}. Supported values: 407, 405", + code + ), } } } @@ -190,6 +197,12 @@ pub struct Settings { /// Optional path prefix for speedtest requests on main hosts. #[serde(default = "Settings::default_speedtest_path")] pub(crate) speedtest_path: Option, + + /// HTTP status code returned on authentication failure. + /// Supported values: 407 (Proxy Authentication Required) or 405 (Method Not Allowed). + #[serde(default = "Settings::default_auth_failure_status_code")] + pub(crate) auth_failure_status_code: u16, + /// Default maximum number of simultaneous HTTP/1 and HTTP/2 connections per client credentials. /// TrustTunnel clients open 8 HTTP/2 connections by default, so set this to /// `8 * ` to limit the number of simultaneously connected devices. @@ -527,6 +540,12 @@ impl Settings { return Err(ValidationError::NoCredentialsOnPublicAddress); } + if self.auth_failure_status_code != 407 && self.auth_failure_status_code != 405 { + return Err(ValidationError::InvalidAuthFailureStatusCode( + self.auth_failure_status_code, + )); + } + Ok(()) } @@ -578,6 +597,10 @@ impl Settings { Some("/ping".to_string()) } + pub fn default_auth_failure_status_code() -> u16 { + 407 + } + fn validate_request_path(name: &str, path: &Option) -> Result<(), ValidationError> { if let Some(path) = path { if path.is_empty() || !path.starts_with('/') { @@ -632,6 +655,7 @@ impl Default for Settings { speedtest_path: None, default_max_http2_conns_per_client: None, default_max_http3_conns_per_client: None, + auth_failure_status_code: Settings::default_auth_failure_status_code(), built: false, } } @@ -891,6 +915,7 @@ impl SettingsBuilder { speedtest_path: Settings::default_speedtest_path(), default_max_http2_conns_per_client: None, default_max_http3_conns_per_client: None, + auth_failure_status_code: Settings::default_auth_failure_status_code(), built: true, }, } @@ -1024,6 +1049,12 @@ impl SettingsBuilder { self } + /// Set the HTTP status code for authentication failures (407 or 405) + pub fn auth_failure_status_code(mut self, x: u16) -> Self { + self.settings.auth_failure_status_code = x; + self + } + /// Set whether ping is available pub fn ping_enable(mut self, x: bool) -> Self { self.settings.ping_enable = x; @@ -1638,3 +1669,42 @@ where fn demangle_toml_string(x: String) -> String { x.replace('"', "").trim().to_string() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_auth_failure_status_code_is_407() { + let settings = Settings::default(); + assert_eq!(settings.auth_failure_status_code, 407); + } + + #[test] + fn auth_failure_status_code_407_valid() { + let mut settings = Settings::default(); + settings.auth_failure_status_code = 407; + settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into(); + assert!(settings.validate().is_ok()); + } + + #[test] + fn auth_failure_status_code_405_valid() { + let mut settings = Settings::default(); + settings.auth_failure_status_code = 405; + settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into(); + assert!(settings.validate().is_ok()); + } + + #[test] + fn auth_failure_status_code_200_invalid() { + let mut settings = Settings::default(); + settings.auth_failure_status_code = 200; + settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into(); + let err = settings.validate().unwrap_err(); + assert!(matches!( + err, + ValidationError::InvalidAuthFailureStatusCode(200) + )); + } +} diff --git a/tools/setup_wizard/composer.rs b/tools/setup_wizard/composer.rs index 81a44d8..ec21f92 100644 --- a/tools/setup_wizard/composer.rs +++ b/tools/setup_wizard/composer.rs @@ -51,6 +51,7 @@ fn compose_main_table(settings: &Settings, credentials_path: &str, rules_path: & } else { doc.remove("ping_path"); } + doc["auth_failure_status_code"] = value(*settings.get_auth_failure_status_code() as i64); doc.to_string() } diff --git a/tools/setup_wizard/template_settings.rs b/tools/setup_wizard/template_settings.rs index 5834d64..b94d4de 100644 --- a/tools/setup_wizard/template_settings.rs +++ b/tools/setup_wizard/template_settings.rs @@ -76,6 +76,9 @@ ping_enable = {} {} ping_path = "{}" + +{} +auth_failure_status_code = {} "#, Settings::doc_listen_address().to_toml_comment(), crate::library_settings::DEFAULT_CREDENTIALS_PATH, @@ -106,6 +109,8 @@ ping_path = "{}" Settings::default_ping_enable(), Settings::doc_ping_path().to_toml_comment(), Settings::default_ping_path().unwrap(), + Settings::doc_auth_failure_status_code().to_toml_comment(), + Settings::default_auth_failure_status_code(), ) });