diff --git a/lib/g3-socket/src/guard.rs b/lib/g3-socket/src/guard.rs new file mode 100644 index 00000000..22da5e77 --- /dev/null +++ b/lib/g3-socket/src/guard.rs @@ -0,0 +1,85 @@ +/* + * Copyright 2023 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::ops::{Deref, DerefMut}; +use std::os::fd::{FromRawFd, IntoRawFd, RawFd}; + +pub struct RawFdGuard +where + T: FromRawFd + IntoRawFd, +{ + inner: Option, +} + +impl RawFdGuard +where + T: FromRawFd + IntoRawFd, +{ + pub fn new(fd: RawFd) -> Self { + Self { + inner: unsafe { Some(T::from_raw_fd(fd)) }, + } + } +} + +impl Drop for RawFdGuard +where + T: FromRawFd + IntoRawFd, +{ + fn drop(&mut self) { + if let Some(resource) = self.inner.take() { + let _ = resource.into_raw_fd(); + } + } +} + +impl Deref for RawFdGuard +where + T: FromRawFd + IntoRawFd, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + // the only way setting inner to None is drop + self.inner.as_ref().unwrap() + } +} + +impl DerefMut for RawFdGuard +where + T: FromRawFd + IntoRawFd, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + // the only way setting inner to None is drop + self.inner.as_mut().unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::RawFdGuard; + use socket2::Socket; + + #[test] + fn not_close_fd() { + let fd = 0; + { + let _socket = RawFdGuard::::new(fd); + // unsafe { libc::close(fd) }; + } + assert!(unsafe { libc::fcntl(fd, libc::F_GETFD) } != -1); + } +} diff --git a/lib/g3-socket/src/lib.rs b/lib/g3-socket/src/lib.rs index 52d2cc8f..e1e60606 100644 --- a/lib/g3-socket/src/lib.rs +++ b/lib/g3-socket/src/lib.rs @@ -16,6 +16,7 @@ mod sockopt; +pub mod guard; pub mod tcp; pub mod udp; pub mod util; diff --git a/lib/g3-socket/src/tcp.rs b/lib/g3-socket/src/tcp.rs index 25368465..f3171135 100644 --- a/lib/g3-socket/src/tcp.rs +++ b/lib/g3-socket/src/tcp.rs @@ -16,13 +16,14 @@ use std::io; use std::net::{IpAddr, SocketAddr}; -use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::os::fd::{AsRawFd, RawFd}; use socket2::{Domain, SockAddr, Socket, TcpKeepalive, Type}; use tokio::net::{TcpListener, TcpSocket}; use g3_types::net::{TcpKeepAliveConfig, TcpListenConfig, TcpMiscSockOpts}; +use super::guard::RawFdGuard; use super::sockopt::{set_bind_address_no_port, set_only_ipv6}; use super::util::AddressFamily; @@ -89,10 +90,8 @@ pub fn set_raw_opts( misc_opts: &TcpMiscSockOpts, default_set_nodelay: bool, ) -> io::Result<()> { - let socket = unsafe { Socket::from_raw_fd(fd) }; - set_misc_opts(&socket, misc_opts, default_set_nodelay)?; - let _ = socket.into_raw_fd(); - Ok(()) + let socket = RawFdGuard::::new(fd); + set_misc_opts(&socket, misc_opts, default_set_nodelay) } fn set_misc_opts( diff --git a/lib/g3-socket/src/udp.rs b/lib/g3-socket/src/udp.rs index 6e672b1f..0b84953b 100644 --- a/lib/g3-socket/src/udp.rs +++ b/lib/g3-socket/src/udp.rs @@ -16,12 +16,13 @@ use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; -use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +use std::os::fd::{AsRawFd, RawFd}; use socket2::{Domain, SockAddr, Socket, Type}; use g3_types::net::{PortRange, SocketBufferConfig, UdpListenConfig, UdpMiscSockOpts}; +use super::guard::RawFdGuard; use super::sockopt::set_bind_address_no_port; use super::util::AddressFamily; @@ -158,17 +159,13 @@ pub fn new_std_rebind_listen(config: &UdpListenConfig, addr: SocketAddr) -> io:: } pub fn set_raw_opts(fd: RawFd, misc_opts: UdpMiscSockOpts) -> io::Result<()> { - let socket = unsafe { Socket::from_raw_fd(fd) }; - set_misc_opts(&socket, misc_opts)?; - let _ = socket.into_raw_fd(); - Ok(()) + let socket = RawFdGuard::::new(fd); + set_misc_opts(&socket, misc_opts) } pub fn set_raw_buf_opts(fd: RawFd, buf_conf: SocketBufferConfig) -> io::Result<()> { - let socket = unsafe { Socket::from_raw_fd(fd) }; - set_buf_opts(&socket, buf_conf)?; - let _ = socket.into_raw_fd(); - Ok(()) + let socket = RawFdGuard::::new(fd); + set_buf_opts(&socket, buf_conf) } fn set_misc_opts(socket: &Socket, misc_opts: UdpMiscSockOpts) -> io::Result<()> {