mirror of
https://github.com/TrustTunnel/TrustTunnel.git
synced 2026-04-28 03:39:53 +00:00
541 lines
18 KiB
Rust
541 lines
18 KiB
Rust
use crate::downstream::Downstream;
|
|
use crate::http_codec::HttpCodec;
|
|
use crate::net_utils::TcpDestination;
|
|
use crate::tls_demultiplexer::Protocol;
|
|
use crate::{
|
|
authentication, core, datagram_pipe, downstream, http_codec, http_datagram_codec,
|
|
http_demultiplexer, http_forwarded_stream, http_icmp_codec, http_ping_handler,
|
|
http_speedtest_handler, http_udp_codec, log_id, log_utils, net_utils, pipe, reverse_proxy,
|
|
tunnel,
|
|
};
|
|
use async_trait::async_trait;
|
|
use bytes::Bytes;
|
|
use http::StatusCode;
|
|
use http_demultiplexer::HttpDemux;
|
|
use std::collections::LinkedList;
|
|
use std::io;
|
|
use std::io::ErrorKind;
|
|
use std::net::IpAddr;
|
|
use std::sync::Arc;
|
|
|
|
const HEALTH_CHECK_AUTHORITY: &str = "_check";
|
|
const UDP_AUTHORITY: &str = "_udp2";
|
|
const ICMP_AUTHORITY: &str = "_icmp";
|
|
|
|
const AUTHORIZATION_FAILURE_EXTRA_HEADER: (&str, &str) =
|
|
("proxy-authenticate", "Basic realm=Authorization Required");
|
|
|
|
const BAD_STATUS_CODE: StatusCode = StatusCode::BAD_GATEWAY;
|
|
const WARNING_HEADER_NAME: &str = "X-Warning";
|
|
const DNS_WARNING_HEADER_NAME: &str = "X-Adguard-Vpn-Error";
|
|
|
|
pub(crate) struct HttpDownstream {
|
|
context: Arc<core::Context>,
|
|
codec: Box<dyn HttpCodec>,
|
|
tls_domain: String,
|
|
request_demux: HttpDemux,
|
|
}
|
|
|
|
struct TcpConnection {
|
|
stream: Box<dyn http_codec::Stream>,
|
|
id: log_utils::IdChain<u64>,
|
|
auth_failure_status_code: StatusCode,
|
|
}
|
|
|
|
struct DatagramMultiplexer {
|
|
stream: Box<dyn http_codec::Stream>,
|
|
id: log_utils::IdChain<u64>,
|
|
auth_failure_status_code: StatusCode,
|
|
}
|
|
|
|
struct DatagramEncoder<D> {
|
|
encoder: Box<dyn http_datagram_codec::Encoder<Datagram = D>>,
|
|
sink: Box<dyn http_codec::DroppingSink>,
|
|
}
|
|
|
|
struct DatagramDecoder<D> {
|
|
source: Box<dyn pipe::Source>,
|
|
decoder: Box<dyn http_datagram_codec::Decoder<Datagram = D>>,
|
|
pending_bytes: LinkedList<Bytes>,
|
|
}
|
|
|
|
struct PendingRequest {
|
|
stream: Box<dyn http_codec::Stream>,
|
|
id: log_utils::IdChain<u64>,
|
|
auth_failure_status_code: StatusCode,
|
|
}
|
|
|
|
impl HttpDownstream {
|
|
pub fn new(context: Arc<core::Context>, codec: Box<dyn HttpCodec>, tls_domain: String) -> Self {
|
|
Self {
|
|
request_demux: HttpDemux::new(context.settings.clone()),
|
|
context,
|
|
codec,
|
|
tls_domain,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Downstream for HttpDownstream {
|
|
async fn listen(
|
|
&mut self,
|
|
) -> io::Result<Option<Box<dyn downstream::PendingMultiplexedRequest>>> {
|
|
loop {
|
|
let stream = match self.codec.listen().await? {
|
|
None => return Ok(None),
|
|
Some(s) => s,
|
|
};
|
|
let request = stream.request().request();
|
|
let stream_id = stream.id();
|
|
log_id!(
|
|
trace,
|
|
stream_id,
|
|
"HTTP downstream received request: {} {}",
|
|
request.method,
|
|
request.uri
|
|
);
|
|
log_id!(
|
|
debug,
|
|
stream_id,
|
|
"Received request: {:?}",
|
|
net_utils::scrub_request(request)
|
|
);
|
|
|
|
let protocol = self.protocol();
|
|
let context = self.context.clone();
|
|
let channel = self.request_demux.select(self.protocol(), request);
|
|
log_id!(
|
|
trace,
|
|
stream_id,
|
|
"HTTP downstream routing to channel: {:?}",
|
|
channel
|
|
);
|
|
match channel {
|
|
net_utils::Channel::Tunnel => {
|
|
log_id!(trace, stream_id, "HTTP downstream: tunnel request");
|
|
let auth_failure_status_code =
|
|
StatusCode::from_u16(self.context.settings.auth_failure_status_code)
|
|
.unwrap_or(StatusCode::PROXY_AUTHENTICATION_REQUIRED);
|
|
return Ok(Some(Box::new(PendingRequest {
|
|
stream,
|
|
id: stream_id,
|
|
auth_failure_status_code,
|
|
})));
|
|
}
|
|
net_utils::Channel::Ping => {
|
|
log_id!(trace, stream_id, "HTTP downstream: ping request");
|
|
tokio::spawn(async move {
|
|
http_ping_handler::listen(
|
|
context.shutdown.clone(),
|
|
Box::new(http_codec::stream_into_codec(stream, protocol)),
|
|
context.settings.tls_handshake_timeout,
|
|
stream_id,
|
|
)
|
|
.await
|
|
});
|
|
}
|
|
net_utils::Channel::Speedtest => {
|
|
log_id!(trace, stream_id, "HTTP downstream: speedtest request");
|
|
tokio::spawn(async move {
|
|
http_speedtest_handler::listen(
|
|
context.shutdown.clone(),
|
|
Box::new(http_codec::stream_into_codec(stream, protocol)),
|
|
context.settings.tls_handshake_timeout,
|
|
context.settings.speedtest_path.clone(),
|
|
stream_id,
|
|
)
|
|
.await
|
|
});
|
|
}
|
|
net_utils::Channel::ReverseProxy => {
|
|
log_id!(trace, stream_id, "HTTP downstream: reverse proxy request");
|
|
tokio::spawn({
|
|
let sni = self.tls_domain.clone();
|
|
async move {
|
|
reverse_proxy::listen(
|
|
context,
|
|
Box::new(http_codec::stream_into_codec(stream, protocol)),
|
|
sni,
|
|
stream_id,
|
|
)
|
|
.await
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn graceful_shutdown(&mut self) -> io::Result<()> {
|
|
self.codec.graceful_shutdown().await
|
|
}
|
|
|
|
fn protocol(&self) -> Protocol {
|
|
self.codec.protocol()
|
|
}
|
|
|
|
fn tls_domain(&self) -> &str {
|
|
&self.tls_domain
|
|
}
|
|
}
|
|
|
|
macro_rules! impl_stream_id {
|
|
(for $($t:ty),+) => {
|
|
$(impl downstream::StreamId for $t {
|
|
fn id(&self) -> log_utils::IdChain<u64> {
|
|
self.id.clone()
|
|
}
|
|
})*
|
|
}
|
|
}
|
|
|
|
impl_stream_id!(for PendingRequest, TcpConnection, DatagramMultiplexer);
|
|
|
|
impl downstream::PendingRequest for TcpConnection {
|
|
type NextState = (Box<dyn pipe::Source>, Box<dyn pipe::Sink>);
|
|
|
|
fn promote_to_next_state(self: Box<Self>) -> io::Result<Self::NextState> {
|
|
if self.stream.request().request().method == http::Method::CONNECT {
|
|
let (source, sink) = self.stream.split();
|
|
return Ok((
|
|
source.finalize(),
|
|
sink.send_ok_response(false)?.into_pipe_sink(),
|
|
));
|
|
}
|
|
|
|
http_forwarded_stream::into_forwarded(self.stream)
|
|
}
|
|
|
|
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
|
|
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
|
|
}
|
|
}
|
|
|
|
impl downstream::PendingTcpConnectRequest for TcpConnection {
|
|
fn client_address(&self) -> io::Result<IpAddr> {
|
|
self.stream.request().client_address()
|
|
}
|
|
|
|
fn destination(&self) -> io::Result<TcpDestination> {
|
|
let request = self.stream.request();
|
|
let authority = request.authority()?;
|
|
|
|
Ok(match authority.as_str().parse() {
|
|
Ok(a) => TcpDestination::Address(a),
|
|
Err(_) => {
|
|
let port = if request.request().method == http::Method::CONNECT {
|
|
authority.port_u16().ok_or_else(|| {
|
|
io::Error::new(
|
|
ErrorKind::Other,
|
|
format!(
|
|
"Unexpected authority port: request={:?}",
|
|
net_utils::scrub_request(request.request())
|
|
),
|
|
)
|
|
})?
|
|
} else {
|
|
authority
|
|
.port_u16()
|
|
.unwrap_or(net_utils::PLAIN_HTTP_PORT_NUMBER)
|
|
};
|
|
|
|
TcpDestination::HostName((authority.host().to_string(), port))
|
|
}
|
|
})
|
|
}
|
|
|
|
fn user_agent(&self) -> Option<String> {
|
|
self.stream.request().user_agent()
|
|
}
|
|
}
|
|
|
|
impl downstream::PendingRequest for PendingRequest {
|
|
type NextState = Option<downstream::PendingDemultiplexedRequest>;
|
|
|
|
fn promote_to_next_state(self: Box<Self>) -> io::Result<Self::NextState> {
|
|
let request = self.stream.request().request();
|
|
|
|
match request.uri.authority().map(http::uri::Authority::as_str) {
|
|
Some(HEALTH_CHECK_AUTHORITY) if request.method == http::Method::CONNECT => {
|
|
self.stream.split().1.send_ok_response(true).map(|_| None)
|
|
}
|
|
Some(UDP_AUTHORITY) | Some(ICMP_AUTHORITY)
|
|
if request.method == http::Method::CONNECT =>
|
|
{
|
|
Ok(Some(
|
|
downstream::PendingDemultiplexedRequest::DatagramMultiplexer(Box::new(
|
|
DatagramMultiplexer {
|
|
stream: self.stream,
|
|
id: self.id,
|
|
auth_failure_status_code: self.auth_failure_status_code,
|
|
},
|
|
)),
|
|
))
|
|
}
|
|
Some(HEALTH_CHECK_AUTHORITY) | Some(UDP_AUTHORITY) | Some(ICMP_AUTHORITY) => {
|
|
log_id!(debug, self.id, "Unexpected request method: {:?}", request);
|
|
fail_request(self.stream, BAD_STATUS_CODE, vec![]);
|
|
Ok(None)
|
|
}
|
|
_ => Ok(Some(downstream::PendingDemultiplexedRequest::TcpConnect(
|
|
Box::new(TcpConnection {
|
|
stream: self.stream,
|
|
id: self.id,
|
|
auth_failure_status_code: self.auth_failure_status_code,
|
|
}),
|
|
))),
|
|
}
|
|
}
|
|
|
|
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
|
|
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
|
|
}
|
|
}
|
|
|
|
impl downstream::PendingMultiplexedRequest for PendingRequest {
|
|
fn auth_info(&self) -> io::Result<Option<authentication::Source>> {
|
|
self.stream.request().auth_info()
|
|
}
|
|
}
|
|
|
|
impl downstream::PendingRequest for DatagramMultiplexer {
|
|
type NextState = downstream::DatagramPipeHalves;
|
|
|
|
fn promote_to_next_state(self: Box<Self>) -> io::Result<Self::NextState> {
|
|
let authority = self.stream.request().authority()?.to_string();
|
|
let (source, sink) = self.stream.split();
|
|
match authority.as_str() {
|
|
UDP_AUTHORITY => Ok(downstream::DatagramPipeHalves::Udp(
|
|
Box::new(DatagramDecoder {
|
|
source: source.finalize(),
|
|
decoder: Box::new(http_udp_codec::Decoder::new(self.id.clone())),
|
|
pending_bytes: Default::default(),
|
|
}),
|
|
Box::new(DatagramEncoder {
|
|
sink: sink.send_ok_response(false)?.into_datagram_sink(),
|
|
encoder: Box::<http_udp_codec::Encoder>::default(),
|
|
}),
|
|
)),
|
|
ICMP_AUTHORITY => Ok(downstream::DatagramPipeHalves::Icmp(
|
|
Box::new(DatagramDecoder {
|
|
source: source.finalize(),
|
|
decoder: Box::new(http_icmp_codec::Decoder::new()),
|
|
pending_bytes: Default::default(),
|
|
}),
|
|
Box::new(DatagramEncoder {
|
|
sink: sink.send_ok_response(false)?.into_datagram_sink(),
|
|
encoder: Box::<http_icmp_codec::Encoder>::default(),
|
|
}),
|
|
)),
|
|
_ => unreachable!(),
|
|
}
|
|
}
|
|
|
|
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
|
|
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
|
|
}
|
|
}
|
|
|
|
impl downstream::PendingDatagramMultiplexerRequest for DatagramMultiplexer {
|
|
fn client_address(&self) -> io::Result<IpAddr> {
|
|
self.stream.request().client_address()
|
|
}
|
|
|
|
fn user_agent(&self) -> Option<String> {
|
|
self.stream.request().user_agent()
|
|
}
|
|
}
|
|
|
|
impl<D> downstream::StreamId for DatagramDecoder<D> {
|
|
fn id(&self) -> log_utils::IdChain<u64> {
|
|
self.source.id()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl<D> datagram_pipe::Source for DatagramDecoder<D> {
|
|
type Output = D;
|
|
|
|
fn id(&self) -> log_utils::IdChain<u64> {
|
|
self.source.id()
|
|
}
|
|
|
|
async fn read(&mut self) -> io::Result<D> {
|
|
loop {
|
|
let chunk = match self.pending_bytes.pop_front() {
|
|
None => match self.source.read().await? {
|
|
pipe::Data::Chunk(bytes) => {
|
|
self.source.consume(bytes.len())?;
|
|
bytes
|
|
}
|
|
pipe::Data::Eof => return Err(io::Error::from(ErrorKind::UnexpectedEof)),
|
|
},
|
|
Some(bytes) => bytes,
|
|
};
|
|
|
|
match self.decoder.decode_chunk(chunk) {
|
|
http_datagram_codec::DecodeResult::WantMore => (),
|
|
http_datagram_codec::DecodeResult::Complete(datagram, tail) => {
|
|
if !tail.is_empty() {
|
|
self.pending_bytes.push_front(tail);
|
|
}
|
|
|
|
return Ok(datagram);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl<D: Send> datagram_pipe::Sink for DatagramEncoder<D> {
|
|
type Input = D;
|
|
|
|
async fn write(&mut self, datagram: D) -> io::Result<datagram_pipe::SendStatus> {
|
|
match self.encoder.encode_packet(&datagram) {
|
|
None => {
|
|
debug!("Failed to encode datagram");
|
|
Ok(datagram_pipe::SendStatus::Dropped)
|
|
}
|
|
Some(encoded) => self.sink.write(encoded),
|
|
}
|
|
}
|
|
}
|
|
|
|
fn tunnel_error_to_status_code(
|
|
error: &tunnel::ConnectionError,
|
|
auth_failure_status_code: StatusCode,
|
|
) -> StatusCode {
|
|
match error {
|
|
tunnel::ConnectionError::Authentication(_) => auth_failure_status_code,
|
|
_ => BAD_STATUS_CODE,
|
|
}
|
|
}
|
|
|
|
fn tunnel_error_to_warn_header(
|
|
error: &tunnel::ConnectionError,
|
|
hostname: &str,
|
|
auth_failure_status_code: StatusCode,
|
|
) -> Vec<(String, String)> {
|
|
match error {
|
|
tunnel::ConnectionError::Io(_) => vec![(
|
|
WARNING_HEADER_NAME.to_string(),
|
|
"300 - Connection failed for some reason".to_string(),
|
|
)],
|
|
tunnel::ConnectionError::Authentication(_) => {
|
|
if auth_failure_status_code == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
|
|
vec![(
|
|
AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(),
|
|
AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(),
|
|
)]
|
|
} else {
|
|
vec![]
|
|
}
|
|
}
|
|
tunnel::ConnectionError::Timeout => {
|
|
vec![(WARNING_HEADER_NAME.to_string(), format!("302 - {}", error))]
|
|
}
|
|
tunnel::ConnectionError::HostUnreachable => {
|
|
vec![(WARNING_HEADER_NAME.to_string(), format!("301 - {}", error))]
|
|
}
|
|
tunnel::ConnectionError::DnsNonroutable => vec![
|
|
(DNS_WARNING_HEADER_NAME.to_string(), hostname.to_string()),
|
|
(WARNING_HEADER_NAME.to_string(), format!("310 - {}", error)),
|
|
],
|
|
tunnel::ConnectionError::DnsLoopback => vec![
|
|
(DNS_WARNING_HEADER_NAME.to_string(), hostname.to_string()),
|
|
(WARNING_HEADER_NAME.to_string(), format!("311 - {}", error)),
|
|
],
|
|
tunnel::ConnectionError::Other(_) => vec![(
|
|
WARNING_HEADER_NAME.to_string(),
|
|
"300 - Connection failed for some reason".to_string(),
|
|
)],
|
|
}
|
|
}
|
|
|
|
fn fail_request(
|
|
stream: Box<dyn http_codec::Stream>,
|
|
status: StatusCode,
|
|
extra_headers: Vec<(String, String)>,
|
|
) {
|
|
let id = stream.id();
|
|
if let Err(e) = stream.split().1.send_bad_response(status, extra_headers) {
|
|
log_id!(debug, id, "Failed to send bad response: {}", e);
|
|
}
|
|
}
|
|
|
|
fn fail_request_with_error(
|
|
stream: Box<dyn http_codec::Stream>,
|
|
error: tunnel::ConnectionError,
|
|
auth_failure_status_code: StatusCode,
|
|
) {
|
|
let extra_headers = tunnel_error_to_warn_header(
|
|
&error,
|
|
request_hostname(stream.request()),
|
|
auth_failure_status_code,
|
|
);
|
|
fail_request(
|
|
stream,
|
|
tunnel_error_to_status_code(&error, auth_failure_status_code),
|
|
extra_headers,
|
|
);
|
|
}
|
|
|
|
fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str {
|
|
request
|
|
.authority()
|
|
.map(http::uri::Authority::as_str)
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn tunnel_error_to_status_code_default_407() {
|
|
let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
|
|
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
|
assert_eq!(
|
|
tunnel_error_to_status_code(&error, status_code),
|
|
StatusCode::PROXY_AUTHENTICATION_REQUIRED
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn tunnel_error_to_status_code_configured_405() {
|
|
let status_code = StatusCode::METHOD_NOT_ALLOWED;
|
|
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
|
assert_eq!(
|
|
tunnel_error_to_status_code(&error, status_code),
|
|
StatusCode::METHOD_NOT_ALLOWED
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn tunnel_error_to_status_code_non_auth_error_unaffected() {
|
|
let error = tunnel::ConnectionError::Timeout;
|
|
assert_eq!(
|
|
tunnel_error_to_status_code(&error, StatusCode::METHOD_NOT_ALLOWED),
|
|
BAD_STATUS_CODE
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn warn_header_includes_proxy_authenticate_for_407() {
|
|
let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
|
|
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
|
let headers = tunnel_error_to_warn_header(&error, "example.com", status_code);
|
|
assert_eq!(headers.len(), 1);
|
|
assert_eq!(headers[0].0, "proxy-authenticate");
|
|
}
|
|
|
|
#[test]
|
|
fn warn_header_empty_for_405() {
|
|
let status_code = StatusCode::METHOD_NOT_ALLOWED;
|
|
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
|
let headers = tunnel_error_to_warn_header(&error, "example.com", status_code);
|
|
assert!(headers.is_empty());
|
|
}
|
|
}
|