Add auth_failure_status_code feature to change response code on auth failure

This commit is contained in:
Zhavoronkov Aleksei 2026-03-06 17:12:36 +03:00
parent 412490e20f
commit 2883fcf58e
4 changed files with 170 additions and 13 deletions

View file

@ -22,7 +22,6 @@ const HEALTH_CHECK_AUTHORITY: &str = "_check";
const UDP_AUTHORITY: &str = "_udp2";
const ICMP_AUTHORITY: &str = "_icmp";
const AUTHORIZATION_FAILURE_STATUS_CODE: StatusCode = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
const AUTHORIZATION_FAILURE_EXTRA_HEADER: (&str, &str) =
("proxy-authenticate", "Basic realm=Authorization Required");
@ -40,11 +39,13 @@ pub(crate) struct HttpDownstream {
struct TcpConnection {
stream: Box<dyn http_codec::Stream>,
id: log_utils::IdChain<u64>,
auth_failure_status_code: StatusCode,
}
struct DatagramMultiplexer {
stream: Box<dyn http_codec::Stream>,
id: log_utils::IdChain<u64>,
auth_failure_status_code: StatusCode,
}
struct DatagramEncoder<D> {
@ -61,6 +62,7 @@ struct DatagramDecoder<D> {
struct PendingRequest {
stream: Box<dyn http_codec::Stream>,
id: log_utils::IdChain<u64>,
auth_failure_status_code: StatusCode,
}
impl HttpDownstream {
@ -112,9 +114,13 @@ impl Downstream for HttpDownstream {
match channel {
net_utils::Channel::Tunnel => {
log_id!(trace, stream_id, "HTTP downstream: tunnel request");
let auth_failure_status_code =
StatusCode::from_u16(self.context.settings.auth_failure_status_code)
.unwrap_or(StatusCode::PROXY_AUTHENTICATION_REQUIRED);
return Ok(Some(Box::new(PendingRequest {
stream,
id: stream_id,
auth_failure_status_code,
})));
}
net_utils::Channel::Ping => {
@ -202,7 +208,7 @@ impl downstream::PendingRequest for TcpConnection {
}
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
fail_request_with_error(self.stream, error);
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
}
}
@ -262,6 +268,7 @@ impl downstream::PendingRequest for PendingRequest {
DatagramMultiplexer {
stream: self.stream,
id: self.id,
auth_failure_status_code: self.auth_failure_status_code,
},
)),
))
@ -275,13 +282,14 @@ impl downstream::PendingRequest for PendingRequest {
Box::new(TcpConnection {
stream: self.stream,
id: self.id,
auth_failure_status_code: self.auth_failure_status_code,
}),
))),
}
}
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
fail_request_with_error(self.stream, error);
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
}
}
@ -325,7 +333,7 @@ impl downstream::PendingRequest for DatagramMultiplexer {
}
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
fail_request_with_error(self.stream, error);
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
}
}
@ -395,9 +403,12 @@ impl<D: Send> datagram_pipe::Sink for DatagramEncoder<D> {
}
}
fn tunnel_error_to_status_code(error: &tunnel::ConnectionError) -> StatusCode {
fn tunnel_error_to_status_code(
error: &tunnel::ConnectionError,
auth_failure_status_code: StatusCode,
) -> StatusCode {
match error {
tunnel::ConnectionError::Authentication(_) => AUTHORIZATION_FAILURE_STATUS_CODE,
tunnel::ConnectionError::Authentication(_) => auth_failure_status_code,
_ => BAD_STATUS_CODE,
}
}
@ -405,16 +416,23 @@ fn tunnel_error_to_status_code(error: &tunnel::ConnectionError) -> StatusCode {
fn tunnel_error_to_warn_header(
error: &tunnel::ConnectionError,
hostname: &str,
auth_failure_status_code: StatusCode,
) -> Vec<(String, String)> {
match error {
tunnel::ConnectionError::Io(_) => vec![(
WARNING_HEADER_NAME.to_string(),
"300 - Connection failed for some reason".to_string(),
)],
tunnel::ConnectionError::Authentication(_) => vec![(
AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(),
AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(),
)],
tunnel::ConnectionError::Authentication(_) => {
if auth_failure_status_code == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
vec![(
AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(),
AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(),
)]
} else {
vec![]
}
}
tunnel::ConnectionError::Timeout => {
vec![(WARNING_HEADER_NAME.to_string(), format!("302 - {}", error))]
}
@ -447,9 +465,21 @@ fn fail_request(
}
}
fn fail_request_with_error(stream: Box<dyn http_codec::Stream>, error: tunnel::ConnectionError) {
let extra_headers = tunnel_error_to_warn_header(&error, request_hostname(stream.request()));
fail_request(stream, tunnel_error_to_status_code(&error), extra_headers);
fn fail_request_with_error(
stream: Box<dyn http_codec::Stream>,
error: tunnel::ConnectionError,
auth_failure_status_code: StatusCode,
) {
let extra_headers = tunnel_error_to_warn_header(
&error,
request_hostname(stream.request()),
auth_failure_status_code,
);
fail_request(
stream,
tunnel_error_to_status_code(&error, auth_failure_status_code),
extra_headers,
);
}
fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str {
@ -458,3 +488,54 @@ fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str {
.map(http::uri::Authority::as_str)
.unwrap_or_default()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tunnel_error_to_status_code_default_407() {
let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
let error = tunnel::ConnectionError::Authentication("bad creds".into());
assert_eq!(
tunnel_error_to_status_code(&error, status_code),
StatusCode::PROXY_AUTHENTICATION_REQUIRED
);
}
#[test]
fn tunnel_error_to_status_code_configured_405() {
let status_code = StatusCode::METHOD_NOT_ALLOWED;
let error = tunnel::ConnectionError::Authentication("bad creds".into());
assert_eq!(
tunnel_error_to_status_code(&error, status_code),
StatusCode::METHOD_NOT_ALLOWED
);
}
#[test]
fn tunnel_error_to_status_code_non_auth_error_unaffected() {
let error = tunnel::ConnectionError::Timeout;
assert_eq!(
tunnel_error_to_status_code(&error, StatusCode::METHOD_NOT_ALLOWED),
BAD_STATUS_CODE
);
}
#[test]
fn warn_header_includes_proxy_authenticate_for_407() {
let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
let error = tunnel::ConnectionError::Authentication("bad creds".into());
let headers = tunnel_error_to_warn_header(&error, "example.com", status_code);
assert_eq!(headers.len(), 1);
assert_eq!(headers[0].0, "proxy-authenticate");
}
#[test]
fn warn_header_empty_for_405() {
let status_code = StatusCode::METHOD_NOT_ALLOWED;
let error = tunnel::ConnectionError::Authentication("bad creds".into());
let headers = tunnel_error_to_warn_header(&error, "example.com", status_code);
assert!(headers.is_empty());
}
}

View file

@ -34,6 +34,8 @@ pub enum ValidationError {
RulesFile(String),
/// No credentials configured while listening on a public address
NoCredentialsOnPublicAddress,
/// Invalid auth failure status code
InvalidAuthFailureStatusCode(u16),
}
impl Debug for ValidationError {
@ -52,6 +54,11 @@ impl Debug for ValidationError {
"No credentials configured (credentials_file is missing) while listening on a public address. \
This is a security risk. Either configure credentials or use a loopback address (127.0.0.1 or ::1)"
),
Self::InvalidAuthFailureStatusCode(code) => write!(
f,
"Invalid auth_failure_status_code: {}. Supported values: 407, 405",
code
),
}
}
}
@ -190,6 +197,12 @@ pub struct Settings {
/// Optional path prefix for speedtest requests on main hosts.
#[serde(default = "Settings::default_speedtest_path")]
pub(crate) speedtest_path: Option<String>,
/// HTTP status code returned on authentication failure.
/// Supported values: 407 (Proxy Authentication Required) or 405 (Method Not Allowed).
#[serde(default = "Settings::default_auth_failure_status_code")]
pub(crate) auth_failure_status_code: u16,
/// Default maximum number of simultaneous HTTP/1 and HTTP/2 connections per client credentials.
/// TrustTunnel clients open 8 HTTP/2 connections by default, so set this to
/// `8 * <max_devices>` to limit the number of simultaneously connected devices.
@ -527,6 +540,12 @@ impl Settings {
return Err(ValidationError::NoCredentialsOnPublicAddress);
}
if self.auth_failure_status_code != 407 && self.auth_failure_status_code != 405 {
return Err(ValidationError::InvalidAuthFailureStatusCode(
self.auth_failure_status_code,
));
}
Ok(())
}
@ -578,6 +597,10 @@ impl Settings {
Some("/ping".to_string())
}
pub fn default_auth_failure_status_code() -> u16 {
407
}
fn validate_request_path(name: &str, path: &Option<String>) -> Result<(), ValidationError> {
if let Some(path) = path {
if path.is_empty() || !path.starts_with('/') {
@ -632,6 +655,7 @@ impl Default for Settings {
speedtest_path: None,
default_max_http2_conns_per_client: None,
default_max_http3_conns_per_client: None,
auth_failure_status_code: Settings::default_auth_failure_status_code(),
built: false,
}
}
@ -891,6 +915,7 @@ impl SettingsBuilder {
speedtest_path: Settings::default_speedtest_path(),
default_max_http2_conns_per_client: None,
default_max_http3_conns_per_client: None,
auth_failure_status_code: Settings::default_auth_failure_status_code(),
built: true,
},
}
@ -1024,6 +1049,12 @@ impl SettingsBuilder {
self
}
/// Set the HTTP status code for authentication failures (407 or 405)
pub fn auth_failure_status_code(mut self, x: u16) -> Self {
self.settings.auth_failure_status_code = x;
self
}
/// Set whether ping is available
pub fn ping_enable(mut self, x: bool) -> Self {
self.settings.ping_enable = x;
@ -1638,3 +1669,42 @@ where
fn demangle_toml_string(x: String) -> String {
x.replace('"', "").trim().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_auth_failure_status_code_is_407() {
let settings = Settings::default();
assert_eq!(settings.auth_failure_status_code, 407);
}
#[test]
fn auth_failure_status_code_407_valid() {
let mut settings = Settings::default();
settings.auth_failure_status_code = 407;
settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into();
assert!(settings.validate().is_ok());
}
#[test]
fn auth_failure_status_code_405_valid() {
let mut settings = Settings::default();
settings.auth_failure_status_code = 405;
settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into();
assert!(settings.validate().is_ok());
}
#[test]
fn auth_failure_status_code_200_invalid() {
let mut settings = Settings::default();
settings.auth_failure_status_code = 200;
settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into();
let err = settings.validate().unwrap_err();
assert!(matches!(
err,
ValidationError::InvalidAuthFailureStatusCode(200)
));
}
}

View file

@ -51,6 +51,7 @@ fn compose_main_table(settings: &Settings, credentials_path: &str, rules_path: &
} else {
doc.remove("ping_path");
}
doc["auth_failure_status_code"] = value(*settings.get_auth_failure_status_code() as i64);
doc.to_string()
}

View file

@ -76,6 +76,9 @@ ping_enable = {}
{}
ping_path = "{}"
{}
auth_failure_status_code = {}
"#,
Settings::doc_listen_address().to_toml_comment(),
crate::library_settings::DEFAULT_CREDENTIALS_PATH,
@ -106,6 +109,8 @@ ping_path = "{}"
Settings::default_ping_enable(),
Settings::doc_ping_path().to_toml_comment(),
Settings::default_ping_path().unwrap(),
Settings::doc_auth_failure_status_code().to_toml_comment(),
Settings::default_auth_failure_status_code(),
)
});