diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4758e56e..243311d1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,7 +38,7 @@ jobs: os: windows-latest rust: stable steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust }} @@ -52,7 +52,7 @@ jobs: name: Rustfmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt @@ -110,7 +110,7 @@ jobs: - x86_64-unknown-redox - wasm32-wasip2 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@nightly with: components: rust-src @@ -119,11 +119,44 @@ jobs: run: cargo hack check -Z build-std=std,panic_abort --feature-powerset --target ${{ matrix.target }} - name: Check docs run: RUSTDOCFLAGS="-D warnings --cfg docsrs" cargo doc -Z build-std=std,panic_abort --no-deps --all-features --target ${{ matrix.target }} + Cross: + name: Cross-test (${{ matrix.target }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + # 32-bit Linux: size_t=4 → cmsg_len is 4 bytes, CMSG_ALIGN factor=4. + # Exercises a different CMSG_* layout than x86_64 (factor=8). + - target: i686-unknown-linux-gnu + rust: stable + # 64-bit ARM Linux: same CMSG_ALIGN factor as x86_64 but different ABI. + - target: aarch64-unknown-linux-gnu + rust: stable + # 32-bit ARM Linux: like i686 but a distinct architecture. + - target: armv7-unknown-linux-gnueabihf + rust: stable + steps: + - uses: actions/checkout@v6 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + - uses: taiki-e/install-action@cross + - name: Run cmsg tests (cross + QEMU) + run: | + cross test --target ${{ matrix.target }} --all-features -- cmsg + cross test --target ${{ matrix.target }} --all-features -- control_message + - name: Run cmsg tests release (cross + QEMU) + run: | + cross test --target ${{ matrix.target }} --all-features --release -- cmsg + cross test --target ${{ matrix.target }} --all-features --release -- control_message + Clippy: name: Clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: clippy @@ -144,7 +177,7 @@ jobs: # the README for details: https://github.com/awslabs/cargo-check-external-types - nightly-2024-06-30 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Install Rust ${{ matrix.rust }} uses: dtolnay/rust-toolchain@stable with: diff --git a/src/cmsg.rs b/src/cmsg.rs new file mode 100644 index 00000000..adb1fa76 --- /dev/null +++ b/src/cmsg.rs @@ -0,0 +1,202 @@ +use std::fmt; +use std::mem; + +/// Returns the space required in a control message buffer for a single message +/// with `data_len` bytes of ancillary data. +/// +/// Returns `None` if `data_len` does not fit in `libc::c_uint`. +/// +/// Corresponds to `CMSG_SPACE(3)`. +pub fn cmsg_space(data_len: usize) -> Option { + let len = libc::c_uint::try_from(data_len).ok()?; + // SAFETY: pure arithmetic. + usize::try_from(unsafe { libc::CMSG_SPACE(len) }).ok() +} + +/// A control message parsed from a `recvmsg(2)` control buffer. +/// +/// Returned by [`ControlMessages`]. +pub struct ControlMessage<'a> { + cmsg_level: i32, + cmsg_type: i32, + data: &'a [u8], +} + +impl<'a> ControlMessage<'a> { + /// Corresponds to `cmsg_level` in `cmsghdr`. + pub fn cmsg_level(&self) -> i32 { + self.cmsg_level + } + + /// Corresponds to `cmsg_type` in `cmsghdr`. + pub fn cmsg_type(&self) -> i32 { + self.cmsg_type + } + + /// The ancillary data payload. + /// + /// Corresponds to the data portion following the `cmsghdr`. + pub fn data(&self) -> &'a [u8] { + self.data + } +} + +impl<'a> fmt::Debug for ControlMessage<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "ControlMessage".fmt(fmt) + } +} + +/// Iterator over control messages in a `recvmsg(2)` control buffer. +/// +/// See [`crate::MsgHdrMut::with_control`] and [`crate::MsgHdrMut::control_len`]. +pub struct ControlMessages<'a> { + buf: &'a [u8], + offset: usize, +} + +impl<'a> ControlMessages<'a> { + /// Create a new `ControlMessages` from the filled control buffer. + /// + /// Pass `&raw_buf[..msg.control_len()]` where `raw_buf` is the slice + /// passed to [`crate::MsgHdrMut::with_control`] before calling `recvmsg(2)`. + pub fn new(buf: &'a [u8]) -> Self { + Self { buf, offset: 0 } + } +} + +impl<'a> Iterator for ControlMessages<'a> { + type Item = ControlMessage<'a>; + + #[allow(clippy::useless_conversion)] + fn next(&mut self) -> Option { + let hdr_size = mem::size_of::(); + // SAFETY: pure arithmetic; gives CMSG_ALIGN(sizeof(cmsghdr)). + let data_offset: usize = + usize::try_from(unsafe { libc::CMSG_LEN(0) }).unwrap_or(usize::MAX); + + if self.offset + hdr_size > self.buf.len() { + return None; + } + + // SAFETY: range is within `buf`; read_unaligned handles any alignment. + let cmsg: libc::cmsghdr = unsafe { + std::ptr::read_unaligned(self.buf.as_ptr().add(self.offset) as *const libc::cmsghdr) + }; + + let total_len = usize::try_from(cmsg.cmsg_len).unwrap_or(0); + if total_len < data_offset { + return None; + } + let data_len = total_len - data_offset; + + let data_abs_start = self.offset + data_offset; + let data_abs_end = data_abs_start.saturating_add(data_len); + if data_abs_end > self.buf.len() { + return None; + } + + let item = ControlMessage { + cmsg_level: cmsg.cmsg_level, + cmsg_type: cmsg.cmsg_type, + data: &self.buf[data_abs_start..data_abs_end], + }; + + // SAFETY: pure arithmetic; CMSG_SPACE(data_len) == CMSG_ALIGN(total_len). + let advance = match libc::c_uint::try_from(data_len) { + Ok(dl) => usize::try_from(unsafe { libc::CMSG_SPACE(dl) }).unwrap_or(usize::MAX), + Err(_) => return None, + }; + self.offset = self.offset.saturating_add(advance); + + Some(item) + } +} + +impl<'a> fmt::Debug for ControlMessages<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "ControlMessages".fmt(fmt) + } +} + +/// Builds a control message buffer for use with `sendmsg(2)`. +/// +/// See [`crate::MsgHdr::with_control`] and [`cmsg_space`]. +pub struct ControlMessageEncoder<'a> { + buf: &'a mut [u8], + len: usize, +} + +impl<'a> ControlMessageEncoder<'a> { + /// Create a new `ControlMessageEncoder` backed by `buf`. + /// + /// Zeroes `buf` on creation to ensure padding bytes are clean. + /// Allocate `buf` with the sum of [`cmsg_space`] for each intended message. + pub fn new(buf: &'a mut [u8]) -> Self { + buf.fill(0); + Self { buf, len: 0 } + } + + /// Append a control message carrying `data`. + /// + /// Returns `Err` if `data` exceeds `c_uint::MAX` or the buffer is too small. + pub fn push(&mut self, cmsg_level: i32, cmsg_type: i32, data: &[u8]) -> std::io::Result<()> { + let data_len_uint = libc::c_uint::try_from(data.len()).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "ancillary data payload too large (exceeds c_uint::MAX)", + ) + })?; + // SAFETY: pure arithmetic. + let space: usize = + usize::try_from(unsafe { libc::CMSG_SPACE(data_len_uint) }).unwrap_or(usize::MAX); + if self.len + space > self.buf.len() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "control message buffer too small", + )); + } + // SAFETY: pure arithmetic. + let cmsg_len = unsafe { libc::CMSG_LEN(data_len_uint) }; + unsafe { + // SAFETY: offset is within buf; write_unaligned handles alignment 1. + // Use zeroed() + field assignment to handle platform-specific padding + // (e.g. musl adds __pad1); buf is pre-zeroed but the write must be + // self-contained for correctness. + let cmsg_ptr = self.buf.as_mut_ptr().add(self.len) as *mut libc::cmsghdr; + let mut hdr: libc::cmsghdr = mem::zeroed(); + hdr.cmsg_len = cmsg_len as _; + hdr.cmsg_level = cmsg_level; + hdr.cmsg_type = cmsg_type; + std::ptr::write_unaligned(cmsg_ptr, hdr); + // SAFETY: CMSG_DATA gives the correct offset past alignment padding. + let data_ptr = libc::CMSG_DATA(cmsg_ptr); + std::ptr::copy_nonoverlapping(data.as_ptr(), data_ptr, data.len()); + } + self.len += space; + Ok(()) + } + + /// Returns the encoded bytes. + /// + /// Corresponds to the slice to pass to [`crate::MsgHdr::with_control`]. + pub fn as_bytes(&self) -> &[u8] { + &self.buf[..self.len] + } + + /// Returns the number of bytes written. + pub fn len(&self) -> usize { + self.len + } + + /// Returns `true` if no control messages have been pushed. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +impl<'a> fmt::Debug for ControlMessageEncoder<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "ControlMessageEncoder".fmt(fmt) + } +} diff --git a/src/lib.rs b/src/lib.rs index b846288f..8788629a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -172,6 +172,11 @@ macro_rules! man_links { }; } +#[cfg(all( + unix, + not(any(target_os = "redox", target_os = "vita", target_os = "wasi")) +))] +mod cmsg; mod sockaddr; mod socket; mod sockref; @@ -188,6 +193,11 @@ compile_error!("Socket2 doesn't support the compile target"); use sys::c_int; +#[cfg(all( + unix, + not(any(target_os = "redox", target_os = "vita", target_os = "wasi")) +))] +pub use cmsg::{cmsg_space, ControlMessage, ControlMessageEncoder, ControlMessages}; pub use sockaddr::{sa_family_t, socklen_t, SockAddr, SockAddrStorage}; #[cfg(not(any( target_os = "haiku", diff --git a/tests/socket.rs b/tests/socket.rs index 6dc1aea0..a08b85c3 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -1970,3 +1970,183 @@ fn set_busy_poll() { assert!(socket.busy_poll().unwrap() == i); } } + +#[cfg(all( + unix, + not(any(target_os = "redox", target_os = "vita", target_os = "wasi")) +))] +#[test] +fn cmsg_space_nonzero() { + // cmsg_space(0) must be at least sizeof(cmsghdr); any positive data length + // must produce a larger result. + let space0 = socket2::cmsg_space(0).expect("cmsg_space(0) should be Some"); + let space4 = socket2::cmsg_space(4).expect("cmsg_space(4) should be Some"); + assert!(space0 > 0, "cmsg_space(0) should cover the cmsghdr header"); + assert!( + space4 > space0, + "cmsg_space(4) should be larger than cmsg_space(0)" + ); + // Overflow path: data_len > c_uint::MAX must return None. + #[cfg(target_pointer_width = "64")] + assert!( + socket2::cmsg_space(usize::MAX).is_none(), + "cmsg_space(usize::MAX) should return None" + ); +} + +#[cfg(all( + unix, + not(any(target_os = "redox", target_os = "vita", target_os = "wasi")) +))] +#[test] +fn control_message_encoder_roundtrip() { + use socket2::{cmsg_space, ControlMessageEncoder, ControlMessages}; + + let level: libc::c_int = libc::SOL_SOCKET; + let ty: libc::c_int = 0x1234; // arbitrary type for the test + let payload: &[u8] = &[1u8, 2, 3, 4]; + + let space = cmsg_space(payload.len()).expect("payload fits in c_uint"); + let mut buf = vec![0u8; space]; + let mut enc = ControlMessageEncoder::new(&mut buf); + assert!(enc.is_empty()); + enc.push(level, ty, payload).expect("push should succeed"); + assert!(!enc.is_empty()); + assert_eq!(enc.len(), space); + + // Decode what we encoded. + let msgs: Vec<_> = ControlMessages::new(enc.as_bytes()).collect(); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].cmsg_level(), level); + assert_eq!(msgs[0].cmsg_type(), ty); + assert_eq!(msgs[0].data(), payload); +} + +#[cfg(all( + unix, + not(any(target_os = "redox", target_os = "vita", target_os = "wasi")) +))] +#[test] +fn control_message_encoder_multiple() { + use socket2::{cmsg_space, ControlMessageEncoder, ControlMessages}; + + let entries: &[(libc::c_int, libc::c_int, &[u8])] = &[ + (libc::SOL_SOCKET, 1, &[0xAA, 0xBB]), + (libc::SOL_SOCKET, 2, &[0x11, 0x22, 0x33, 0x44]), + (libc::IPPROTO_IP, 3, &[0xFF]), + ]; + + let total: usize = entries + .iter() + .map(|(_, _, d)| cmsg_space(d.len()).expect("payload fits in c_uint")) + .sum(); + let mut buf = vec![0u8; total]; + let mut enc = ControlMessageEncoder::new(&mut buf); + + for (lvl, ty, data) in entries { + enc.push(*lvl, *ty, data).expect("push should succeed"); + } + + let msgs: Vec<_> = ControlMessages::new(enc.as_bytes()).collect(); + assert_eq!(msgs.len(), entries.len()); + for (i, (lvl, ty, data)) in entries.iter().enumerate() { + assert_eq!(msgs[i].cmsg_level(), *lvl); + assert_eq!(msgs[i].cmsg_type(), *ty); + assert_eq!(msgs[i].data(), *data); + } +} + +#[cfg(all( + unix, + not(any(target_os = "redox", target_os = "vita", target_os = "wasi")) +))] +#[test] +fn control_message_encoder_overflow() { + use socket2::{cmsg_space, ControlMessageEncoder}; + + let payload: &[u8] = &[1, 2, 3, 4]; + // Allocate space for only one message, then try to push two. + let mut buf = vec![0u8; cmsg_space(payload.len()).unwrap()]; + let mut enc = ControlMessageEncoder::new(&mut buf); + enc.push(libc::SOL_SOCKET, 1, payload) + .expect("first push ok"); + let result = enc.push(libc::SOL_SOCKET, 2, payload); + assert!( + result.is_err(), + "second push should fail — buffer too small" + ); +} + +/// End-to-end test: send a byte plus SCM_CREDENTIALS over a Unix socket pair, +/// then receive and verify the credential ancillary data. +#[cfg(target_os = "linux")] +#[test] +fn sendmsg_recvmsg_scm_credentials() { + use socket2::{ + cmsg_space, ControlMessageEncoder, ControlMessages, Domain, MaybeUninitSlice, MsgHdr, + MsgHdrMut, Socket, Type, + }; + use std::mem::MaybeUninit; + + // Enable SO_PASSCRED so the kernel attaches credentials. + let receiver = Socket::new(Domain::UNIX, Type::DGRAM, None).unwrap(); + receiver.set_passcred(true).unwrap(); + let path = std::env::temp_dir().join(format!("socket2_test_{}", std::process::id())); + let _ = std::fs::remove_file(&path); + let addr = socket2::SockAddr::unix(&path).unwrap(); + receiver.bind(&addr).unwrap(); + + let sender = Socket::new(Domain::UNIX, Type::DGRAM, None).unwrap(); + + // Build a sendmsg with one byte of data and an SCM_CREDENTIALS cmsg. + let cred = libc::ucred { + pid: unsafe { libc::getpid() }, + uid: unsafe { libc::getuid() }, + gid: unsafe { libc::getgid() }, + }; + let cred_bytes = unsafe { + std::slice::from_raw_parts( + &cred as *const libc::ucred as *const u8, + std::mem::size_of::(), + ) + }; + + let mut ctrl_buf = vec![0u8; cmsg_space(cred_bytes.len()).unwrap()]; + let mut enc = ControlMessageEncoder::new(&mut ctrl_buf); + enc.push(libc::SOL_SOCKET, libc::SCM_CREDENTIALS, cred_bytes) + .unwrap(); + + let data = b"x"; + let send_bufs = [std::io::IoSlice::new(data)]; + let msg = MsgHdr::new() + .with_addr(&addr) + .with_buffers(&send_bufs) + .with_control(enc.as_bytes()); + sender.sendmsg(&msg, 0).unwrap(); + + // Receive with a large enough control buffer. + let mut recv_data = [MaybeUninit::uninit(); 16]; + let ctrl_cap = cmsg_space(std::mem::size_of::()).unwrap(); + let mut recv_ctrl = vec![MaybeUninit::::uninit(); ctrl_cap]; + let mut recv_bufs = [MaybeUninitSlice::new(&mut recv_data)]; + + let mut recv_msg = MsgHdrMut::new() + .with_buffers(&mut recv_bufs) + .with_control(&mut recv_ctrl); + + let n = receiver.recvmsg(&mut recv_msg, 0).unwrap(); + assert_eq!(n, 1); + + let ctrl_len = recv_msg.control_len(); + let filled = unsafe { std::slice::from_raw_parts(recv_ctrl.as_ptr() as *const u8, ctrl_len) }; + + let msgs: Vec<_> = ControlMessages::new(filled).collect(); + assert!(!msgs.is_empty(), "expected at least one control message"); + + let found = msgs + .iter() + .any(|m| m.cmsg_level() == libc::SOL_SOCKET && m.cmsg_type() == libc::SCM_CREDENTIALS); + assert!(found, "SCM_CREDENTIALS cmsg not found"); + + let _ = std::fs::remove_file(&path); +}