g3bench: switch to use SslConnector

This commit is contained in:
Zhang Jingqiang 2023-12-19 11:18:49 +08:00
parent 9140f2b01a
commit 015015f466
12 changed files with 129 additions and 137 deletions

8
Cargo.lock generated
View file

@ -909,6 +909,7 @@ dependencies = [
"chrono",
"digest",
"flume",
"g3-openssl",
"g3-socket",
"g3-types",
"hex",
@ -922,7 +923,6 @@ dependencies = [
"slog",
"thiserror",
"tokio",
"tokio-tongsuo",
]
[[package]]
@ -1350,6 +1350,7 @@ dependencies = [
"g3-histogram",
"g3-http",
"g3-io-ext",
"g3-openssl",
"g3-runtime",
"g3-signal",
"g3-socket",
@ -1376,7 +1377,6 @@ dependencies = [
"rustls-pemfile",
"thiserror",
"tokio",
"tokio-tongsuo",
"tongsuo",
"url",
]
@ -3632,9 +3632,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
[[package]]
name = "winnow"
version = "0.5.28"
version = "0.5.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c830786f7720c2fd27a1a0e27a709dbd3c4d009b56d098fc742d4f4eab91fe2"
checksum = "9b5c3db89721d50d0e2a673f5043fc4722f76dcc352d7b1ab8b8288bed4ed2c5"
dependencies = [
"memchr",
]

View file

@ -25,7 +25,6 @@ quinn = { workspace = true, optional = true, features = ["tls-rustls", "runtime-
bytes.workspace = true
futures-util.workspace = true
atomic-waker.workspace = true
tokio-openssl.workspace = true
openssl.workspace = true
openssl-probe = { workspace = true, optional = true }
rustls = { workspace = true, optional = true }
@ -49,6 +48,7 @@ g3-io-ext.workspace = true
g3-statsd-client.workspace = true
g3-histogram.workspace = true
g3-tls-cert.workspace = true
g3-openssl.workspace = true
openssl-async-job.workspace = true
[build-dependencies]

View file

@ -14,23 +14,20 @@
* limitations under the License.
*/
use std::borrow::Cow;
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::str::FromStr;
use std::time::Duration;
use anyhow::{anyhow, Context};
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
use http::{Method, StatusCode};
use openssl::ssl::SslVerifyMode;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
use url::Url;
use g3_io_ext::AggregatedIo;
use g3_openssl::SslStream;
use g3_types::collection::{SelectiveVec, WeightedValue};
use g3_types::net::{
HttpAuth, HttpProxy, OpensslClientConfig, OpensslClientConfigBuilder, Proxy, UpstreamAddr,
@ -305,25 +302,10 @@ impl BenchHttpArgs {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let tls_name = self
let tls_stream = self
.target_tls
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| self.host.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, self.host.port())
.context("failed to build ssl context")?;
if self.target_tls.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let mut tls_stream = SslStream::new(ssl, stream)
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Pin::new(&mut tls_stream)
.connect()
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
.connect_target(tls_client, stream, &self.host)
.await?;
let (r, w) = tokio::io::split(tls_stream);
Ok((Box::new(r), Box::new(w)))
}
@ -334,26 +316,9 @@ impl BenchHttpArgs {
peer: &UpstreamAddr,
stream: TcpStream,
) -> anyhow::Result<SslStream<TcpStream>> {
let tls_name = self
.proxy_tls
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| peer.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, peer.port())
.context("failed to build ssl context")?;
if self.proxy_tls.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let mut tls_stream = SslStream::new(ssl, stream)
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Pin::new(&mut tls_stream)
.connect()
self.proxy_tls
.connect_target(tls_client, stream, peer)
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Ok(tls_stream)
}
fn write_request_line<W: io::Write>(&self, buf: &mut W) -> io::Result<()> {

View file

@ -14,9 +14,7 @@
* limitations under the License.
*/
use std::borrow::Cow;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
@ -26,13 +24,12 @@ use bytes::Bytes;
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
use h2::client::SendRequest;
use http::{HeaderValue, Method, StatusCode};
use openssl::ssl::SslVerifyMode;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
use url::Url;
use g3_io_ext::{AggregatedIo, LimitedStream};
use g3_openssl::SslStream;
use g3_types::collection::{SelectiveVec, WeightedValue};
use g3_types::net::{
AlpnProtocol, HttpAuth, OpensslClientConfig, OpensslClientConfigBuilder, Proxy, UpstreamAddr,
@ -309,24 +306,10 @@ impl BenchH2Args {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let tls_name = self
let tls_stream = self
.target_tls
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| self.host.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, self.host.port())
.context("failed to build ssl context")?;
if self.target_tls.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let mut tls_stream = SslStream::new(ssl, stream)
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Pin::new(&mut tls_stream)
.connect()
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
.connect_target(tls_client, stream, &self.host)
.await?;
if let Some(alpn) = tls_stream.ssl().selected_alpn_protocol() {
if AlpnProtocol::from_buf(alpn) != Some(AlpnProtocol::Http2) {
return Err(anyhow!("invalid returned alpn protocol: {:?}", alpn));
@ -341,26 +324,9 @@ impl BenchH2Args {
peer: &UpstreamAddr,
stream: TcpStream,
) -> anyhow::Result<SslStream<TcpStream>> {
let tls_name = self
.proxy_tls
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| peer.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, peer.port())
.context("failed to build ssl context")?;
if self.proxy_tls.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let mut tls_stream = SslStream::new(ssl, stream)
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Pin::new(&mut tls_stream)
.connect()
self.proxy_tls
.connect_target(tls_client, stream, peer)
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Ok(tls_stream)
}
pub(super) fn build_pre_request_header(&self) -> anyhow::Result<H2PreRequest> {

View file

@ -14,18 +14,15 @@
* limitations under the License.
*/
use std::borrow::Cow;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::time::Duration;
use anyhow::{anyhow, Context};
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
use openssl::ssl::SslVerifyMode;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
use g3_openssl::SslStream;
use g3_types::collection::{SelectiveVec, WeightedValue};
use g3_types::net::{OpensslClientConfig, OpensslClientConfigBuilder, UpstreamAddr};
@ -160,25 +157,9 @@ impl KeylessCloudflareArgs {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let tls_name = self
.tls
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| self.target.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, self.target.port())
.context("failed to build ssl context")?;
if self.tls.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let mut tls_stream = SslStream::new(ssl, stream)
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Pin::new(&mut tls_stream)
.connect()
self.tls
.connect_target(tls_client, stream, &self.target)
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Ok(tls_stream)
}
}

View file

@ -14,6 +14,7 @@
* limitations under the License.
*/
use std::borrow::Cow;
use std::fs::File;
use std::io::Read;
use std::path::{Path, PathBuf};
@ -22,11 +23,14 @@ use std::str::FromStr;
use anyhow::{anyhow, Context};
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command, ValueHint};
use openssl::pkey::{PKey, Private};
use openssl::ssl::SslVerifyMode;
use openssl::x509::X509;
use tokio::io::{AsyncRead, AsyncWrite};
use g3_openssl::{SslConnector, SslStream};
use g3_types::net::{
AlpnProtocol, OpensslCertificatePair, OpensslClientConfig, OpensslClientConfigBuilder,
OpensslProtocol,
OpensslProtocol, UpstreamAddr,
};
const TLS_ARG_CA_CERT: &str = "tls-ca-cert";
@ -71,6 +75,35 @@ pub(crate) struct OpensslTlsClientArgs {
}
impl OpensslTlsClientArgs {
pub(crate) async fn connect_target<S>(
&self,
tls_client: &OpensslClientConfig,
stream: S,
target: &UpstreamAddr,
) -> anyhow::Result<SslStream<S>>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let tls_name = self
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| target.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, target.port())
.context("failed to build ssl context")?;
if self.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let tls_connector = SslConnector::new(ssl, stream)
.map_err(|e| anyhow!("tls connector create failed: {e}"))?;
let tls_stream = tls_connector
.connect()
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Ok(tls_stream)
}
fn parse_tls_name(&mut self, args: &ArgMatches, id: &str) {
if let Some(name) = args.get_one::<String>(id) {
self.tls_name = Some(name.to_string());

View file

@ -14,18 +14,15 @@
* limitations under the License.
*/
use std::borrow::Cow;
use std::net::{IpAddr, SocketAddr};
use std::pin::Pin;
use std::time::Duration;
use anyhow::{anyhow, Context};
use clap::{value_parser, Arg, ArgMatches, Command};
use openssl::ssl::SslVerifyMode;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
use g3_openssl::SslStream;
use g3_types::collection::{SelectiveVec, WeightedValue};
use g3_types::net::{OpensslClientConfig, OpensslClientConfigBuilder, UpstreamAddr};
@ -117,25 +114,9 @@ impl BenchSslArgs {
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let tls_name = self
.tls
.tls_name
.as_ref()
.map(|v| Cow::Borrowed(v.as_str()))
.unwrap_or_else(|| self.target.host_str());
let mut ssl = tls_client
.build_ssl(&tls_name, self.target.port())
.context("failed to build ssl context")?;
if self.tls.no_verify {
ssl.set_verify(SslVerifyMode::NONE);
}
let mut tls_stream = SslStream::new(ssl, stream)
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Pin::new(&mut tls_stream)
.connect()
self.tls
.connect_target(tls_client, stream, &self.target)
.await
.map_err(|e| anyhow!("tls connect to {tls_name} failed: {e}"))?;
Ok(tls_stream)
}
}

View file

@ -17,7 +17,6 @@ flume = { workspace = true, features = ["async"] }
rmp.workspace = true
rmp-serde.workspace = true
serde.workspace = true
tokio-openssl.workspace = true
tokio = { workspace = true, features = ["rt", "net", "time", "macros"] }
rand.workspace = true
digest.workspace = true
@ -26,3 +25,4 @@ hex.workspace = true
log.workspace = true
g3-types = { workspace = true, features = ["async-log", "openssl"] }
g3-socket.workspace = true
g3-openssl.workspace = true

View file

@ -15,7 +15,6 @@
*/
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::pin::Pin;
use std::time::Duration;
use anyhow::{anyhow, Context};
@ -24,8 +23,8 @@ use rand::Rng;
use rmp::encode::ValueWriteError;
use sha2::Sha512;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_openssl::SslStream;
use g3_openssl::SslConnector;
use g3_types::net::{OpensslClientConfig, OpensslClientConfigBuilder, TcpKeepAliveConfig};
use super::FluentdConnection;
@ -155,9 +154,9 @@ impl FluentdClientConfig {
let tls = tls_client
.build_ssl(&tls_name, self.server_addr.port())
.map_err(|e| anyhow!("failed to build ssl context: {e}"))?;
let mut tls_stream = SslStream::new(tls, tcp_stream)
let tls_connector = SslConnector::new(tls, tcp_stream)
.map_err(|e| anyhow!("failed to setup ssl context: {e}"))?;
Pin::new(&mut tls_stream)
let tls_stream = tls_connector
.connect()
.await
.map_err(|e| anyhow!("failed to tls connect to peer {tls_name}: {e}"))?;

View file

@ -23,8 +23,8 @@ use flume::Receiver;
use log::warn;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
use g3_openssl::SslStream;
use g3_types::log::{AsyncLogConfig, AsyncLogger, LogStats};
mod config;

View file

@ -0,0 +1,64 @@
/*
* Copyright 2023 ByteDance and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::future;
use std::task::{Context, Poll};
use openssl::error::ErrorStack;
use openssl::ssl::{self, Ssl};
use openssl_sys::{SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE};
use tokio::io::{AsyncRead, AsyncWrite};
use super::error::{SSL_ERROR_WANT_ASYNC, SSL_ERROR_WANT_ASYNC_JOB};
use super::{SslIoWrapper, SslStream};
pub struct SslConnector<S> {
inner: ssl::SslStream<SslIoWrapper<S>>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> SslConnector<S> {
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
let wrapper = SslIoWrapper::new(stream);
ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector { inner })
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> SslConnector<S> {
pub fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
self.inner.get_mut().set_cx(cx);
match self.inner.connect() {
Ok(_) => Poll::Ready(Ok(())),
Err(e) => match e.code().as_raw() {
SSL_ERROR_WANT_READ | SSL_ERROR_WANT_WRITE => Poll::Pending,
SSL_ERROR_WANT_ASYNC => {
// TODO
todo!()
}
SSL_ERROR_WANT_ASYNC_JOB => {
cx.waker().wake_by_ref();
Poll::Pending
}
_ => Poll::Ready(Err(e)),
},
}
}
pub async fn connect(mut self) -> Result<SslStream<S>, ssl::Error> {
future::poll_fn(|cx| self.poll_connect(cx)).await?;
Ok(SslStream::new(self.inner))
}
}

View file

@ -24,3 +24,6 @@ pub use stream::SslStream;
mod accept;
pub use accept::SslAcceptor;
mod connect;
pub use connect::SslConnector;