mirror of
https://github.com/bytedance/g3.git
synced 2026-05-22 03:03:39 +00:00
g3-openssl: add async engine poller
This commit is contained in:
parent
5e5a9da481
commit
7922be6314
7 changed files with 503 additions and 84 deletions
28
lib/g3-openssl/build.rs
Normal file
28
lib/g3-openssl/build.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Copyright 2023 ByteDance and/or its affiliates.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
use std::env;
|
||||
|
||||
#[allow(clippy::unusual_byte_groupings)]
|
||||
fn main() {
|
||||
if let Ok(version) = env::var("DEP_OPENSSL_VERSION_NUMBER") {
|
||||
let version = u64::from_str_radix(&version, 16).unwrap();
|
||||
|
||||
if version >= 0x3_00_00_00_0 {
|
||||
println!("cargo:rustc-cfg=ossl300");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -15,48 +15,80 @@
|
|||
*/
|
||||
|
||||
use std::future;
|
||||
use std::task::{Context, Poll};
|
||||
use std::io;
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
use openssl::error::ErrorStack;
|
||||
use openssl::ssl::{self, ErrorCode, Ssl};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::{SslIoWrapper, SslStream};
|
||||
use super::{AsyncEnginePoller, SslIoWrapper, SslStream};
|
||||
|
||||
pub struct SslAcceptor<S> {
|
||||
inner: ssl::SslStream<SslIoWrapper<S>>,
|
||||
async_engine: Option<AsyncEnginePoller>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SslAcceptor<S> {
|
||||
#[cfg(not(ossl300))]
|
||||
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
|
||||
let wrapper = SslIoWrapper::new(stream);
|
||||
ssl::SslStream::new(ssl, wrapper).map(|inner| SslAcceptor { inner })
|
||||
let async_engine = AsyncEnginePoller::new(&ssl);
|
||||
|
||||
ssl::SslStream::new(ssl, wrapper).map(|inner| SslAcceptor {
|
||||
inner,
|
||||
async_engine,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
|
||||
let wrapper = SslIoWrapper::new(stream);
|
||||
let async_engine = AsyncEnginePoller::new(&ssl)?;
|
||||
|
||||
ssl::SslStream::new(ssl, wrapper).map(|inner| SslAcceptor {
|
||||
inner,
|
||||
async_engine,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SslAcceptor<S> {
|
||||
pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
|
||||
pub fn poll_accept(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.inner.get_mut().set_cx(cx);
|
||||
#[cfg(ossl300)]
|
||||
if let Some(async_engine) = &self.async_engine {
|
||||
async_engine.set_cx(cx);
|
||||
}
|
||||
|
||||
match self.inner.accept() {
|
||||
Ok(_) => Poll::Ready(Ok(())),
|
||||
Err(e) => match e.code() {
|
||||
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
// TODO
|
||||
todo!()
|
||||
}
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
_ => Poll::Ready(Err(e)),
|
||||
},
|
||||
loop {
|
||||
match self.inner.accept() {
|
||||
Ok(_) => return Poll::Ready(Ok(())),
|
||||
Err(e) => match e.code() {
|
||||
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
if let Some(async_engine) = &mut self.async_engine {
|
||||
ready!(async_engine.poll_ready(self.inner.ssl(), cx))?
|
||||
} else {
|
||||
return Poll::Ready(Err(io::Error::other(
|
||||
"async engine poller is not set",
|
||||
)));
|
||||
}
|
||||
}
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
_ => {
|
||||
return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn accept(mut self) -> Result<SslStream<S>, ssl::Error> {
|
||||
pub async fn accept(mut self) -> io::Result<SslStream<S>> {
|
||||
future::poll_fn(|cx| self.poll_accept(cx)).await?;
|
||||
Ok(SslStream::new(self.inner))
|
||||
Ok(SslStream::new(self.inner, self.async_engine))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
227
lib/g3-openssl/src/async_mode.rs
Normal file
227
lib/g3-openssl/src/async_mode.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
/*
|
||||
* Copyright 2023 ByteDance and/or its affiliates.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
use std::os::fd::RawFd;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{io, ptr};
|
||||
|
||||
#[cfg(ossl300)]
|
||||
use atomic_waker::AtomicWaker;
|
||||
use libc::c_int;
|
||||
#[cfg(ossl300)]
|
||||
use libc::c_void;
|
||||
use openssl::error::ErrorStack;
|
||||
use openssl::foreign_types::ForeignTypeRef;
|
||||
use openssl::ssl::SslRef;
|
||||
#[cfg(ossl300)]
|
||||
use openssl_sys::SSL;
|
||||
use tokio::io::unix::AsyncFd;
|
||||
use tokio::io::Interest;
|
||||
|
||||
use crate::ffi;
|
||||
|
||||
pub trait SslAsyncModeExt {
|
||||
fn is_async(&self) -> bool;
|
||||
fn waiting_for_async(&self) -> bool;
|
||||
#[cfg(ossl300)]
|
||||
fn async_status(&self) -> c_int;
|
||||
#[cfg(ossl300)]
|
||||
fn set_async_engine_waker(&self, waker: &Arc<AtomicWaker>) -> Result<(), ErrorStack>;
|
||||
fn get_changed_fds(&self) -> Result<(Vec<RawFd>, Vec<RawFd>), ErrorStack>;
|
||||
}
|
||||
|
||||
impl SslAsyncModeExt for SslRef {
|
||||
fn is_async(&self) -> bool {
|
||||
unsafe { (ffi::SSL_get_mode(self.as_ptr()) & 0x00000100) != 0 }
|
||||
}
|
||||
|
||||
fn waiting_for_async(&self) -> bool {
|
||||
unsafe { ffi::SSL_waiting_for_async(self.as_ptr()) == 1 }
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
fn async_status(&self) -> c_int {
|
||||
unsafe { ffi::SSL_get_async_status(self.as_ptr()) }
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
fn set_async_engine_waker(&self, waker: &Arc<AtomicWaker>) -> Result<(), ErrorStack> {
|
||||
let r = unsafe { ffi::SSL_set_async_callback(self.as_ptr(), Some(async_engine_wake)) };
|
||||
if r != 1 {
|
||||
return Err(ErrorStack::get());
|
||||
}
|
||||
|
||||
let r = unsafe {
|
||||
ffi::SSL_set_async_callback_arg(self.as_ptr(), Arc::as_ptr(waker) as *mut c_void)
|
||||
};
|
||||
if r != 1 {
|
||||
Err(ErrorStack::get())
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn get_changed_fds(&self) -> Result<(Vec<RawFd>, Vec<RawFd>), ErrorStack> {
|
||||
let mut add_fd_count = 0usize;
|
||||
let mut del_fd_count = 0usize;
|
||||
let r = unsafe {
|
||||
ffi::SSL_get_changed_async_fds(
|
||||
self.as_ptr(),
|
||||
ptr::null_mut(),
|
||||
&mut add_fd_count as *mut usize,
|
||||
ptr::null_mut(),
|
||||
&mut del_fd_count as *mut usize,
|
||||
)
|
||||
};
|
||||
if r != 1 {
|
||||
return Err(ErrorStack::get());
|
||||
}
|
||||
|
||||
let mut add_fds: Vec<c_int> = vec![0; add_fd_count];
|
||||
let mut del_fds: Vec<c_int> = vec![0; del_fd_count];
|
||||
let r = unsafe {
|
||||
ffi::SSL_get_changed_async_fds(
|
||||
self.as_ptr(),
|
||||
add_fds.as_mut_ptr(),
|
||||
&mut add_fd_count as *mut usize,
|
||||
del_fds.as_mut_ptr(),
|
||||
&mut del_fd_count as *mut usize,
|
||||
)
|
||||
};
|
||||
if r != 1 {
|
||||
return Err(ErrorStack::get());
|
||||
}
|
||||
|
||||
Ok((
|
||||
add_fds.into_iter().map(RawFd::from).collect(),
|
||||
del_fds.into_iter().map(RawFd::from).collect(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AsyncEnginePoller {
|
||||
tracked_fds: Vec<AsyncFd<RawFd>>,
|
||||
#[cfg(ossl300)]
|
||||
atomic_waker: Arc<AtomicWaker>,
|
||||
}
|
||||
|
||||
impl AsyncEnginePoller {
|
||||
#[cfg(not(ossl300))]
|
||||
pub(crate) fn new(ssl: &SslRef) -> Option<Self> {
|
||||
if ssl.is_async() {
|
||||
Some(AsyncEnginePoller {
|
||||
tracked_fds: Vec::with_capacity(1),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
pub(crate) fn new(ssl: &SslRef) -> Result<Option<Self>, ErrorStack> {
|
||||
if !ssl.is_async() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let atomic_waker = Arc::new(AtomicWaker::new());
|
||||
ssl.set_async_engine_waker(&atomic_waker)?;
|
||||
|
||||
Ok(Some(AsyncEnginePoller {
|
||||
tracked_fds: Vec::with_capacity(1),
|
||||
atomic_waker,
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
pub(crate) fn set_cx(&self, cx: &mut Context<'_>) {
|
||||
self.atomic_waker.register(cx.waker());
|
||||
}
|
||||
|
||||
#[cfg(not(ossl300))]
|
||||
pub(crate) fn poll_ready(
|
||||
&mut self,
|
||||
ssl: &SslRef,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let (add, del) = ssl.get_changed_fds().map_err(io::Error::other)?;
|
||||
for fd in add {
|
||||
let async_fd = AsyncFd::with_interest(fd, Interest::READABLE)?;
|
||||
self.tracked_fds.push(async_fd);
|
||||
}
|
||||
for fd in del {
|
||||
self.tracked_fds.retain(|v| fd.ne(v.get_ref()));
|
||||
}
|
||||
|
||||
for fd in &self.tracked_fds {
|
||||
match fd.poll_read_ready(cx) {
|
||||
Poll::Pending => {}
|
||||
Poll::Ready(Ok(_)) => return Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
pub(crate) fn poll_ready(
|
||||
&mut self,
|
||||
ssl: &SslRef,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match ssl.async_status() {
|
||||
ffi::ASYNC_STATUS_UNSUPPORTED => {
|
||||
let (add, del) = ssl.get_changed_fds().map_err(io::Error::other)?;
|
||||
for fd in add {
|
||||
let async_fd = AsyncFd::with_interest(fd, Interest::READABLE)?;
|
||||
self.tracked_fds.push(async_fd);
|
||||
}
|
||||
for fd in del {
|
||||
self.tracked_fds.retain(|v| fd.ne(v.get_ref()));
|
||||
}
|
||||
|
||||
for fd in &self.tracked_fds {
|
||||
match fd.poll_read_ready(cx) {
|
||||
Poll::Pending => {}
|
||||
Poll::Ready(Ok(_)) => return Poll::Ready(Ok(())),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
ffi::ASYNC_STATUS_ERR => Poll::Ready(Err(io::Error::other(ErrorStack::get()))),
|
||||
ffi::ASYNC_STATUS_OK => {
|
||||
// submitted, wait for the callback
|
||||
Poll::Pending
|
||||
}
|
||||
ffi::ASYNC_STATUS_EAGAIN => {
|
||||
// engine busy, resume later
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
r => Poll::Ready(Err(io::Error::other(format!(
|
||||
"SSL_get_async_status returned {r}"
|
||||
)))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
extern "C" fn async_engine_wake(_ssl: *mut SSL, arg: *mut c_void) -> c_int {
|
||||
let waker = unsafe { &*(arg as *const AtomicWaker) };
|
||||
waker.wake();
|
||||
0
|
||||
}
|
||||
|
|
@ -15,48 +15,80 @@
|
|||
*/
|
||||
|
||||
use std::future;
|
||||
use std::task::{Context, Poll};
|
||||
use std::io;
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
use openssl::error::ErrorStack;
|
||||
use openssl::ssl::{self, ErrorCode, Ssl};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use super::{SslIoWrapper, SslStream};
|
||||
use super::{AsyncEnginePoller, SslIoWrapper, SslStream};
|
||||
|
||||
pub struct SslConnector<S> {
|
||||
inner: ssl::SslStream<SslIoWrapper<S>>,
|
||||
async_engine: Option<AsyncEnginePoller>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SslConnector<S> {
|
||||
#[cfg(not(ossl300))]
|
||||
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
|
||||
let wrapper = SslIoWrapper::new(stream);
|
||||
ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector { inner })
|
||||
let async_engine = AsyncEnginePoller::new(&ssl);
|
||||
|
||||
ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector {
|
||||
inner,
|
||||
async_engine,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(ossl300)]
|
||||
pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
|
||||
let wrapper = SslIoWrapper::new(stream);
|
||||
let async_engine = AsyncEnginePoller::new(&ssl)?;
|
||||
|
||||
ssl::SslStream::new(ssl, wrapper).map(|inner| SslConnector {
|
||||
inner,
|
||||
async_engine,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SslConnector<S> {
|
||||
pub fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
|
||||
pub fn poll_connect(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.inner.get_mut().set_cx(cx);
|
||||
#[cfg(ossl300)]
|
||||
if let Some(async_engine) = &self.async_engine {
|
||||
async_engine.set_cx(cx);
|
||||
}
|
||||
|
||||
match self.inner.connect() {
|
||||
Ok(_) => Poll::Ready(Ok(())),
|
||||
Err(e) => match e.code() {
|
||||
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
// TODO
|
||||
todo!()
|
||||
}
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
Poll::Pending
|
||||
}
|
||||
_ => Poll::Ready(Err(e)),
|
||||
},
|
||||
loop {
|
||||
match self.inner.connect() {
|
||||
Ok(_) => return Poll::Ready(Ok(())),
|
||||
Err(e) => match e.code() {
|
||||
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
if let Some(async_engine) = &mut self.async_engine {
|
||||
ready!(async_engine.poll_ready(self.inner.ssl(), cx))?
|
||||
} else {
|
||||
return Poll::Ready(Err(io::Error::other(
|
||||
"async engine poller is not set",
|
||||
)));
|
||||
}
|
||||
}
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
_ => {
|
||||
return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(mut self) -> Result<SslStream<S>, ssl::Error> {
|
||||
pub async fn connect(mut self) -> io::Result<SslStream<S>> {
|
||||
future::poll_fn(|cx| self.poll_connect(cx)).await?;
|
||||
Ok(SslStream::new(self.inner))
|
||||
Ok(SslStream::new(self.inner, None))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
59
lib/g3-openssl/src/ffi.rs
Normal file
59
lib/g3-openssl/src/ffi.rs
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright 2023 ByteDance and/or its affiliates.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#![allow(unused)]
|
||||
|
||||
use libc::{c_int, c_long, c_void};
|
||||
use openssl_sys::{SSL_ctrl, SSL, SSL_CTRL_MODE, SSL_CTX};
|
||||
use std::ptr;
|
||||
|
||||
pub const ASYNC_STATUS_UNSUPPORTED: c_int = 0;
|
||||
pub const ASYNC_STATUS_ERR: c_int = 1;
|
||||
pub const ASYNC_STATUS_OK: c_int = 2;
|
||||
pub const ASYNC_STATUS_EAGAIN: c_int = 3;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[cfg(ossl300)]
|
||||
pub type SSL_async_callback_fn =
|
||||
Option<unsafe extern "C" fn(s: *mut SSL, arg: *mut c_void) -> c_int>;
|
||||
|
||||
extern "C" {
|
||||
pub fn SSL_waiting_for_async(s: *mut SSL) -> c_int;
|
||||
pub fn SSL_get_all_async_fds(s: *mut SSL, fd: *mut c_int, numfds: *mut usize) -> c_int;
|
||||
pub fn SSL_get_changed_async_fds(
|
||||
s: *mut SSL,
|
||||
addfd: *mut c_int,
|
||||
numaddfds: *mut usize,
|
||||
delfd: *mut c_int,
|
||||
numdelfds: *mut usize,
|
||||
) -> c_int;
|
||||
#[cfg(ossl300)]
|
||||
pub fn SSL_CTX_set_async_callback(ctx: *mut SSL_CTX, callback: SSL_async_callback_fn) -> c_int;
|
||||
#[cfg(ossl300)]
|
||||
pub fn SSL_CTX_set_async_callback_arg(ctx: *mut SSL_CTX, arg: *mut c_void) -> c_int;
|
||||
#[cfg(ossl300)]
|
||||
pub fn SSL_set_async_callback(s: *mut SSL, callback: SSL_async_callback_fn) -> c_int;
|
||||
#[cfg(ossl300)]
|
||||
pub fn SSL_set_async_callback_arg(s: *mut SSL, arg: *mut c_void) -> c_int;
|
||||
#[cfg(ossl300)]
|
||||
pub fn SSL_get_async_status(s: *mut SSL) -> c_int;
|
||||
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub unsafe fn SSL_get_mode(ctx: *mut SSL) -> c_long {
|
||||
SSL_ctrl(ctx, SSL_CTRL_MODE, 0, ptr::null_mut())
|
||||
}
|
||||
|
|
@ -14,9 +14,14 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
mod ffi;
|
||||
|
||||
mod wrapper;
|
||||
use wrapper::SslIoWrapper;
|
||||
|
||||
mod async_mode;
|
||||
use async_mode::AsyncEnginePoller;
|
||||
|
||||
mod stream;
|
||||
pub use stream::SslStream;
|
||||
|
||||
|
|
|
|||
|
|
@ -16,21 +16,28 @@
|
|||
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
use openssl::ssl::{self, ErrorCode, SslRef};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use super::SslIoWrapper;
|
||||
use super::{AsyncEnginePoller, SslIoWrapper};
|
||||
|
||||
pub struct SslStream<S> {
|
||||
inner: ssl::SslStream<SslIoWrapper<S>>,
|
||||
async_engine: Option<AsyncEnginePoller>,
|
||||
}
|
||||
|
||||
impl<S> SslStream<S> {
|
||||
#[inline]
|
||||
pub(crate) fn new(inner: ssl::SslStream<SslIoWrapper<S>>) -> Self {
|
||||
SslStream { inner }
|
||||
pub(crate) fn new(
|
||||
inner: ssl::SslStream<SslIoWrapper<S>>,
|
||||
async_engine: Option<AsyncEnginePoller>,
|
||||
) -> Self {
|
||||
SslStream {
|
||||
inner,
|
||||
async_engine,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
|
@ -42,15 +49,31 @@ impl<S> SslStream<S> {
|
|||
pub fn get_mut(&mut self) -> &mut S {
|
||||
self.inner.get_mut().get_mut()
|
||||
}
|
||||
|
||||
fn set_cx(&mut self, cx: &mut Context<'_>) {
|
||||
self.inner.get_mut().set_cx(cx);
|
||||
#[cfg(ossl300)]
|
||||
if let Some(async_engine) = &self.async_engine {
|
||||
async_engine.set_cx(cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_async_engine(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
if let Some(async_engine) = &mut self.async_engine {
|
||||
async_engine.poll_ready(self.inner.ssl(), cx)
|
||||
} else {
|
||||
Poll::Ready(Err(io::Error::other("async engine poller is not set")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SslStream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
impl<S: AsyncRead + AsyncWrite> SslStream<S> {
|
||||
fn poll_read_unpin(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.inner.get_mut().set_cx(cx);
|
||||
self.set_cx(cx);
|
||||
|
||||
loop {
|
||||
match self.inner.ssl_read_uninit(unsafe { buf.unfilled_mut() }) {
|
||||
|
|
@ -69,10 +92,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SslStream<S> {
|
|||
}
|
||||
}
|
||||
ErrorCode::WANT_WRITE => return Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
// TODO
|
||||
todo!()
|
||||
}
|
||||
ErrorCode::WANT_ASYNC => ready!(self.poll_async_engine(cx))?,
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
|
|
@ -84,15 +104,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SslStream<S> {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SslStream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.inner.get_mut().set_cx(cx);
|
||||
fn poll_write_unpin(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
||||
self.set_cx(cx);
|
||||
|
||||
loop {
|
||||
match self.inner.ssl_write(buf) {
|
||||
|
|
@ -106,10 +120,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SslStream<S> {
|
|||
}
|
||||
}
|
||||
ErrorCode::WANT_WRITE => return Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
// TODO
|
||||
todo!()
|
||||
}
|
||||
ErrorCode::WANT_ASYNC => ready!(self.poll_async_engine(cx))?,
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
|
|
@ -122,27 +133,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SslStream<S> {
|
|||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.inner.get_mut().get_pin_mut().poll_flush(cx)
|
||||
}
|
||||
fn poll_shutdown_unpin(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.set_cx(cx);
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.inner.get_mut().set_cx(cx);
|
||||
|
||||
if let Err(e) = self.inner.shutdown() {
|
||||
match e.code() {
|
||||
ErrorCode::ZERO_RETURN => {}
|
||||
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => {
|
||||
// TODO
|
||||
todo!()
|
||||
}
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
_ => {
|
||||
return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
|
||||
loop {
|
||||
if let Err(e) = self.inner.shutdown() {
|
||||
match e.code() {
|
||||
ErrorCode::ZERO_RETURN => break,
|
||||
ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending,
|
||||
ErrorCode::WANT_ASYNC => ready!(self.poll_async_engine(cx))?,
|
||||
ErrorCode::WANT_ASYNC_JOB => {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
_ => {
|
||||
return Poll::Ready(Err(e
|
||||
.into_io_error()
|
||||
.unwrap_or_else(io::Error::other)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -150,3 +158,31 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SslStream<S> {
|
|||
self.inner.get_mut().get_pin_mut().poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for SslStream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.as_mut().get_mut().poll_read_unpin(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for SslStream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.as_mut().get_mut().poll_write_unpin(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.inner.get_mut().get_pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.as_mut().get_mut().poll_shutdown_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue