use macOS sendmsg_x and recvmsg_x (#319)

* g3-io-ext: use sendmsg_x and recvmsg_x on macOS

* use more macOS sendmsg_x/recvmsg_x

* fix clippy warning
This commit is contained in:
Zhang Jingqiang 2024-09-23 16:17:15 +08:00 committed by GitHub
parent 4a10503043
commit d9440682ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 486 additions and 27 deletions

View file

@ -23,6 +23,7 @@ use g3_io_ext::{AsyncUdpRecv, UdpCopyRemoteError, UdpCopyRemoteRecv};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{RecvMsgHdr, UdpCopyPacket, UdpCopyPacketMeta};
@ -62,6 +63,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,

View file

@ -24,6 +24,7 @@ use g3_io_ext::{AsyncUdpSend, UdpCopyRemoteError, UdpCopyRemoteSend};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{SendMsgHdr, UdpCopyPacket};
@ -90,4 +91,29 @@ where
Poll::Ready(Ok(count))
}
}
#[cfg(target_os = "macos")]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,
packets: &[UdpCopyPacket],
) -> Poll<Result<usize, UdpCopyRemoteError>> {
use std::io::IoSlice;
let mut msgs: Vec<SendMsgHdr<1>> = packets
.iter()
.map(|p| SendMsgHdr::new([IoSlice::new(p.payload())], None))
.collect();
let count = ready!(self.inner.poll_batch_sendmsg_x(cx, &mut msgs))
.map_err(UdpCopyRemoteError::SendFailed)?;
if count == 0 {
Poll::Ready(Err(UdpCopyRemoteError::SendFailed(io::Error::new(
io::ErrorKind::WriteZero,
"write zero packet into sender",
))))
} else {
Poll::Ready(Ok(count))
}
}
}

View file

@ -24,6 +24,7 @@ use g3_io_ext::{AsyncUdpRecv, UdpRelayRemoteError, UdpRelayRemoteRecv};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{RecvMsgHdr, UdpRelayPacket, UdpRelayPacketMeta};
use g3_types::net::UpstreamAddr;
@ -101,6 +102,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
inner: &mut T,
@ -157,6 +159,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,

View file

@ -27,6 +27,7 @@ use g3_io_ext::{AsyncUdpRecv, UdpCopyRemoteError, UdpCopyRemoteRecv};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{RecvMsgHdr, UdpCopyPacket, UdpCopyPacketMeta};
use g3_socks::v5::UdpInput;
@ -111,6 +112,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,

View file

@ -24,6 +24,7 @@ use g3_io_ext::{AsyncUdpSend, UdpCopyRemoteError, UdpCopyRemoteSend};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{SendMsgHdr, UdpCopyPacket};
use g3_socks::v5::UdpOutput;
@ -107,4 +108,32 @@ where
Poll::Ready(Ok(count))
}
}
#[cfg(target_os = "macos")]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,
packets: &[UdpCopyPacket],
) -> Poll<Result<usize, UdpCopyRemoteError>> {
let mut msgs: Vec<SendMsgHdr<2>> = packets
.iter()
.map(|p| {
SendMsgHdr::new(
[IoSlice::new(&self.socks5_header), IoSlice::new(p.payload())],
None,
)
})
.collect();
let count = ready!(self.inner.poll_batch_sendmsg_x(cx, &mut msgs))
.map_err(UdpCopyRemoteError::SendFailed)?;
if count == 0 {
Poll::Ready(Err(UdpCopyRemoteError::SendFailed(io::Error::new(
io::ErrorKind::WriteZero,
"write zero packet into sender",
))))
} else {
Poll::Ready(Ok(count))
}
}
}

View file

@ -28,6 +28,7 @@ use g3_io_ext::{AsyncUdpRecv, UdpRelayRemoteError, UdpRelayRemoteRecv};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{RecvMsgHdr, UdpRelayPacket, UdpRelayPacketMeta};
use g3_socks::v5::UdpInput;
@ -133,6 +134,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,

View file

@ -25,6 +25,7 @@ use g3_io_ext::{AsyncUdpSend, UdpRelayRemoteError, UdpRelayRemoteSend};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{SendMsgHdr, UdpRelayPacket};
use g3_socks::v5::SocksUdpHeader;
@ -117,4 +118,37 @@ where
Poll::Ready(Ok(count))
}
}
#[cfg(target_os = "macos")]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,
packets: &[UdpRelayPacket],
) -> Poll<Result<usize, UdpRelayRemoteError>> {
if packets.len() > self.socks_headers.len() {
self.socks_headers.resize(packets.len(), Default::default());
}
let mut msgs = Vec::with_capacity(packets.len());
for (p, h) in packets.iter().zip(self.socks_headers.iter_mut()) {
msgs.push(SendMsgHdr::new(
[
IoSlice::new(h.encode(p.upstream())),
IoSlice::new(p.payload()),
],
None,
));
}
let count = ready!(self.inner.poll_batch_sendmsg_x(cx, &mut msgs))
.map_err(|e| UdpRelayRemoteError::SendFailed(self.local_addr, self.peer_addr, e))?;
if count == 0 {
Poll::Ready(Err(UdpRelayRemoteError::SendFailed(
self.local_addr,
self.peer_addr,
io::Error::new(io::ErrorKind::WriteZero, "write zero packet into sender"),
)))
} else {
Poll::Ready(Ok(count))
}
}
}

View file

@ -26,6 +26,7 @@ use g3_io_ext::{AsyncUdpRecv, UdpRelayClientError, UdpRelayClientRecv};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{RecvMsgHdr, UdpRelayPacket, UdpRelayPacketMeta};
use g3_socks::v5::UdpInput;
@ -240,6 +241,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,

View file

@ -25,6 +25,7 @@ use g3_io_ext::{AsyncUdpSend, UdpRelayClientError, UdpRelayClientSend};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{SendMsgHdr, UdpRelayPacket};
use g3_socks::v5::SocksUdpHeader;
@ -113,4 +114,36 @@ where
Poll::Ready(Ok(count))
}
}
#[cfg(target_os = "macos")]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,
packets: &[UdpRelayPacket],
) -> Poll<Result<usize, UdpRelayClientError>> {
if packets.len() > self.socks_headers.len() {
self.socks_headers.resize(packets.len(), Default::default());
}
let mut msgs = Vec::with_capacity(packets.len());
for (p, h) in packets.iter().zip(self.socks_headers.iter_mut()) {
msgs.push(SendMsgHdr::new(
[
IoSlice::new(h.encode(p.upstream())),
IoSlice::new(p.payload()),
],
None,
));
}
let count = ready!(self.inner.poll_batch_sendmsg_x(cx, &mut msgs))
.map_err(UdpRelayClientError::SendFailed)?;
if count == 0 {
Poll::Ready(Err(UdpRelayClientError::SendFailed(io::Error::new(
io::ErrorKind::WriteZero,
"write zero packet into sender",
))))
} else {
Poll::Ready(Ok(count))
}
}
}

View file

@ -26,6 +26,7 @@ use g3_io_ext::{AsyncUdpRecv, UdpCopyClientError, UdpCopyClientRecv};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{RecvMsgHdr, UdpCopyPacket, UdpCopyPacketMeta};
use g3_socks::v5::UdpInput;
@ -166,6 +167,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,

View file

@ -24,6 +24,7 @@ use g3_io_ext::{AsyncUdpSend, UdpCopyClientError, UdpCopyClientSend};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use g3_io_ext::{SendMsgHdr, UdpCopyPacket};
use g3_socks::v5::UdpOutput;
@ -107,4 +108,32 @@ where
Poll::Ready(Ok(count))
}
}
#[cfg(target_os = "macos")]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,
packets: &[UdpCopyPacket],
) -> Poll<Result<usize, UdpCopyClientError>> {
let mut msgs: Vec<SendMsgHdr<2>> = packets
.iter()
.map(|p| {
SendMsgHdr::new(
[IoSlice::new(&self.socks5_header), IoSlice::new(p.payload())],
None,
)
})
.collect();
let count = ready!(self.inner.poll_batch_sendmsg_x(cx, &mut msgs))
.map_err(UdpCopyClientError::SendFailed)?;
if count == 0 {
Poll::Ready(Err(UdpCopyClientError::SendFailed(io::Error::new(
io::ErrorKind::WriteZero,
"write zero packet into sender",
))))
} else {
Poll::Ready(Ok(count))
}
}
}

View file

@ -25,6 +25,7 @@ use thiserror::Error;
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use super::UdpCopyPacket;
@ -61,6 +62,7 @@ pub trait UdpCopyClientRecv {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -83,6 +85,7 @@ pub trait UdpCopyClientSend {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_send_packets(
&mut self,

View file

@ -156,6 +156,7 @@ impl<'a, T: UdpCopyClientRecv + ?Sized> UdpCopyRecv for ClientRecv<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -191,6 +192,7 @@ impl<'a, T: UdpCopyRemoteRecv + ?Sized> UdpCopyRecv for RemoteRecv<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -252,6 +254,7 @@ impl<'a, T: UdpCopyClientSend + ?Sized> UdpCopySend for ClientSend<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_send_packets(
&mut self,
@ -283,6 +286,7 @@ impl<'a, T: UdpCopyRemoteSend + ?Sized> UdpCopySend for RemoteSend<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_send_packets(
&mut self,

View file

@ -25,6 +25,7 @@ use thiserror::Error;
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use super::UdpCopyPacket;
@ -61,6 +62,7 @@ pub trait UdpCopyRemoteRecv {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -83,6 +85,7 @@ pub trait UdpCopyRemoteSend {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_send_packets(
&mut self,

View file

@ -0,0 +1,34 @@
/*
* Copyright 2024 ByteDance and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use libc::{c_int, c_uint, c_void, iovec, size_t, socklen_t, ssize_t};
#[repr(C)]
pub(super) struct msghdr_x {
pub msg_name: *mut c_void,
pub msg_namelen: socklen_t,
pub msg_iov: *mut iovec,
pub msg_iovlen: c_int,
pub msg_control: *mut c_void,
pub msg_controllen: socklen_t,
pub msg_flags: c_int,
pub msg_datalen: size_t,
}
extern "C" {
pub(super) fn sendmsg_x(s: c_int, msgp: *mut msghdr_x, cnt: c_uint, flags: c_int) -> ssize_t;
pub(super) fn recvmsg_x(s: c_int, msgp: *mut msghdr_x, cnt: c_uint, flags: c_int) -> ssize_t;
}

View file

@ -20,6 +20,8 @@ use std::io::{self, IoSlice};
use std::net::SocketAddr;
use std::task::{Context, Poll};
#[cfg(target_os = "macos")]
mod macos;
#[cfg(unix)]
mod unix;
#[cfg(unix)]
@ -58,12 +60,20 @@ pub trait UdpSocketExt {
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>>;
#[cfg(target_os = "macos")]
fn poll_batch_sendmsg_x<const C: usize>(
&self,
cx: &mut Context<'_>,
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>>;
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_batch_recvmsg<const C: usize>(
&self,

View file

@ -15,12 +15,11 @@
*/
use std::cell::UnsafeCell;
use std::io::{self, IoSlice, IoSliceMut};
use std::mem;
use std::io::{IoSlice, IoSliceMut};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::fd::AsFd;
use std::ptr;
use std::task::{ready, Context, Poll};
use std::{io, mem, ptr};
use rustix::net::{
recvmsg, sendmsg, sendmsg_v4, sendmsg_v6, RecvAncillaryBuffer, RecvFlags, SendAncillaryBuffer,
@ -140,6 +139,17 @@ impl<'a, const C: usize> SendMsgHdr<'a, C> {
h.msg_iovlen = C as _;
h
}
/// # Safety
///
/// `self` should not be dropped before the returned value
#[cfg(target_os = "macos")]
unsafe fn to_msghdr_x(&self) -> super::macos::msghdr_x {
let mut h = mem::zeroed::<super::macos::msghdr_x>();
h.msg_iov = self.iov.as_ptr() as _;
h.msg_iovlen = C as _;
h
}
}
impl<'a, const C: usize> AsRef<[IoSlice<'a>]> for SendMsgHdr<'a, C> {
@ -182,6 +192,22 @@ impl<'a, const C: usize> RecvMsgHdr<'a, C> {
h.msg_iovlen = C as _;
h
}
/// # Safety
///
/// `self` should not be dropped before the returned value
#[cfg(target_os = "macos")]
unsafe fn to_msghdr_x(&self) -> super::macos::msghdr_x {
let c_addr = &mut *self.c_addr.get();
let (c_addr, c_addr_len) = c_addr.get_ptr_and_size();
let mut h = mem::zeroed::<super::macos::msghdr_x>();
h.msg_name = c_addr as _;
h.msg_namelen = c_addr_len as _;
h.msg_iov = self.iov.as_ptr() as _;
h.msg_iovlen = C as _;
h
}
}
impl UdpSocketExt for UdpSocket {
@ -342,6 +368,53 @@ impl UdpSocketExt for UdpSocket {
}
}
#[cfg(target_os = "macos")]
fn poll_batch_sendmsg_x<const C: usize>(
&self,
cx: &mut Context<'_>,
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>> {
use smallvec::SmallVec;
use std::os::fd::AsRawFd;
let mut msgvec: SmallVec<[_; 32]> = SmallVec::with_capacity(msgs.len());
for m in msgs.iter_mut() {
msgvec.push(unsafe { m.to_msghdr_x() });
}
let raw_fd = self.as_raw_fd();
let flags = libc::MSG_DONTWAIT;
let mut sendmsg_x = || {
let r = unsafe {
super::macos::sendmsg_x(raw_fd, msgvec.as_mut_ptr(), msgvec.len() as _, flags as _)
};
if r < 0 {
Err(io::Error::last_os_error())
} else {
Ok(r as usize)
}
};
loop {
ready!(self.poll_send_ready(cx))?;
match self.try_io(Interest::WRITABLE, &mut sendmsg_x) {
Ok(count) => {
for m in msgs.iter_mut().take(count) {
m.n_send = m.iov.iter().map(|iov| iov.len()).sum();
}
return Poll::Ready(Ok(count));
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
continue;
} else {
return Poll::Ready(Err(e));
}
}
}
}
}
#[cfg(any(
target_os = "linux",
target_os = "android",
@ -402,6 +475,57 @@ impl UdpSocketExt for UdpSocket {
}
}
}
#[cfg(target_os = "macos")]
fn poll_batch_recvmsg<const C: usize>(
&self,
cx: &mut Context<'_>,
hdr_v: &mut [RecvMsgHdr<'_, C>],
) -> Poll<io::Result<usize>> {
use smallvec::SmallVec;
use std::os::fd::AsRawFd;
let mut msgvec: SmallVec<[_; 32]> = SmallVec::with_capacity(hdr_v.len());
for m in hdr_v.iter_mut() {
msgvec.push(unsafe { m.to_msghdr_x() });
}
let raw_fd = self.as_raw_fd();
let mut recvmsg_x = || {
let r = unsafe {
super::macos::recvmsg_x(
raw_fd,
msgvec.as_mut_ptr(),
msgvec.len() as _,
libc::MSG_DONTWAIT as _,
)
};
if r < 0 {
Err(io::Error::last_os_error())
} else {
Ok(r as usize)
}
};
loop {
ready!(self.poll_recv_ready(cx))?;
match self.try_io(Interest::READABLE, &mut recvmsg_x) {
Ok(count) => {
for (m, h) in hdr_v.iter_mut().take(count).zip(msgvec) {
m.n_recv = h.msg_datalen;
}
return Poll::Ready(Ok(count));
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
continue;
} else {
return Poll::Ready(Err(e));
}
}
}
}
}
}
#[cfg(test)]
@ -415,6 +539,7 @@ mod tests {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
#[tokio::test]
async fn batch_msg_connect() {
@ -433,9 +558,14 @@ mod tests {
SendMsgHdr::new([IoSlice::new(msg_2)], None),
];
#[cfg(not(target_os = "macos"))]
let count = poll_fn(|cx| c_sock.poll_batch_sendmsg(cx, &mut msgs))
.await
.unwrap();
#[cfg(target_os = "macos")]
let count = poll_fn(|cx| c_sock.poll_batch_sendmsg_x(cx, &mut msgs))
.await
.unwrap();
assert_eq!(count, 2);
assert_eq!(msgs[0].n_send, msg_1.len());
assert_eq!(msgs[1].n_send, msg_2.len());

View file

@ -30,6 +30,7 @@ use tokio::time::{Instant, Sleep};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use super::RecvMsgHdr;
use crate::limit::{DatagramLimitAction, DatagramLimiter};
@ -50,6 +51,7 @@ pub trait AsyncUdpRecv {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_batch_recvmsg<const C: usize>(
&mut self,
@ -216,6 +218,7 @@ where
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_batch_recvmsg<const C: usize>(
&mut self,

View file

@ -27,6 +27,7 @@ use g3_types::net::UpstreamAddr;
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use super::UdpRelayPacket;
@ -65,6 +66,7 @@ pub trait UdpRelayClientRecv {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -88,6 +90,7 @@ pub trait UdpRelayClientSend {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_send_packets(
&mut self,

View file

@ -174,6 +174,7 @@ impl<'a, T: UdpRelayClientRecv + ?Sized> UdpRelayRecv for ClientRecv<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -210,6 +211,7 @@ impl<'a, T: UdpRelayRemoteRecv + ?Sized> UdpRelayRecv for RemoteRecv<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -271,6 +273,7 @@ impl<'a, T: UdpRelayClientSend + ?Sized> UdpRelaySend for ClientSend<'a, T> {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_send_packets(
&mut self,
@ -296,13 +299,6 @@ impl<'a, T: UdpRelayRemoteSend + ?Sized> UdpRelaySend for RemoteSend<'a, T> {
.map_err(|e| UdpRelayError::RemoteError(Some(packet.ups.clone()), e))
}
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
))]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,

View file

@ -24,13 +24,6 @@ use thiserror::Error;
use g3_resolver::ResolveError;
use g3_types::net::UpstreamAddr;
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
))]
use super::UdpRelayPacket;
#[derive(Error, Debug)]
@ -77,6 +70,7 @@ pub trait UdpRelayRemoteRecv {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv_packets(
&mut self,
@ -94,16 +88,25 @@ pub trait UdpRelayRemoteSend {
to: &UpstreamAddr,
) -> Poll<Result<usize, UdpRelayRemoteError>>;
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
))]
fn poll_send_packets(
&mut self,
cx: &mut Context<'_>,
packets: &[UdpRelayPacket],
) -> Poll<Result<usize, UdpRelayRemoteError>>;
) -> Poll<Result<usize, UdpRelayRemoteError>> {
let mut count = 0;
for packet in packets {
match self.poll_send_packet(cx, packet.payload(), packet.upstream()) {
Poll::Pending => {
return if count > 0 {
Poll::Ready(Ok(count))
} else {
Poll::Pending
};
}
Poll::Ready(Ok(_)) => count += 1,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(count))
}
}

View file

@ -30,6 +30,7 @@ use tokio::time::{Instant, Sleep};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use super::SendMsgHdr;
use crate::limit::{DatagramLimitAction, DatagramLimiter};
@ -64,6 +65,13 @@ pub trait AsyncUdpSend {
cx: &mut Context<'_>,
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>>;
#[cfg(target_os = "macos")]
fn poll_batch_sendmsg_x<const C: usize>(
&mut self,
cx: &mut Context<'_>,
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>>;
}
pub struct LimitedUdpSend<T> {
@ -345,4 +353,72 @@ where
Poll::Ready(Ok(count))
}
}
#[cfg(target_os = "macos")]
fn poll_batch_sendmsg_x<const C: usize>(
&mut self,
cx: &mut Context<'_>,
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>> {
use smallvec::SmallVec;
if self.limit.is_set() {
let dur_millis = self.started.elapsed().as_millis() as u64;
let mut total_size_v = SmallVec::<[usize; 32]>::with_capacity(msgs.len());
let mut total_size = 0;
for msg in msgs.iter() {
total_size += msg.iov.iter().map(|v| v.len()).sum::<usize>();
total_size_v.push(total_size);
}
match self.limit.check_packets(dur_millis, total_size_v.as_ref()) {
DatagramLimitAction::Advance(n) => {
match self.inner.poll_batch_sendmsg_x(cx, &mut msgs[0..n]) {
Poll::Ready(Ok(count)) => {
let len = msgs.iter().take(count).map(|v| v.n_send).sum();
self.limit.set_advance(count, len);
self.stats.add_send_packets(count);
self.stats.add_send_bytes(len);
Poll::Ready(Ok(count))
}
Poll::Ready(Err(e)) => {
self.limit.release_global();
Poll::Ready(Err(e))
}
Poll::Pending => {
self.limit.release_global();
Poll::Pending
}
}
}
DatagramLimitAction::DelayUntil(t) => {
self.delay.as_mut().reset(t);
match self.delay.poll_unpin(cx) {
Poll::Ready(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
DatagramLimitAction::DelayFor(ms) => {
self.delay
.as_mut()
.reset(self.started + Duration::from_millis(dur_millis + ms));
match self.delay.poll_unpin(cx) {
Poll::Ready(_) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
}
} else {
let count = ready!(self.inner.poll_batch_sendmsg_x(cx, msgs))?;
self.stats.add_send_packets(count);
self.stats
.add_send_bytes(msgs.iter().take(count).map(|h| h.n_send).sum());
Poll::Ready(Ok(count))
}
}
}

View file

@ -31,6 +31,7 @@ use super::{AsyncUdpRecv, AsyncUdpSend, UdpSocketExt};
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
use super::{RecvMsgHdr, SendMsgHdr};
@ -114,6 +115,15 @@ impl AsyncUdpSend for SendHalf {
) -> Poll<io::Result<usize>> {
self.0.poll_batch_sendmsg(cx, msgs)
}
#[cfg(target_os = "macos")]
fn poll_batch_sendmsg_x<const C: usize>(
&mut self,
cx: &mut Context<'_>,
msgs: &mut [SendMsgHdr<'_, C>],
) -> Poll<io::Result<usize>> {
self.0.poll_batch_sendmsg_x(cx, msgs)
}
}
impl RecvHalf {
@ -149,6 +159,7 @@ impl AsyncUdpRecv for RecvHalf {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_batch_recvmsg<const C: usize>(
&mut self,

View file

@ -201,6 +201,7 @@ impl AsyncUdpSocket for Socks5UdpSocket {
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd",
target_os = "macos",
))]
fn poll_recv(
&self,
@ -289,7 +290,7 @@ impl AsyncUdpSocket for Socks5UdpSocket {
}
}
#[cfg(target_os = "macos")]
#[cfg(target_os = "dragonfly")]
fn poll_recv(
&self,
cx: &mut Context,

View file

@ -70,7 +70,25 @@ impl Sinker {
Ok(())
}
#[cfg(any(windows, target_os = "macos", target_os = "dragonfly"))]
#[cfg(target_os = "macos")]
async fn send_udp(&self, packets: &[Vec<u8>]) -> io::Result<()> {
use g3_io_ext::{SendMsgHdr, UdpSocketExt};
use std::future::poll_fn;
use std::io::IoSlice;
let mut msgs: Vec<_> = packets
.iter()
.map(|v| SendMsgHdr::new([IoSlice::new(v.as_slice())], None))
.collect();
let mut offset = 0;
while offset < msgs.len() {
offset +=
poll_fn(|cx| self.socket.poll_batch_sendmsg_x(cx, &mut msgs[offset..])).await?;
}
Ok(())
}
#[cfg(any(windows, target_os = "dragonfly"))]
async fn send_udp(&self, packets: &[Vec<u8>]) -> io::Result<()> {
for pkt in packets {
self.socket.send(pkt.as_slice()).await?;