diff --git a/Cargo.lock b/Cargo.lock index d6da689e..242c93ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -452,9 +452,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.52" +version = "1.2.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" +checksum = "755d2fce177175ffca841e9a06afdb2c4ab0f593d53b4dee48147dfaade85932" dependencies = [ "find-msvc-tools", "jobserver", @@ -752,9 +752,9 @@ checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "find-msvc-tools" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" +checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db" [[package]] name = "fixedbitset" @@ -3177,9 +3177,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21e6f2ab2928ca4291b86736a8bd920a277a399bba1589409d72154ff87c1282" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ "web-time", "zeroize", @@ -3187,9 +3187,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.8" +version = "0.103.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" dependencies = [ "aws-lc-rs", "ring", @@ -3675,9 +3675,9 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasip2" -version = "1.0.1+wasi-0.2.4" +version = "1.0.2+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" dependencies = [ "wit-bindgen", ] @@ -3963,9 +3963,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "wit-bindgen" -version = "0.46.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" [[package]] name = "writeable" diff --git a/g3proxy/CHANGELOG b/g3proxy/CHANGELOG index 7c964600..779528bf 100644 --- a/g3proxy/CHANGELOG +++ b/g3proxy/CHANGELOG @@ -3,6 +3,8 @@ v1.13.0: - Feature: add facts user group type and enable it in the following servers: - tcp_tproxy - sni_proxy + - tcp_stream + - tls_stream - Feature: support set TCP max segment size in listen config - Compatibility: bump MSRV to 1.90.0 - Deprecated: the following config options are deprecated: diff --git a/g3proxy/src/config/server/tcp_stream.rs b/g3proxy/src/config/server/tcp_stream.rs index 6539eab8..d8a89588 100644 --- a/g3proxy/src/config/server/tcp_stream.rs +++ b/g3proxy/src/config/server/tcp_stream.rs @@ -13,6 +13,7 @@ use yaml_rust::{Yaml, yaml}; use g3_io_ext::StreamCopyConfig; use g3_types::acl::AclNetworkRuleBuilder; +use g3_types::auth::FactsMatchType; use g3_types::collection::SelectivePickPolicy; use g3_types::metrics::{MetricTagMap, NodeName}; use g3_types::net::{ @@ -34,6 +35,9 @@ pub(crate) struct TcpStreamServerConfig { position: Option, pub(crate) escaper: NodeName, pub(crate) auditor: NodeName, + pub(crate) user_group: NodeName, + auth_by_client_ip: bool, + pub(crate) auth_match: Option, pub(crate) shared_logger: Option, pub(crate) listen: Option, pub(crate) listen_in_worker: bool, @@ -60,6 +64,9 @@ impl TcpStreamServerConfig { position, escaper: NodeName::default(), auditor: NodeName::default(), + user_group: NodeName::default(), + auth_by_client_ip: false, + auth_match: None, shared_logger: None, listen: None, listen_in_worker: false, @@ -107,6 +114,14 @@ impl TcpStreamServerConfig { self.auditor = g3_yaml::value::as_metric_node_name(v)?; Ok(()) } + "user_group" => { + self.user_group = g3_yaml::value::as_metric_node_name(v)?; + Ok(()) + } + "auth_by_client_ip" => { + self.auth_by_client_ip = g3_yaml::value::as_bool(v)?; + Ok(()) + } "shared_logger" => { let name = g3_yaml::value::as_ascii(v)?; self.shared_logger = Some(name); @@ -240,6 +255,14 @@ impl TcpStreamServerConfig { if self.upstream.is_empty() { return Err(anyhow!("upstream is not set")); } + + if self.auth_by_client_ip { + self.auth_match = Some(FactsMatchType::ClientIp); + } + if self.auth_match.is_some() && self.user_group.is_empty() { + return Err(anyhow!("user group is not set but auth is enabled")); + } + if self.task_idle_check_interval > IDLE_CHECK_MAXIMUM_DURATION { self.task_idle_check_interval = IDLE_CHECK_MAXIMUM_DURATION; } @@ -272,7 +295,7 @@ impl ServerConfig for TcpStreamServerConfig { } fn user_group(&self) -> &NodeName { - Default::default() + &self.user_group } fn auditor(&self) -> &NodeName { diff --git a/g3proxy/src/config/server/tls_stream.rs b/g3proxy/src/config/server/tls_stream.rs index ca4d738c..2bdcecf3 100644 --- a/g3proxy/src/config/server/tls_stream.rs +++ b/g3proxy/src/config/server/tls_stream.rs @@ -14,6 +14,7 @@ use yaml_rust::{Yaml, yaml}; use g3_io_ext::StreamCopyConfig; use g3_tls_ticket::TlsTicketConfig; use g3_types::acl::AclNetworkRuleBuilder; +use g3_types::auth::FactsMatchType; use g3_types::collection::SelectivePickPolicy; use g3_types::metrics::{MetricTagMap, NodeName}; use g3_types::net::{ @@ -35,6 +36,9 @@ pub(crate) struct TlsStreamServerConfig { position: Option, pub(crate) escaper: NodeName, pub(crate) auditor: NodeName, + pub(crate) user_group: NodeName, + auth_by_client_ip: bool, + pub(crate) auth_match: Option, pub(crate) shared_logger: Option, pub(crate) listen: Option, pub(crate) listen_in_worker: bool, @@ -63,6 +67,9 @@ impl TlsStreamServerConfig { position, escaper: NodeName::default(), auditor: NodeName::default(), + user_group: NodeName::default(), + auth_by_client_ip: false, + auth_match: None, shared_logger: None, listen: None, listen_in_worker: false, @@ -112,6 +119,14 @@ impl TlsStreamServerConfig { self.auditor = g3_yaml::value::as_metric_node_name(v)?; Ok(()) } + "user_group" => { + self.user_group = g3_yaml::value::as_metric_node_name(v)?; + Ok(()) + } + "auth_by_client_ip" => { + self.auth_by_client_ip = g3_yaml::value::as_bool(v)?; + Ok(()) + } "shared_logger" => { let name = g3_yaml::value::as_ascii(v)?; self.shared_logger = Some(name); @@ -260,6 +275,13 @@ impl TlsStreamServerConfig { return Err(anyhow!("upstream is not set")); } + if self.auth_by_client_ip { + self.auth_match = Some(FactsMatchType::ClientIp); + } + if self.auth_match.is_some() && self.user_group.is_empty() { + return Err(anyhow!("user group is not set but auth is enabled")); + } + self.server_tls_config .check() .context("invalid server tls config")?; @@ -296,7 +318,7 @@ impl ServerConfig for TlsStreamServerConfig { } fn user_group(&self) -> &NodeName { - Default::default() + &self.user_group } fn auditor(&self) -> &NodeName { diff --git a/g3proxy/src/serve/tcp_stream/server.rs b/g3proxy/src/serve/tcp_stream/server.rs index 43595dae..f0deb686 100644 --- a/g3proxy/src/serve/tcp_stream/server.rs +++ b/g3proxy/src/serve/tcp_stream/server.rs @@ -5,10 +5,12 @@ use std::net::SocketAddr; use std::sync::Arc; +use std::time::Duration; use anyhow::{Context, anyhow}; use arc_swap::{ArcSwap, ArcSwapOption}; use async_trait::async_trait; +use log::warn; #[cfg(feature = "quic")] use quinn::Connection; use slog::Logger; @@ -22,6 +24,7 @@ use g3_daemon::server::{BaseServer, ClientConnectionInfo, ServerExt, ServerReloa use g3_io_ext::{AsyncStream, IdleWheel}; use g3_openssl::SslStream; use g3_types::acl::{AclAction, AclNetworkRule}; +use g3_types::auth::FactsMatchType; use g3_types::collection::{SelectiveVec, SelectiveVecBuilder}; use g3_types::metrics::NodeName; use g3_types::net::{OpensslClientConfig, UpstreamAddr, WeightedUpstreamAddr}; @@ -30,12 +33,13 @@ use super::common::CommonTaskContext; use super::stats::TcpStreamServerStats; use super::task::TcpStreamTask; use crate::audit::{AuditContext, AuditHandle}; +use crate::auth::{FactsUserGroup, UserContext, UserGroup}; use crate::config::server::tcp_stream::TcpStreamServerConfig; use crate::config::server::{AnyServerConfig, ServerConfig}; use crate::escape::ArcEscaper; use crate::serve::{ ArcServer, ArcServerInternal, ArcServerStats, Server, ServerInternal, ServerQuitPolicy, - ServerRegistry, ServerStats, WrapArcServer, + ServerRegistry, ServerStats, ServerTaskNotes, WrapArcServer, }; pub(crate) struct TcpStreamServer { @@ -49,6 +53,7 @@ pub(crate) struct TcpStreamServer { task_logger: Option, escaper: ArcSwap, + user_group: ArcSwapOption, audit_handle: ArcSwapOption, quit_policy: Arc, idle_wheel: Arc, @@ -104,11 +109,13 @@ impl TcpStreamServer { reload_sender, task_logger, escaper: ArcSwap::new(escaper), + user_group: ArcSwapOption::new(None), audit_handle: ArcSwapOption::new(audit_handle), quit_policy: Arc::new(ServerQuitPolicy::default()), idle_wheel, reload_version: version, }; + server._update_user_group_in_place(); Ok(server) } @@ -163,6 +170,40 @@ impl TcpStreamServer { AuditContext::new(self.audit_handle.load_full()) } + fn get_task_notes(&self, cc_info: &ClientConnectionInfo) -> Option { + let task_notes = if let Some(auth_match) = self.config.auth_match { + let ip = match auth_match { + FactsMatchType::ClientIp => cc_info.client_ip(), + FactsMatchType::ServerIp => cc_info.server_ip(), + FactsMatchType::ServerName => return None, + }; + let Some((user, user_type)) = self + .user_group + .load() + .as_ref() + .and_then(|g| g.get_user_by_ip(ip)) + else { + // TODO log + return None; + }; + let user_ctx = UserContext::new( + None, + user, + user_type, + self.config.name(), + self.server_stats.share_extra_tags(), + ); + if user_ctx.check_client_addr(cc_info.client_addr()).is_err() { + // TODO may be attack + return None; + } + ServerTaskNotes::new(cc_info.clone(), Some(user_ctx), Duration::ZERO) + } else { + ServerTaskNotes::new(cc_info.clone(), None, Duration::ZERO) + }; + Some(task_notes) + } + fn get_ctx_and_upstream( &self, cc_info: ClientConnectionInfo, @@ -190,10 +231,13 @@ impl TcpStreamServer { T::R: AsyncRead + Send + Sync + Unpin + 'static, T::W: AsyncWrite + Send + Sync + Unpin + 'static, { + let Some(task_notes) = self.get_task_notes(&cc_info) else { + return; + }; let (ctx, upstream) = self.get_ctx_and_upstream(cc_info); let (clt_r, clt_w) = stream.into_split(); - TcpStreamTask::new(ctx, upstream, self.audit_context()) + TcpStreamTask::new(ctx, upstream, self.audit_context(), task_notes) .into_running(clt_r, clt_w) .await; } @@ -205,10 +249,13 @@ impl TcpStreamServer { recv_stream: quinn::RecvStream, cc_info: ClientConnectionInfo, ) { + let Some(task_notes) = self.get_task_notes(&cc_info) else { + return; + }; let (ctx, upstream) = self.get_ctx_and_upstream(cc_info); tokio::spawn( - TcpStreamTask::new(ctx, upstream, self.audit_context()) + TcpStreamTask::new(ctx, upstream, self.audit_context(), task_notes) .into_running(recv_stream, send_stream), ); } @@ -235,7 +282,24 @@ impl ServerInternal for TcpStreamServer { self.escaper.store(Arc::new(escaper)); } - fn _update_user_group_in_place(&self) {} + fn _update_user_group_in_place(&self) { + let user_group = if let Some(g) = self.config.get_user_group() { + let g_type = g.r#type(); + if let UserGroup::Facts(g) = g { + Some(g) + } else { + warn!( + "server {}: user group {}(type {g_type}) ignored", + self.config.name(), + self.config.user_group + ); + None + } + } else { + None + }; + self.user_group.store(user_group); + } fn _update_audit_handle_in_place(&self) -> anyhow::Result<()> { let audit_handle = self.config.get_audit_handle()?; @@ -353,7 +417,7 @@ impl Server for TcpStreamServer { } fn user_group(&self) -> &NodeName { - Default::default() + self.config.user_group() } fn auditor(&self) -> &NodeName { diff --git a/g3proxy/src/serve/tcp_stream/task.rs b/g3proxy/src/serve/tcp_stream/task.rs index 018d9f3d..dde8a2fd 100644 --- a/g3proxy/src/serve/tcp_stream/task.rs +++ b/g3proxy/src/serve/tcp_stream/task.rs @@ -3,6 +3,7 @@ * Copyright 2023-2025 ByteDance and/or its affiliates. */ +use std::borrow::Cow; use std::sync::Arc; use std::time::Duration; @@ -11,16 +12,21 @@ use tokio::io::{AsyncRead, AsyncWrite}; use g3_daemon::server::ServerQuitPolicy; use g3_daemon::stat::task::TcpStreamTaskStats; use g3_io_ext::{IdleInterval, LimitedReader, LimitedWriter, StreamCopyConfig}; +use g3_types::acl::AclAction; use g3_types::net::UpstreamAddr; use super::common::CommonTaskContext; use super::stats::{TcpStreamServerAliveTaskGuard, TcpStreamTaskCltWrapperStats}; use crate::audit::AuditContext; use crate::auth::User; +use crate::config::server::ServerConfig; use crate::inspect::{StreamInspectContext, StreamTransitTask}; use crate::log::task::tcp_connect::TaskLogForTcpConnect; use crate::module::tcp_connect::{TcpConnectTaskConf, TcpConnectTaskNotes, TlsConnectTaskConf}; -use crate::serve::{ServerTaskError, ServerTaskNotes, ServerTaskResult, ServerTaskStage}; +use crate::serve::{ + ServerStats, ServerTaskError, ServerTaskForbiddenError, ServerTaskNotes, ServerTaskResult, + ServerTaskStage, +}; pub(super) struct TcpStreamTask { ctx: CommonTaskContext, @@ -29,16 +35,26 @@ pub(super) struct TcpStreamTask { task_notes: ServerTaskNotes, task_stats: Arc, audit_ctx: AuditContext, + started: bool, _alive_guard: Option, } +impl Drop for TcpStreamTask { + fn drop(&mut self) { + if self.started { + self.post_stop(); + self.started = false; + } + } +} + impl TcpStreamTask { pub(super) fn new( ctx: CommonTaskContext, upstream: &UpstreamAddr, audit_ctx: AuditContext, + task_notes: ServerTaskNotes, ) -> Self { - let task_notes = ServerTaskNotes::new(ctx.cc_info.clone(), None, Duration::ZERO); TcpStreamTask { ctx, upstream: upstream.clone(), @@ -46,6 +62,7 @@ impl TcpStreamTask { task_notes, task_stats: Arc::new(TcpStreamTaskStats::default()), audit_ctx, + started: false, _alive_guard: None, } } @@ -85,11 +102,54 @@ impl TcpStreamTask { fn pre_start(&mut self) { self._alive_guard = Some(self.ctx.server_stats.add_task()); + if let Some(user_ctx) = self.task_notes.user_ctx() { + user_ctx.foreach_req_stats(|s| { + s.req_total.add_tcp_connect(); + s.req_alive.add_tcp_connect(); + }); + } + if self.ctx.server_config.flush_task_log_on_created && let Some(log_ctx) = self.get_log_context() { log_ctx.log_created(); } + + self.started = true; + } + + fn post_stop(&mut self) { + if let Some(user_ctx) = self.task_notes.user_ctx() { + user_ctx.foreach_req_stats(|s| { + s.req_alive.del_tcp_connect(); + }); + + if let Some(user_req_alive_permit) = self.task_notes.user_req_alive_permit.take() { + drop(user_req_alive_permit); + } + } + } + + async fn handle_user_upstream_acl_action(&mut self, action: AclAction) -> ServerTaskResult<()> { + let forbid = match action { + AclAction::Permit => false, + AclAction::PermitAndLog => { + // TODO log permit + false + } + AclAction::Forbid => true, + AclAction::ForbidAndLog => { + // TODO log forbid + true + } + }; + if forbid { + Err(ServerTaskError::ForbiddenByRule( + ServerTaskForbiddenError::DestDenied, + )) + } else { + Ok(()) + } } async fn run(&mut self, clt_r: CR, clt_w: CW) -> ServerTaskResult<()> @@ -97,10 +157,42 @@ impl TcpStreamTask { CR: AsyncRead + Send + Sync + Unpin + 'static, CW: AsyncWrite + Send + Sync + Unpin + 'static, { + let tcp_client_misc_opts; + + if let Some(user_ctx) = self.task_notes.user_ctx() { + let user_ctx = user_ctx.clone(); + + if user_ctx.check_rate_limit().is_err() { + return Err(ServerTaskError::ForbiddenByRule( + ServerTaskForbiddenError::RateLimited, + )); + } + + match user_ctx.acquire_request_semaphore() { + Ok(permit) => self.task_notes.user_req_alive_permit = Some(permit), + Err(_) => { + return Err(ServerTaskError::ForbiddenByRule( + ServerTaskForbiddenError::FullyLoaded, + )); + } + } + + let action = user_ctx.check_upstream(&self.upstream); + self.handle_user_upstream_acl_action(action).await?; + + let user_config = user_ctx.user_config(); + + tcp_client_misc_opts = + user_config.tcp_client_misc_opts(&self.ctx.server_config.tcp_misc_opts); + // + } else { + tcp_client_misc_opts = Cow::Borrowed(&self.ctx.server_config.tcp_misc_opts); + } + // set client side socket options self.ctx .cc_info - .tcp_sock_set_raw_opts(&self.ctx.server_config.tcp_misc_opts, true) + .tcp_sock_set_raw_opts(&tcp_client_misc_opts, true) .map_err(|_| { ServerTaskError::InternalServerError("failed to set client socket options") })?; @@ -186,28 +278,42 @@ impl TcpStreamTask { UW: AsyncWrite + Send + Sync + Unpin + 'static, { if let Some(audit_handle) = self.audit_ctx.check_take_handle() { - let ctx = StreamInspectContext::new( - audit_handle, - self.ctx.server_config.clone(), - self.ctx.server_stats.clone(), - self.ctx.server_quit_policy.clone(), - self.ctx.idle_wheel.clone(), - &self.task_notes, - &self.tcp_notes, - ); - crate::inspect::stream::transit_with_inspection( - clt_r, - clt_w, - ups_r, - ups_w, - ctx, - self.upstream.clone(), - None, - ) - .await - } else { - self.transit_transparent(clt_r, clt_w, ups_r, ups_w).await + let audit_task = self + .task_notes + .user_ctx() + .map(|ctx| { + let user_config = &ctx.user_config().audit; + user_config.enable_protocol_inspection + && user_config + .do_task_audit() + .unwrap_or_else(|| audit_handle.do_task_audit()) + }) + .unwrap_or_else(|| audit_handle.do_task_audit()); + + if audit_task { + let ctx = StreamInspectContext::new( + audit_handle, + self.ctx.server_config.clone(), + self.ctx.server_stats.clone(), + self.ctx.server_quit_policy.clone(), + self.ctx.idle_wheel.clone(), + &self.task_notes, + &self.tcp_notes, + ); + return crate::inspect::stream::transit_with_inspection( + clt_r, + clt_w, + ups_r, + ups_w, + ctx, + self.upstream.clone(), + None, + ) + .await; + } } + + self.transit_transparent(clt_r, clt_w, ups_r, ups_w).await } fn setup_limit_and_stats( @@ -219,24 +325,47 @@ impl TcpStreamTask { CR: AsyncRead, CW: AsyncWrite, { - let wrapper_stats = + let mut wrapper_stats = TcpStreamTaskCltWrapperStats::new(&self.ctx.server_stats, &self.task_stats); - let wrapper_stats = Arc::new(wrapper_stats); - let clt_speed_limit = &self.ctx.server_config.tcp_sock_speed_limit; - let clt_r = LimitedReader::local_limited( + let limit_config = if let Some(user_ctx) = self.task_notes.user_ctx() { + wrapper_stats.push_user_io_stats(user_ctx.fetch_traffic_stats( + self.ctx.server_config.name(), + self.ctx.server_stats.share_extra_tags(), + )); + + user_ctx + .user_config() + .tcp_sock_speed_limit + .shrink_as_smaller(&self.ctx.server_config.tcp_sock_speed_limit) + } else { + self.ctx.server_config.tcp_sock_speed_limit + }; + + let wrapper_stats = Arc::new(wrapper_stats); + let mut clt_r = LimitedReader::local_limited( clt_r, - clt_speed_limit.shift_millis, - clt_speed_limit.max_north, + limit_config.shift_millis, + limit_config.max_north, wrapper_stats.clone(), ); - let clt_w = LimitedWriter::local_limited( + let mut clt_w = LimitedWriter::local_limited( clt_w, - clt_speed_limit.shift_millis, - clt_speed_limit.max_south, + limit_config.shift_millis, + limit_config.max_south, wrapper_stats, ); + if let Some(user_ctx) = self.task_notes.user_ctx() { + let user = user_ctx.user(); + if let Some(limiter) = user.tcp_all_upload_speed_limit() { + clt_r.add_global_limiter(limiter.clone()); + } + if let Some(limiter) = user.tcp_all_download_speed_limit() { + clt_w.add_global_limiter(limiter.clone()); + } + } + (clt_r, clt_w) } } @@ -281,6 +410,6 @@ impl StreamTransitTask for TcpStreamTask { } fn user(&self) -> Option<&User> { - None + self.task_notes.user_ctx().map(|ctx| ctx.user().as_ref()) } } diff --git a/g3proxy/src/serve/tls_stream/server.rs b/g3proxy/src/serve/tls_stream/server.rs index 7169a763..2ae4935c 100644 --- a/g3proxy/src/serve/tls_stream/server.rs +++ b/g3proxy/src/serve/tls_stream/server.rs @@ -10,7 +10,7 @@ use std::time::Duration; use anyhow::{Context, anyhow}; use arc_swap::{ArcSwap, ArcSwapOption}; use async_trait::async_trait; -use log::debug; +use log::{debug, warn}; #[cfg(feature = "quic")] use quinn::Connection; use slog::Logger; @@ -23,6 +23,7 @@ use g3_daemon::server::{BaseServer, ClientConnectionInfo, ServerExt, ServerReloa use g3_io_ext::IdleWheel; use g3_openssl::SslStream; use g3_types::acl::{AclAction, AclNetworkRule}; +use g3_types::auth::FactsMatchType; use g3_types::collection::{SelectiveVec, SelectiveVecBuilder}; use g3_types::metrics::NodeName; use g3_types::net::{ @@ -33,13 +34,14 @@ use g3_types::net::{ use super::common::CommonTaskContext; use super::task::TlsStreamTask; use crate::audit::{AuditContext, AuditHandle}; +use crate::auth::{FactsUserGroup, UserContext, UserGroup}; use crate::config::server::tls_stream::TlsStreamServerConfig; use crate::config::server::{AnyServerConfig, ServerConfig}; use crate::escape::ArcEscaper; use crate::serve::tcp_stream::TcpStreamServerStats; use crate::serve::{ ArcServer, ArcServerInternal, ArcServerStats, Server, ServerInternal, ServerQuitPolicy, - ServerRegistry, ServerStats, WrapArcServer, + ServerRegistry, ServerStats, ServerTaskNotes, WrapArcServer, }; pub(crate) struct TlsStreamServer { @@ -56,6 +58,7 @@ pub(crate) struct TlsStreamServer { task_logger: Option, escaper: ArcSwap, + user_group: ArcSwapOption, audit_handle: ArcSwapOption, quit_policy: Arc, idle_wheel: Arc, @@ -121,11 +124,13 @@ impl TlsStreamServer { reload_sender, task_logger, escaper: ArcSwap::new(escaper), + user_group: ArcSwapOption::new(None), audit_handle: ArcSwapOption::new(audit_handle), quit_policy: Arc::new(ServerQuitPolicy::default()), idle_wheel, reload_version: version, }; + server._update_user_group_in_place(); Ok(server) } @@ -207,6 +212,37 @@ impl TlsStreamServer { } async fn run_task(&self, stream: TlsStream, cc_info: ClientConnectionInfo) { + let task_notes = if let Some(auth_match) = self.config.auth_match { + let ip = match auth_match { + FactsMatchType::ClientIp => cc_info.client_ip(), + FactsMatchType::ServerIp => cc_info.server_ip(), + FactsMatchType::ServerName => return, + }; + let Some((user, user_type)) = self + .user_group + .load() + .as_ref() + .and_then(|g| g.get_user_by_ip(ip)) + else { + // TODO log + return; + }; + let user_ctx = UserContext::new( + None, + user, + user_type, + self.config.name(), + self.server_stats.share_extra_tags(), + ); + if user_ctx.check_client_addr(cc_info.client_addr()).is_err() { + // TODO may be attack + return; + } + ServerTaskNotes::new(cc_info.clone(), Some(user_ctx), Duration::ZERO) + } else { + ServerTaskNotes::new(cc_info.clone(), None, Duration::ZERO) + }; + let upstream = self.select_consistent(&self.upstream, self.config.upstream_pick_policy, &cc_info); @@ -221,7 +257,7 @@ impl TlsStreamServer { task_logger: self.task_logger.clone(), }; - TlsStreamTask::new(ctx, upstream.inner(), self.audit_context()) + TlsStreamTask::new(ctx, upstream.inner(), self.audit_context(), task_notes) .into_running(stream) .await; } @@ -248,7 +284,24 @@ impl ServerInternal for TlsStreamServer { self.escaper.store(Arc::new(escaper)); } - fn _update_user_group_in_place(&self) {} + fn _update_user_group_in_place(&self) { + let user_group = if let Some(g) = self.config.get_user_group() { + let g_type = g.r#type(); + if let UserGroup::Facts(g) = g { + Some(g) + } else { + warn!( + "server {}: user group {}(type {g_type}) ignored", + self.config.name(), + self.config.user_group + ); + None + } + } else { + None + }; + self.user_group.store(user_group); + } fn _update_audit_handle_in_place(&self) -> anyhow::Result<()> { let audit_handle = self.config.get_audit_handle()?; @@ -368,7 +421,7 @@ impl Server for TlsStreamServer { } fn user_group(&self) -> &NodeName { - Default::default() + self.config.user_group() } fn auditor(&self) -> &NodeName { diff --git a/g3proxy/src/serve/tls_stream/task.rs b/g3proxy/src/serve/tls_stream/task.rs index bbadc3c4..af3a9473 100644 --- a/g3proxy/src/serve/tls_stream/task.rs +++ b/g3proxy/src/serve/tls_stream/task.rs @@ -3,6 +3,7 @@ * Copyright 2023-2025 ByteDance and/or its affiliates. */ +use std::borrow::Cow; use std::sync::Arc; use std::time::Duration; @@ -13,16 +14,21 @@ use tokio_rustls::server::TlsStream; use g3_daemon::server::ServerQuitPolicy; use g3_daemon::stat::task::TcpStreamTaskStats; use g3_io_ext::{AsyncStream, IdleInterval, LimitedReader, LimitedWriter, StreamCopyConfig}; +use g3_types::acl::AclAction; use g3_types::net::UpstreamAddr; use super::common::CommonTaskContext; use crate::audit::AuditContext; use crate::auth::User; +use crate::config::server::ServerConfig; use crate::inspect::{StreamInspectContext, StreamTransitTask}; use crate::log::task::tcp_connect::TaskLogForTcpConnect; use crate::module::tcp_connect::{TcpConnectTaskConf, TcpConnectTaskNotes, TlsConnectTaskConf}; use crate::serve::tcp_stream::{TcpStreamServerAliveTaskGuard, TcpStreamTaskCltWrapperStats}; -use crate::serve::{ServerTaskError, ServerTaskNotes, ServerTaskResult, ServerTaskStage}; +use crate::serve::{ + ServerStats, ServerTaskError, ServerTaskForbiddenError, ServerTaskNotes, ServerTaskResult, + ServerTaskStage, +}; pub(super) struct TlsStreamTask { ctx: CommonTaskContext, @@ -31,16 +37,26 @@ pub(super) struct TlsStreamTask { task_notes: ServerTaskNotes, task_stats: Arc, audit_ctx: AuditContext, + started: bool, _alive_guard: Option, } +impl Drop for TlsStreamTask { + fn drop(&mut self) { + if self.started { + self.post_stop(); + self.started = false; + } + } +} + impl TlsStreamTask { pub(super) fn new( ctx: CommonTaskContext, upstream: &UpstreamAddr, audit_ctx: AuditContext, + task_notes: ServerTaskNotes, ) -> Self { - let task_notes = ServerTaskNotes::new(ctx.cc_info.clone(), None, Duration::ZERO); TlsStreamTask { ctx, upstream: upstream.clone(), @@ -48,6 +64,7 @@ impl TlsStreamTask { task_notes, task_stats: Arc::new(TcpStreamTaskStats::default()), audit_ctx, + started: false, _alive_guard: None, } } @@ -82,18 +99,93 @@ impl TlsStreamTask { fn pre_start(&mut self) { self._alive_guard = Some(self.ctx.server_stats.add_task()); + if let Some(user_ctx) = self.task_notes.user_ctx() { + user_ctx.foreach_req_stats(|s| { + s.req_total.add_tcp_connect(); + s.req_alive.add_tcp_connect(); + }); + } + if self.ctx.server_config.flush_task_log_on_created && let Some(log_ctx) = self.get_log_context() { log_ctx.log_created(); } + + self.started = true; + } + + fn post_stop(&mut self) { + if let Some(user_ctx) = self.task_notes.user_ctx() { + user_ctx.foreach_req_stats(|s| { + s.req_alive.del_tcp_connect(); + }); + + if let Some(user_req_alive_permit) = self.task_notes.user_req_alive_permit.take() { + drop(user_req_alive_permit); + } + } + } + + async fn handle_user_upstream_acl_action(&mut self, action: AclAction) -> ServerTaskResult<()> { + let forbid = match action { + AclAction::Permit => false, + AclAction::PermitAndLog => { + // TODO log permit + false + } + AclAction::Forbid => true, + AclAction::ForbidAndLog => { + // TODO log forbid + true + } + }; + if forbid { + Err(ServerTaskError::ForbiddenByRule( + ServerTaskForbiddenError::DestDenied, + )) + } else { + Ok(()) + } } async fn run(&mut self, clt_stream: TlsStream) -> ServerTaskResult<()> { + let tcp_client_misc_opts; + + if let Some(user_ctx) = self.task_notes.user_ctx() { + let user_ctx = user_ctx.clone(); + + if user_ctx.check_rate_limit().is_err() { + return Err(ServerTaskError::ForbiddenByRule( + ServerTaskForbiddenError::RateLimited, + )); + } + + match user_ctx.acquire_request_semaphore() { + Ok(permit) => self.task_notes.user_req_alive_permit = Some(permit), + Err(_) => { + return Err(ServerTaskError::ForbiddenByRule( + ServerTaskForbiddenError::FullyLoaded, + )); + } + } + + let action = user_ctx.check_upstream(&self.upstream); + self.handle_user_upstream_acl_action(action).await?; + + let user_config = user_ctx.user_config(); + + tcp_client_misc_opts = + user_config.tcp_client_misc_opts(&self.ctx.server_config.tcp_misc_opts); + // + } else { + tcp_client_misc_opts = Cow::Borrowed(&self.ctx.server_config.tcp_misc_opts); + } + // set client side socket options self.ctx .cc_info - .tcp_sock_set_raw_opts(&self.ctx.server_config.tcp_misc_opts, true) + .tcp_sock_set_raw_opts(&tcp_client_misc_opts, true) .map_err(|_| { ServerTaskError::InternalServerError("failed to set client socket options") })?; @@ -175,28 +267,42 @@ impl TlsStreamTask { let (clt_r, clt_w) = self.split_clt(clt_stream); if let Some(audit_handle) = self.audit_ctx.check_take_handle() { - let ctx = StreamInspectContext::new( - audit_handle, - self.ctx.server_config.clone(), - self.ctx.server_stats.clone(), - self.ctx.server_quit_policy.clone(), - self.ctx.idle_wheel.clone(), - &self.task_notes, - &self.tcp_notes, - ); - crate::inspect::stream::transit_with_inspection( - clt_r, - clt_w, - ups_r, - ups_w, - ctx, - self.upstream.clone(), - None, - ) - .await - } else { - self.transit_transparent(clt_r, clt_w, ups_r, ups_w).await + let audit_task = self + .task_notes + .user_ctx() + .map(|ctx| { + let user_config = &ctx.user_config().audit; + user_config.enable_protocol_inspection + && user_config + .do_task_audit() + .unwrap_or_else(|| audit_handle.do_task_audit()) + }) + .unwrap_or_else(|| audit_handle.do_task_audit()); + + if audit_task { + let ctx = StreamInspectContext::new( + audit_handle, + self.ctx.server_config.clone(), + self.ctx.server_stats.clone(), + self.ctx.server_quit_policy.clone(), + self.ctx.idle_wheel.clone(), + &self.task_notes, + &self.tcp_notes, + ); + return crate::inspect::stream::transit_with_inspection( + clt_r, + clt_w, + ups_r, + ups_w, + ctx, + self.upstream.clone(), + None, + ) + .await; + } } + + self.transit_transparent(clt_r, clt_w, ups_r, ups_w).await } fn split_clt( @@ -208,24 +314,47 @@ impl TlsStreamTask { ) { let (clt_r, clt_w) = clt_stream.into_split(); - let wrapper_stats = + let mut wrapper_stats = TcpStreamTaskCltWrapperStats::new(&self.ctx.server_stats, &self.task_stats); - let wrapper_stats = Arc::new(wrapper_stats); - let clt_speed_limit = &self.ctx.server_config.tcp_sock_speed_limit; - let clt_r = LimitedReader::local_limited( + let limit_config = if let Some(user_ctx) = self.task_notes.user_ctx() { + wrapper_stats.push_user_io_stats(user_ctx.fetch_traffic_stats( + self.ctx.server_config.name(), + self.ctx.server_stats.share_extra_tags(), + )); + + user_ctx + .user_config() + .tcp_sock_speed_limit + .shrink_as_smaller(&self.ctx.server_config.tcp_sock_speed_limit) + } else { + self.ctx.server_config.tcp_sock_speed_limit + }; + + let wrapper_stats = Arc::new(wrapper_stats); + let mut clt_r = LimitedReader::local_limited( clt_r, - clt_speed_limit.shift_millis, - clt_speed_limit.max_north, + limit_config.shift_millis, + limit_config.max_north, wrapper_stats.clone(), ); - let clt_w = LimitedWriter::local_limited( + let mut clt_w = LimitedWriter::local_limited( clt_w, - clt_speed_limit.shift_millis, - clt_speed_limit.max_south, + limit_config.shift_millis, + limit_config.max_south, wrapper_stats, ); + if let Some(user_ctx) = self.task_notes.user_ctx() { + let user = user_ctx.user(); + if let Some(limiter) = user.tcp_all_upload_speed_limit() { + clt_r.add_global_limiter(limiter.clone()); + } + if let Some(limiter) = user.tcp_all_download_speed_limit() { + clt_w.add_global_limiter(limiter.clone()); + } + } + (clt_r, clt_w) } } @@ -270,6 +399,6 @@ impl StreamTransitTask for TlsStreamTask { } fn user(&self) -> Option<&User> { - None + self.task_notes.user_ctx().map(|ctx| ctx.user().as_ref()) } } diff --git a/sphinx/g3proxy/configuration/servers/tcp_stream.rst b/sphinx/g3proxy/configuration/servers/tcp_stream.rst index f2a33af5..0a447009 100644 --- a/sphinx/g3proxy/configuration/servers/tcp_stream.rst +++ b/sphinx/g3proxy/configuration/servers/tcp_stream.rst @@ -9,6 +9,13 @@ The following common keys are supported: * :ref:`escaper ` * :ref:`auditor ` +* :ref:`user_group ` + + The user group should be `facts` authenticate type. + It will be used only if `auth_by_client_ip` is set. + + .. versionadded:: 1.13.0 + * :ref:`shared_logger ` * :ref:`listen_in_worker ` * :ref:`tcp_sock_speed_limit ` @@ -77,3 +84,14 @@ Set an explicit tls server name to do upstream tls certificate verification. If not set, the host of upstream address will be used. **default**: not set + +auth_by_client_ip +----------------- + +**optional**, **type**: bool, **conflict**: auth_by_server_ip + +Enable facts user authenticate and use client IP as the authenticate fact. + +**default**: false + +.. versionadded:: 1.13.0 diff --git a/sphinx/g3proxy/configuration/servers/tls_stream.rst b/sphinx/g3proxy/configuration/servers/tls_stream.rst index a8299b9f..440d5274 100644 --- a/sphinx/g3proxy/configuration/servers/tls_stream.rst +++ b/sphinx/g3proxy/configuration/servers/tls_stream.rst @@ -9,6 +9,13 @@ The following common keys are supported: * :ref:`escaper ` * :ref:`auditor ` +* :ref:`user_group ` + + The user group should be `facts` authenticate type. + It will be used only if `auth_by_client_ip` is set. + + .. versionadded:: 1.13.0 + * :ref:`shared_logger ` * :ref:`listen_in_worker ` * :ref:`tls_server ` @@ -82,3 +89,14 @@ Set an explicit tls server name to do upstream tls certificate verification. If not set, the host of upstream address will be used. **default**: not set + +auth_by_client_ip +----------------- + +**optional**, **type**: bool, **conflict**: auth_by_server_ip + +Enable facts user authenticate and use client IP as the authenticate fact. + +**default**: false + +.. versionadded:: 1.13.0