From 015015f4669da5fa3d23772c5cd852d163ebedbc Mon Sep 17 00:00:00 2001 From: Zhang Jingqiang Date: Tue, 19 Dec 2023 11:18:49 +0800 Subject: [PATCH] g3bench: switch to use SslConnector --- Cargo.lock | 8 +-- g3bench/Cargo.toml | 2 +- g3bench/src/target/h1/opts.rs | 47 ++------------ g3bench/src/target/h2/opts.rs | 46 ++----------- g3bench/src/target/keyless/cloudflare/opts.rs | 25 +------- g3bench/src/target/openssl.rs | 35 +++++++++- g3bench/src/target/ssl/opts.rs | 25 +------- lib/g3-fluentd/Cargo.toml | 2 +- lib/g3-fluentd/src/config.rs | 7 +- lib/g3-fluentd/src/lib.rs | 2 +- lib/g3-openssl/src/connect.rs | 64 +++++++++++++++++++ lib/g3-openssl/src/lib.rs | 3 + 12 files changed, 129 insertions(+), 137 deletions(-) create mode 100644 lib/g3-openssl/src/connect.rs diff --git a/Cargo.lock b/Cargo.lock index efb00912..46be1d9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -909,6 +909,7 @@ dependencies = [ "chrono", "digest", "flume", + "g3-openssl", "g3-socket", "g3-types", "hex", @@ -922,7 +923,6 @@ dependencies = [ "slog", "thiserror", "tokio", - "tokio-tongsuo", ] [[package]] @@ -1350,6 +1350,7 @@ dependencies = [ "g3-histogram", "g3-http", "g3-io-ext", + "g3-openssl", "g3-runtime", "g3-signal", "g3-socket", @@ -1376,7 +1377,6 @@ dependencies = [ "rustls-pemfile", "thiserror", "tokio", - "tokio-tongsuo", "tongsuo", "url", ] @@ -3632,9 +3632,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.28" +version = "0.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c830786f7720c2fd27a1a0e27a709dbd3c4d009b56d098fc742d4f4eab91fe2" +checksum = "9b5c3db89721d50d0e2a673f5043fc4722f76dcc352d7b1ab8b8288bed4ed2c5" dependencies = [ "memchr", ] diff --git a/g3bench/Cargo.toml b/g3bench/Cargo.toml index 8e4e43c8..d11fbaf1 100644 --- a/g3bench/Cargo.toml +++ b/g3bench/Cargo.toml @@ -25,7 +25,6 @@ quinn = { workspace = true, optional = true, features = ["tls-rustls", "runtime- bytes.workspace = true futures-util.workspace = true atomic-waker.workspace = true -tokio-openssl.workspace = true openssl.workspace = true openssl-probe = { workspace = true, optional = true } rustls = { workspace = true, optional = true } @@ -49,6 +48,7 @@ g3-io-ext.workspace = true g3-statsd-client.workspace = true g3-histogram.workspace = true g3-tls-cert.workspace = true +g3-openssl.workspace = true openssl-async-job.workspace = true [build-dependencies] diff --git a/g3bench/src/target/h1/opts.rs b/g3bench/src/target/h1/opts.rs index 3d917e14..1709dc9d 100644 --- a/g3bench/src/target/h1/opts.rs +++ b/g3bench/src/target/h1/opts.rs @@ -14,23 +14,20 @@ * limitations under the License. */ -use std::borrow::Cow; use std::io; use std::net::{IpAddr, SocketAddr}; -use std::pin::Pin; use std::str::FromStr; use std::time::Duration; use anyhow::{anyhow, Context}; use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; use http::{Method, StatusCode}; -use openssl::ssl::SslVerifyMode; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; -use tokio_openssl::SslStream; use url::Url; use g3_io_ext::AggregatedIo; +use g3_openssl::SslStream; use g3_types::collection::{SelectiveVec, WeightedValue}; use g3_types::net::{ HttpAuth, HttpProxy, OpensslClientConfig, OpensslClientConfigBuilder, Proxy, UpstreamAddr, @@ -305,25 +302,10 @@ impl BenchHttpArgs { where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let tls_name = self + let tls_stream = self .target_tls - .tls_name - .as_ref() - .map(|v| Cow::Borrowed(v.as_str())) - .unwrap_or_else(|| self.host.host_str()); - let mut ssl = tls_client - .build_ssl(&tls_name, self.host.port()) - .context("failed to build ssl context")?; - if self.target_tls.no_verify { - ssl.set_verify(SslVerifyMode::NONE); - } - let mut tls_stream = SslStream::new(ssl, stream) - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Pin::new(&mut tls_stream) - .connect() - .await - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - + .connect_target(tls_client, stream, &self.host) + .await?; let (r, w) = tokio::io::split(tls_stream); Ok((Box::new(r), Box::new(w))) } @@ -334,26 +316,9 @@ impl BenchHttpArgs { peer: &UpstreamAddr, stream: TcpStream, ) -> anyhow::Result> { - let tls_name = self - .proxy_tls - .tls_name - .as_ref() - .map(|v| Cow::Borrowed(v.as_str())) - .unwrap_or_else(|| peer.host_str()); - let mut ssl = tls_client - .build_ssl(&tls_name, peer.port()) - .context("failed to build ssl context")?; - if self.proxy_tls.no_verify { - ssl.set_verify(SslVerifyMode::NONE); - } - let mut tls_stream = SslStream::new(ssl, stream) - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Pin::new(&mut tls_stream) - .connect() + self.proxy_tls + .connect_target(tls_client, stream, peer) .await - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - - Ok(tls_stream) } fn write_request_line(&self, buf: &mut W) -> io::Result<()> { diff --git a/g3bench/src/target/h2/opts.rs b/g3bench/src/target/h2/opts.rs index f483222f..c6199eda 100644 --- a/g3bench/src/target/h2/opts.rs +++ b/g3bench/src/target/h2/opts.rs @@ -14,9 +14,7 @@ * limitations under the License. */ -use std::borrow::Cow; use std::net::{IpAddr, SocketAddr}; -use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -26,13 +24,12 @@ use bytes::Bytes; use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; use h2::client::SendRequest; use http::{HeaderValue, Method, StatusCode}; -use openssl::ssl::SslVerifyMode; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; use tokio::net::TcpStream; -use tokio_openssl::SslStream; use url::Url; use g3_io_ext::{AggregatedIo, LimitedStream}; +use g3_openssl::SslStream; use g3_types::collection::{SelectiveVec, WeightedValue}; use g3_types::net::{ AlpnProtocol, HttpAuth, OpensslClientConfig, OpensslClientConfigBuilder, Proxy, UpstreamAddr, @@ -309,24 +306,10 @@ impl BenchH2Args { where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let tls_name = self + let tls_stream = self .target_tls - .tls_name - .as_ref() - .map(|v| Cow::Borrowed(v.as_str())) - .unwrap_or_else(|| self.host.host_str()); - let mut ssl = tls_client - .build_ssl(&tls_name, self.host.port()) - .context("failed to build ssl context")?; - if self.target_tls.no_verify { - ssl.set_verify(SslVerifyMode::NONE); - } - let mut tls_stream = SslStream::new(ssl, stream) - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Pin::new(&mut tls_stream) - .connect() - .await - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; + .connect_target(tls_client, stream, &self.host) + .await?; if let Some(alpn) = tls_stream.ssl().selected_alpn_protocol() { if AlpnProtocol::from_buf(alpn) != Some(AlpnProtocol::Http2) { return Err(anyhow!("invalid returned alpn protocol: {:?}", alpn)); @@ -341,26 +324,9 @@ impl BenchH2Args { peer: &UpstreamAddr, stream: TcpStream, ) -> anyhow::Result> { - let tls_name = self - .proxy_tls - .tls_name - .as_ref() - .map(|v| Cow::Borrowed(v.as_str())) - .unwrap_or_else(|| peer.host_str()); - let mut ssl = tls_client - .build_ssl(&tls_name, peer.port()) - .context("failed to build ssl context")?; - if self.proxy_tls.no_verify { - ssl.set_verify(SslVerifyMode::NONE); - } - let mut tls_stream = SslStream::new(ssl, stream) - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Pin::new(&mut tls_stream) - .connect() + self.proxy_tls + .connect_target(tls_client, stream, peer) .await - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - - Ok(tls_stream) } pub(super) fn build_pre_request_header(&self) -> anyhow::Result { diff --git a/g3bench/src/target/keyless/cloudflare/opts.rs b/g3bench/src/target/keyless/cloudflare/opts.rs index 414c6aff..0f39630e 100644 --- a/g3bench/src/target/keyless/cloudflare/opts.rs +++ b/g3bench/src/target/keyless/cloudflare/opts.rs @@ -14,18 +14,15 @@ * limitations under the License. */ -use std::borrow::Cow; use std::net::{IpAddr, SocketAddr}; -use std::pin::Pin; use std::time::Duration; use anyhow::{anyhow, Context}; use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; -use openssl::ssl::SslVerifyMode; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_openssl::SslStream; +use g3_openssl::SslStream; use g3_types::collection::{SelectiveVec, WeightedValue}; use g3_types::net::{OpensslClientConfig, OpensslClientConfigBuilder, UpstreamAddr}; @@ -160,25 +157,9 @@ impl KeylessCloudflareArgs { where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let tls_name = self - .tls - .tls_name - .as_ref() - .map(|v| Cow::Borrowed(v.as_str())) - .unwrap_or_else(|| self.target.host_str()); - let mut ssl = tls_client - .build_ssl(&tls_name, self.target.port()) - .context("failed to build ssl context")?; - if self.tls.no_verify { - ssl.set_verify(SslVerifyMode::NONE); - } - let mut tls_stream = SslStream::new(ssl, stream) - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Pin::new(&mut tls_stream) - .connect() + self.tls + .connect_target(tls_client, stream, &self.target) .await - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Ok(tls_stream) } } diff --git a/g3bench/src/target/openssl.rs b/g3bench/src/target/openssl.rs index b12d66ca..4677e5dd 100644 --- a/g3bench/src/target/openssl.rs +++ b/g3bench/src/target/openssl.rs @@ -14,6 +14,7 @@ * limitations under the License. */ +use std::borrow::Cow; use std::fs::File; use std::io::Read; use std::path::{Path, PathBuf}; @@ -22,11 +23,14 @@ use std::str::FromStr; use anyhow::{anyhow, Context}; use clap::{value_parser, Arg, ArgAction, ArgMatches, Command, ValueHint}; use openssl::pkey::{PKey, Private}; +use openssl::ssl::SslVerifyMode; use openssl::x509::X509; +use tokio::io::{AsyncRead, AsyncWrite}; +use g3_openssl::{SslConnector, SslStream}; use g3_types::net::{ AlpnProtocol, OpensslCertificatePair, OpensslClientConfig, OpensslClientConfigBuilder, - OpensslProtocol, + OpensslProtocol, UpstreamAddr, }; const TLS_ARG_CA_CERT: &str = "tls-ca-cert"; @@ -71,6 +75,35 @@ pub(crate) struct OpensslTlsClientArgs { } impl OpensslTlsClientArgs { + pub(crate) async fn connect_target( + &self, + tls_client: &OpensslClientConfig, + stream: S, + target: &UpstreamAddr, + ) -> anyhow::Result> + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + let tls_name = self + .tls_name + .as_ref() + .map(|v| Cow::Borrowed(v.as_str())) + .unwrap_or_else(|| target.host_str()); + let mut ssl = tls_client + .build_ssl(&tls_name, target.port()) + .context("failed to build ssl context")?; + if self.no_verify { + ssl.set_verify(SslVerifyMode::NONE); + } + let tls_connector = SslConnector::new(ssl, stream) + .map_err(|e| anyhow!("tls connector create failed: {e}"))?; + let tls_stream = tls_connector + .connect() + .await + .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; + Ok(tls_stream) + } + fn parse_tls_name(&mut self, args: &ArgMatches, id: &str) { if let Some(name) = args.get_one::(id) { self.tls_name = Some(name.to_string()); diff --git a/g3bench/src/target/ssl/opts.rs b/g3bench/src/target/ssl/opts.rs index 3f339494..c03f9a05 100644 --- a/g3bench/src/target/ssl/opts.rs +++ b/g3bench/src/target/ssl/opts.rs @@ -14,18 +14,15 @@ * limitations under the License. */ -use std::borrow::Cow; use std::net::{IpAddr, SocketAddr}; -use std::pin::Pin; use std::time::Duration; use anyhow::{anyhow, Context}; use clap::{value_parser, Arg, ArgMatches, Command}; -use openssl::ssl::SslVerifyMode; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_openssl::SslStream; +use g3_openssl::SslStream; use g3_types::collection::{SelectiveVec, WeightedValue}; use g3_types::net::{OpensslClientConfig, OpensslClientConfigBuilder, UpstreamAddr}; @@ -117,25 +114,9 @@ impl BenchSslArgs { where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - let tls_name = self - .tls - .tls_name - .as_ref() - .map(|v| Cow::Borrowed(v.as_str())) - .unwrap_or_else(|| self.target.host_str()); - let mut ssl = tls_client - .build_ssl(&tls_name, self.target.port()) - .context("failed to build ssl context")?; - if self.tls.no_verify { - ssl.set_verify(SslVerifyMode::NONE); - } - let mut tls_stream = SslStream::new(ssl, stream) - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Pin::new(&mut tls_stream) - .connect() + self.tls + .connect_target(tls_client, stream, &self.target) .await - .map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?; - Ok(tls_stream) } } diff --git a/lib/g3-fluentd/Cargo.toml b/lib/g3-fluentd/Cargo.toml index 697ab054..b66d12f4 100644 --- a/lib/g3-fluentd/Cargo.toml +++ b/lib/g3-fluentd/Cargo.toml @@ -17,7 +17,6 @@ flume = { workspace = true, features = ["async"] } rmp.workspace = true rmp-serde.workspace = true serde.workspace = true -tokio-openssl.workspace = true tokio = { workspace = true, features = ["rt", "net", "time", "macros"] } rand.workspace = true digest.workspace = true @@ -26,3 +25,4 @@ hex.workspace = true log.workspace = true g3-types = { workspace = true, features = ["async-log", "openssl"] } g3-socket.workspace = true +g3-openssl.workspace = true diff --git a/lib/g3-fluentd/src/config.rs b/lib/g3-fluentd/src/config.rs index 1dd2816c..0a34998e 100644 --- a/lib/g3-fluentd/src/config.rs +++ b/lib/g3-fluentd/src/config.rs @@ -15,7 +15,6 @@ */ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::pin::Pin; use std::time::Duration; use anyhow::{anyhow, Context}; @@ -24,8 +23,8 @@ use rand::Rng; use rmp::encode::ValueWriteError; use sha2::Sha512; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio_openssl::SslStream; +use g3_openssl::SslConnector; use g3_types::net::{OpensslClientConfig, OpensslClientConfigBuilder, TcpKeepAliveConfig}; use super::FluentdConnection; @@ -155,9 +154,9 @@ impl FluentdClientConfig { let tls = tls_client .build_ssl(&tls_name, self.server_addr.port()) .map_err(|e| anyhow!("failed to build ssl context: {e}"))?; - let mut tls_stream = SslStream::new(tls, tcp_stream) + let tls_connector = SslConnector::new(tls, tcp_stream) .map_err(|e| anyhow!("failed to setup ssl context: {e}"))?; - Pin::new(&mut tls_stream) + let tls_stream = tls_connector .connect() .await .map_err(|e| anyhow!("failed to tls connect to peer {tls_name}: {e}"))?; diff --git a/lib/g3-fluentd/src/lib.rs b/lib/g3-fluentd/src/lib.rs index da32115b..4378597f 100644 --- a/lib/g3-fluentd/src/lib.rs +++ b/lib/g3-fluentd/src/lib.rs @@ -23,8 +23,8 @@ use flume::Receiver; use log::warn; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio_openssl::SslStream; +use g3_openssl::SslStream; use g3_types::log::{AsyncLogConfig, AsyncLogger, LogStats}; mod config; diff --git a/lib/g3-openssl/src/connect.rs b/lib/g3-openssl/src/connect.rs new file mode 100644 index 00000000..ecba9241 --- /dev/null +++ b/lib/g3-openssl/src/connect.rs @@ -0,0 +1,64 @@ +/* + * Copyright 2023 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::future; +use std::task::{Context, Poll}; + +use openssl::error::ErrorStack; +use openssl::ssl::{self, Ssl}; +use openssl_sys::{SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use super::error::{SSL_ERROR_WANT_ASYNC, SSL_ERROR_WANT_ASYNC_JOB}; +use super::{SslIoWrapper, SslStream}; + +pub struct SslConnector { + inner: ssl::SslStream>, +} + +impl SslConnector { + pub fn new(ssl: Ssl, stream: S) -> Result { + let wrapper = SslIoWrapper::new(stream); + ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector { inner }) + } +} + +impl SslConnector { + pub fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.get_mut().set_cx(cx); + + match self.inner.connect() { + Ok(_) => Poll::Ready(Ok(())), + Err(e) => match e.code().as_raw() { + SSL_ERROR_WANT_READ | SSL_ERROR_WANT_WRITE => Poll::Pending, + SSL_ERROR_WANT_ASYNC => { + // TODO + todo!() + } + SSL_ERROR_WANT_ASYNC_JOB => { + cx.waker().wake_by_ref(); + Poll::Pending + } + _ => Poll::Ready(Err(e)), + }, + } + } + + pub async fn connect(mut self) -> Result, ssl::Error> { + future::poll_fn(|cx| self.poll_connect(cx)).await?; + Ok(SslStream::new(self.inner)) + } +} diff --git a/lib/g3-openssl/src/lib.rs b/lib/g3-openssl/src/lib.rs index a65d271e..6e6973d4 100644 --- a/lib/g3-openssl/src/lib.rs +++ b/lib/g3-openssl/src/lib.rs @@ -24,3 +24,6 @@ pub use stream::SslStream; mod accept; pub use accept::SslAcceptor; + +mod connect; +pub use connect::SslConnector;