diff --git a/lib/g3-openssl/build.rs b/lib/g3-openssl/build.rs new file mode 100644 index 00000000..5fc28159 --- /dev/null +++ b/lib/g3-openssl/build.rs @@ -0,0 +1,28 @@ +/* + * Copyright 2023 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::env; + +#[allow(clippy::unusual_byte_groupings)] +fn main() { + if let Ok(version) = env::var("DEP_OPENSSL_VERSION_NUMBER") { + let version = u64::from_str_radix(&version, 16).unwrap(); + + if version >= 0x3_00_00_00_0 { + println!("cargo:rustc-cfg=ossl300"); + } + } +} diff --git a/lib/g3-openssl/src/accept.rs b/lib/g3-openssl/src/accept.rs index bc080135..e01cbcdf 100644 --- a/lib/g3-openssl/src/accept.rs +++ b/lib/g3-openssl/src/accept.rs @@ -15,48 +15,80 @@ */ use std::future; -use std::task::{Context, Poll}; +use std::io; +use std::task::{ready, Context, Poll}; use openssl::error::ErrorStack; use openssl::ssl::{self, ErrorCode, Ssl}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{SslIoWrapper, SslStream}; +use super::{AsyncEnginePoller, SslIoWrapper, SslStream}; pub struct SslAcceptor { inner: ssl::SslStream>, + async_engine: Option, } impl SslAcceptor { + #[cfg(not(ossl300))] pub fn new(ssl: Ssl, stream: S) -> Result { let wrapper = SslIoWrapper::new(stream); - ssl::SslStream::new(ssl, wrapper).map(|inner| SslAcceptor { inner }) + let async_engine = AsyncEnginePoller::new(&ssl); + + ssl::SslStream::new(ssl, wrapper).map(|inner| SslAcceptor { + inner, + async_engine, + }) + } + + #[cfg(ossl300)] + pub fn new(ssl: Ssl, stream: S) -> Result { + let wrapper = SslIoWrapper::new(stream); + let async_engine = AsyncEnginePoller::new(&ssl)?; + + ssl::SslStream::new(ssl, wrapper).map(|inner| SslAcceptor { + inner, + async_engine, + }) } } impl SslAcceptor { - pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.get_mut().set_cx(cx); + #[cfg(ossl300)] + if let Some(async_engine) = &self.async_engine { + async_engine.set_cx(cx); + } - match self.inner.accept() { - Ok(_) => Poll::Ready(Ok(())), - Err(e) => match e.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending, - ErrorCode::WANT_ASYNC => { - // TODO - todo!() - } - ErrorCode::WANT_ASYNC_JOB => { - cx.waker().wake_by_ref(); - Poll::Pending - } - _ => Poll::Ready(Err(e)), - }, + loop { + match self.inner.accept() { + Ok(_) => return Poll::Ready(Ok(())), + Err(e) => match e.code() { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending, + ErrorCode::WANT_ASYNC => { + if let Some(async_engine) = &mut self.async_engine { + ready!(async_engine.poll_ready(self.inner.ssl(), cx))? + } else { + return Poll::Ready(Err(io::Error::other( + "async engine poller is not set", + ))); + } + } + ErrorCode::WANT_ASYNC_JOB => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + _ => { + return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other))) + } + }, + } } } - pub async fn accept(mut self) -> Result, ssl::Error> { + pub async fn accept(mut self) -> io::Result> { future::poll_fn(|cx| self.poll_accept(cx)).await?; - Ok(SslStream::new(self.inner)) + Ok(SslStream::new(self.inner, self.async_engine)) } } diff --git a/lib/g3-openssl/src/async_mode.rs b/lib/g3-openssl/src/async_mode.rs new file mode 100644 index 00000000..8365438f --- /dev/null +++ b/lib/g3-openssl/src/async_mode.rs @@ -0,0 +1,227 @@ +/* + * Copyright 2023 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::os::fd::RawFd; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{io, ptr}; + +#[cfg(ossl300)] +use atomic_waker::AtomicWaker; +use libc::c_int; +#[cfg(ossl300)] +use libc::c_void; +use openssl::error::ErrorStack; +use openssl::foreign_types::ForeignTypeRef; +use openssl::ssl::SslRef; +#[cfg(ossl300)] +use openssl_sys::SSL; +use tokio::io::unix::AsyncFd; +use tokio::io::Interest; + +use crate::ffi; + +pub trait SslAsyncModeExt { + fn is_async(&self) -> bool; + fn waiting_for_async(&self) -> bool; + #[cfg(ossl300)] + fn async_status(&self) -> c_int; + #[cfg(ossl300)] + fn set_async_engine_waker(&self, waker: &Arc) -> Result<(), ErrorStack>; + fn get_changed_fds(&self) -> Result<(Vec, Vec), ErrorStack>; +} + +impl SslAsyncModeExt for SslRef { + fn is_async(&self) -> bool { + unsafe { (ffi::SSL_get_mode(self.as_ptr()) & 0x00000100) != 0 } + } + + fn waiting_for_async(&self) -> bool { + unsafe { ffi::SSL_waiting_for_async(self.as_ptr()) == 1 } + } + + #[cfg(ossl300)] + fn async_status(&self) -> c_int { + unsafe { ffi::SSL_get_async_status(self.as_ptr()) } + } + + #[cfg(ossl300)] + fn set_async_engine_waker(&self, waker: &Arc) -> Result<(), ErrorStack> { + let r = unsafe { ffi::SSL_set_async_callback(self.as_ptr(), Some(async_engine_wake)) }; + if r != 1 { + return Err(ErrorStack::get()); + } + + let r = unsafe { + ffi::SSL_set_async_callback_arg(self.as_ptr(), Arc::as_ptr(waker) as *mut c_void) + }; + if r != 1 { + Err(ErrorStack::get()) + } else { + Ok(()) + } + } + + fn get_changed_fds(&self) -> Result<(Vec, Vec), ErrorStack> { + let mut add_fd_count = 0usize; + let mut del_fd_count = 0usize; + let r = unsafe { + ffi::SSL_get_changed_async_fds( + self.as_ptr(), + ptr::null_mut(), + &mut add_fd_count as *mut usize, + ptr::null_mut(), + &mut del_fd_count as *mut usize, + ) + }; + if r != 1 { + return Err(ErrorStack::get()); + } + + let mut add_fds: Vec = vec![0; add_fd_count]; + let mut del_fds: Vec = vec![0; del_fd_count]; + let r = unsafe { + ffi::SSL_get_changed_async_fds( + self.as_ptr(), + add_fds.as_mut_ptr(), + &mut add_fd_count as *mut usize, + del_fds.as_mut_ptr(), + &mut del_fd_count as *mut usize, + ) + }; + if r != 1 { + return Err(ErrorStack::get()); + } + + Ok(( + add_fds.into_iter().map(RawFd::from).collect(), + del_fds.into_iter().map(RawFd::from).collect(), + )) + } +} + +pub(crate) struct AsyncEnginePoller { + tracked_fds: Vec>, + #[cfg(ossl300)] + atomic_waker: Arc, +} + +impl AsyncEnginePoller { + #[cfg(not(ossl300))] + pub(crate) fn new(ssl: &SslRef) -> Option { + if ssl.is_async() { + Some(AsyncEnginePoller { + tracked_fds: Vec::with_capacity(1), + }) + } else { + None + } + } + + #[cfg(ossl300)] + pub(crate) fn new(ssl: &SslRef) -> Result, ErrorStack> { + if !ssl.is_async() { + return Ok(None); + } + + let atomic_waker = Arc::new(AtomicWaker::new()); + ssl.set_async_engine_waker(&atomic_waker)?; + + Ok(Some(AsyncEnginePoller { + tracked_fds: Vec::with_capacity(1), + atomic_waker, + })) + } + + #[cfg(ossl300)] + pub(crate) fn set_cx(&self, cx: &mut Context<'_>) { + self.atomic_waker.register(cx.waker()); + } + + #[cfg(not(ossl300))] + pub(crate) fn poll_ready( + &mut self, + ssl: &SslRef, + cx: &mut Context<'_>, + ) -> Poll> { + let (add, del) = ssl.get_changed_fds().map_err(io::Error::other)?; + for fd in add { + let async_fd = AsyncFd::with_interest(fd, Interest::READABLE)?; + self.tracked_fds.push(async_fd); + } + for fd in del { + self.tracked_fds.retain(|v| fd.ne(v.get_ref())); + } + + for fd in &self.tracked_fds { + match fd.poll_read_ready(cx) { + Poll::Pending => {} + Poll::Ready(Ok(_)) => return Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + } + } + Poll::Pending + } + + #[cfg(ossl300)] + pub(crate) fn poll_ready( + &mut self, + ssl: &SslRef, + cx: &mut Context<'_>, + ) -> Poll> { + match ssl.async_status() { + ffi::ASYNC_STATUS_UNSUPPORTED => { + let (add, del) = ssl.get_changed_fds().map_err(io::Error::other)?; + for fd in add { + let async_fd = AsyncFd::with_interest(fd, Interest::READABLE)?; + self.tracked_fds.push(async_fd); + } + for fd in del { + self.tracked_fds.retain(|v| fd.ne(v.get_ref())); + } + + for fd in &self.tracked_fds { + match fd.poll_read_ready(cx) { + Poll::Pending => {} + Poll::Ready(Ok(_)) => return Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + } + } + Poll::Pending + } + ffi::ASYNC_STATUS_ERR => Poll::Ready(Err(io::Error::other(ErrorStack::get()))), + ffi::ASYNC_STATUS_OK => { + // submitted, wait for the callback + Poll::Pending + } + ffi::ASYNC_STATUS_EAGAIN => { + // engine busy, resume later + cx.waker().wake_by_ref(); + Poll::Pending + } + r => Poll::Ready(Err(io::Error::other(format!( + "SSL_get_async_status returned {r}" + )))), + } + } +} + +#[cfg(ossl300)] +extern "C" fn async_engine_wake(_ssl: *mut SSL, arg: *mut c_void) -> c_int { + let waker = unsafe { &*(arg as *const AtomicWaker) }; + waker.wake(); + 0 +} diff --git a/lib/g3-openssl/src/connect.rs b/lib/g3-openssl/src/connect.rs index f6996cfd..0dde8656 100644 --- a/lib/g3-openssl/src/connect.rs +++ b/lib/g3-openssl/src/connect.rs @@ -15,48 +15,80 @@ */ use std::future; -use std::task::{Context, Poll}; +use std::io; +use std::task::{ready, Context, Poll}; use openssl::error::ErrorStack; use openssl::ssl::{self, ErrorCode, Ssl}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{SslIoWrapper, SslStream}; +use super::{AsyncEnginePoller, SslIoWrapper, SslStream}; pub struct SslConnector { inner: ssl::SslStream>, + async_engine: Option, } impl SslConnector { + #[cfg(not(ossl300))] pub fn new(ssl: Ssl, stream: S) -> Result { let wrapper = SslIoWrapper::new(stream); - ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector { inner }) + let async_engine = AsyncEnginePoller::new(&ssl); + + ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector { + inner, + async_engine, + }) + } + + #[cfg(ossl300)] + pub fn new(ssl: Ssl, stream: S) -> Result { + let wrapper = SslIoWrapper::new(stream); + let async_engine = AsyncEnginePoller::new(&ssl)?; + + ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector { + inner, + async_engine, + }) } } impl SslConnector { - pub fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll> { + pub fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.get_mut().set_cx(cx); + #[cfg(ossl300)] + if let Some(async_engine) = &self.async_engine { + async_engine.set_cx(cx); + } - match self.inner.connect() { - Ok(_) => Poll::Ready(Ok(())), - Err(e) => match e.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending, - ErrorCode::WANT_ASYNC => { - // TODO - todo!() - } - ErrorCode::WANT_ASYNC_JOB => { - cx.waker().wake_by_ref(); - Poll::Pending - } - _ => Poll::Ready(Err(e)), - }, + loop { + match self.inner.connect() { + Ok(_) => return Poll::Ready(Ok(())), + Err(e) => match e.code() { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending, + ErrorCode::WANT_ASYNC => { + if let Some(async_engine) = &mut self.async_engine { + ready!(async_engine.poll_ready(self.inner.ssl(), cx))? + } else { + return Poll::Ready(Err(io::Error::other( + "async engine poller is not set", + ))); + } + } + ErrorCode::WANT_ASYNC_JOB => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + _ => { + return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other))) + } + }, + } } } - pub async fn connect(mut self) -> Result, ssl::Error> { + pub async fn connect(mut self) -> io::Result> { future::poll_fn(|cx| self.poll_connect(cx)).await?; - Ok(SslStream::new(self.inner)) + Ok(SslStream::new(self.inner, None)) } } diff --git a/lib/g3-openssl/src/ffi.rs b/lib/g3-openssl/src/ffi.rs new file mode 100644 index 00000000..1d8f32f2 --- /dev/null +++ b/lib/g3-openssl/src/ffi.rs @@ -0,0 +1,59 @@ +/* + * Copyright 2023 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#![allow(unused)] + +use libc::{c_int, c_long, c_void}; +use openssl_sys::{SSL_ctrl, SSL, SSL_CTRL_MODE, SSL_CTX}; +use std::ptr; + +pub const ASYNC_STATUS_UNSUPPORTED: c_int = 0; +pub const ASYNC_STATUS_ERR: c_int = 1; +pub const ASYNC_STATUS_OK: c_int = 2; +pub const ASYNC_STATUS_EAGAIN: c_int = 3; + +#[allow(non_camel_case_types)] +#[cfg(ossl300)] +pub type SSL_async_callback_fn = + Option c_int>; + +extern "C" { + pub fn SSL_waiting_for_async(s: *mut SSL) -> c_int; + pub fn SSL_get_all_async_fds(s: *mut SSL, fd: *mut c_int, numfds: *mut usize) -> c_int; + pub fn SSL_get_changed_async_fds( + s: *mut SSL, + addfd: *mut c_int, + numaddfds: *mut usize, + delfd: *mut c_int, + numdelfds: *mut usize, + ) -> c_int; + #[cfg(ossl300)] + pub fn SSL_CTX_set_async_callback(ctx: *mut SSL_CTX, callback: SSL_async_callback_fn) -> c_int; + #[cfg(ossl300)] + pub fn SSL_CTX_set_async_callback_arg(ctx: *mut SSL_CTX, arg: *mut c_void) -> c_int; + #[cfg(ossl300)] + pub fn SSL_set_async_callback(s: *mut SSL, callback: SSL_async_callback_fn) -> c_int; + #[cfg(ossl300)] + pub fn SSL_set_async_callback_arg(s: *mut SSL, arg: *mut c_void) -> c_int; + #[cfg(ossl300)] + pub fn SSL_get_async_status(s: *mut SSL) -> c_int; + +} + +#[allow(non_snake_case)] +pub unsafe fn SSL_get_mode(ctx: *mut SSL) -> c_long { + SSL_ctrl(ctx, SSL_CTRL_MODE, 0, ptr::null_mut()) +} diff --git a/lib/g3-openssl/src/lib.rs b/lib/g3-openssl/src/lib.rs index cee2b383..ac45d25d 100644 --- a/lib/g3-openssl/src/lib.rs +++ b/lib/g3-openssl/src/lib.rs @@ -14,9 +14,14 @@ * limitations under the License. */ +mod ffi; + mod wrapper; use wrapper::SslIoWrapper; +mod async_mode; +use async_mode::AsyncEnginePoller; + mod stream; pub use stream::SslStream; diff --git a/lib/g3-openssl/src/stream.rs b/lib/g3-openssl/src/stream.rs index 2544a94d..904a04fe 100644 --- a/lib/g3-openssl/src/stream.rs +++ b/lib/g3-openssl/src/stream.rs @@ -16,21 +16,28 @@ use std::io; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use openssl::ssl::{self, ErrorCode, SslRef}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use super::SslIoWrapper; +use super::{AsyncEnginePoller, SslIoWrapper}; pub struct SslStream { inner: ssl::SslStream>, + async_engine: Option, } impl SslStream { #[inline] - pub(crate) fn new(inner: ssl::SslStream>) -> Self { - SslStream { inner } + pub(crate) fn new( + inner: ssl::SslStream>, + async_engine: Option, + ) -> Self { + SslStream { + inner, + async_engine, + } } #[inline] @@ -42,15 +49,31 @@ impl SslStream { pub fn get_mut(&mut self) -> &mut S { self.inner.get_mut().get_mut() } + + fn set_cx(&mut self, cx: &mut Context<'_>) { + self.inner.get_mut().set_cx(cx); + #[cfg(ossl300)] + if let Some(async_engine) = &self.async_engine { + async_engine.set_cx(cx); + } + } + + fn poll_async_engine(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(async_engine) = &mut self.async_engine { + async_engine.poll_ready(self.inner.ssl(), cx) + } else { + Poll::Ready(Err(io::Error::other("async engine poller is not set"))) + } + } } -impl AsyncRead for SslStream { - fn poll_read( - mut self: Pin<&mut Self>, +impl SslStream { + fn poll_read_unpin( + &mut self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.inner.get_mut().set_cx(cx); + self.set_cx(cx); loop { match self.inner.ssl_read_uninit(unsafe { buf.unfilled_mut() }) { @@ -69,10 +92,7 @@ impl AsyncRead for SslStream { } } ErrorCode::WANT_WRITE => return Poll::Pending, - ErrorCode::WANT_ASYNC => { - // TODO - todo!() - } + ErrorCode::WANT_ASYNC => ready!(self.poll_async_engine(cx))?, ErrorCode::WANT_ASYNC_JOB => { cx.waker().wake_by_ref(); return Poll::Pending; @@ -84,15 +104,9 @@ impl AsyncRead for SslStream { } } } -} -impl AsyncWrite for SslStream { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.inner.get_mut().set_cx(cx); + fn poll_write_unpin(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.set_cx(cx); loop { match self.inner.ssl_write(buf) { @@ -106,10 +120,7 @@ impl AsyncWrite for SslStream { } } ErrorCode::WANT_WRITE => return Poll::Pending, - ErrorCode::WANT_ASYNC => { - // TODO - todo!() - } + ErrorCode::WANT_ASYNC => ready!(self.poll_async_engine(cx))?, ErrorCode::WANT_ASYNC_JOB => { cx.waker().wake_by_ref(); return Poll::Pending; @@ -122,27 +133,24 @@ impl AsyncWrite for SslStream { } } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.get_mut().get_pin_mut().poll_flush(cx) - } + fn poll_shutdown_unpin(&mut self, cx: &mut Context<'_>) -> Poll> { + self.set_cx(cx); - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.get_mut().set_cx(cx); - - if let Err(e) = self.inner.shutdown() { - match e.code() { - ErrorCode::ZERO_RETURN => {} - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending, - ErrorCode::WANT_ASYNC => { - // TODO - todo!() - } - ErrorCode::WANT_ASYNC_JOB => { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - _ => { - return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other))); + loop { + if let Err(e) = self.inner.shutdown() { + match e.code() { + ErrorCode::ZERO_RETURN => break, + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending, + ErrorCode::WANT_ASYNC => ready!(self.poll_async_engine(cx))?, + ErrorCode::WANT_ASYNC_JOB => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + _ => { + return Poll::Ready(Err(e + .into_io_error() + .unwrap_or_else(io::Error::other))); + } } } } @@ -150,3 +158,31 @@ impl AsyncWrite for SslStream { self.inner.get_mut().get_pin_mut().poll_shutdown(cx) } } + +impl AsyncRead for SslStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.as_mut().get_mut().poll_read_unpin(cx, buf) + } +} + +impl AsyncWrite for SslStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.as_mut().get_mut().poll_write_unpin(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.get_mut().get_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().get_mut().poll_shutdown_unpin(cx) + } +}