x11rb/rust_connection/
stream.rs

1use rustix::fd::{AsFd, BorrowedFd};
2use std::io::{IoSlice, Result};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd};
6#[cfg(unix)]
7use std::os::unix::net::UnixStream;
8#[cfg(windows)]
9use std::os::windows::io::{
10    AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket,
11};
12
13use crate::utils::RawFdContainer;
14use x11rb_protocol::parse_display::ConnectAddress;
15use x11rb_protocol::xauth::Family;
16
17/// The kind of operation that one want to poll for.
18#[derive(Debug, Clone, Copy)]
19pub enum PollMode {
20    /// Check if the stream is readable, i.e. there is pending data to be read.
21    Readable,
22
23    /// Check if the stream is writable, i.e. some data could be successfully written to it.
24    Writable,
25
26    /// Check for both readability and writability.
27    ReadAndWritable,
28}
29
30impl PollMode {
31    /// Does this poll mode include readability?
32    pub fn readable(self) -> bool {
33        match self {
34            PollMode::Readable | PollMode::ReadAndWritable => true,
35            PollMode::Writable => false,
36        }
37    }
38
39    /// Does this poll mode include writability?
40    pub fn writable(self) -> bool {
41        match self {
42            PollMode::Writable | PollMode::ReadAndWritable => true,
43            PollMode::Readable => false,
44        }
45    }
46}
47
48/// A trait used to implement the raw communication with the X11 server.
49///
50/// None of the functions of this trait shall return [`std::io::ErrorKind::Interrupted`].
51/// If a system call fails with this error, the implementation should try again.
52pub trait Stream {
53    /// Waits for level-triggered read and/or write events on the stream.
54    ///
55    /// This function does not return what caused it to complete the poll.
56    /// Instead, callers should try to read or write and check for
57    /// [`std::io::ErrorKind::WouldBlock`].
58    ///
59    /// This function is allowed to spuriously return even if the stream
60    /// is neither readable nor writable. However, it shall not do it
61    /// continuously, which would cause a 100% CPU usage.
62    ///
63    /// # Multithreading
64    ///
65    /// If `Self` is `Send + Sync` and `poll` is used concurrently from more than
66    /// one thread, all threads should wake when the stream becomes readable (when
67    /// `read` is `true`) or writable (when `write` is `true`).
68    fn poll(&self, mode: PollMode) -> Result<()>;
69
70    /// Read some bytes and FDs from this reader without blocking, returning how many bytes
71    /// were read.
72    ///
73    /// This function works like [`std::io::Read::read`], but also supports the reception of file
74    /// descriptors. Any received file descriptors are appended to the given `fd_storage`.
75    /// Whereas implementation of [`std::io::Read::read`] are allowed to block or not to block,
76    /// this method shall never block and return `ErrorKind::WouldBlock` if needed.
77    ///
78    /// This function does not guarantee that all file descriptors were sent together with the data
79    /// with which they are received. However, file descriptors may not be received later than the
80    /// data that was sent at the same time. Instead, file descriptors may only be received
81    /// earlier.
82    ///
83    /// # Multithreading
84    ///
85    /// If `Self` is `Send + Sync` and `read` is used concurrently from more than one thread:
86    ///
87    /// * Both the data and the file descriptors shall be read in order, but possibly
88    /// interleaved across threads.
89    /// * Neither the data nor the file descriptors shall be duplicated.
90    /// * The returned value shall always be the actual number of bytes read into `buf`.
91    fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>;
92
93    /// Write a buffer and some FDs into this writer without blocking, returning how many
94    /// bytes were written.
95    ///
96    /// This function works like [`std::io::Write::write`], but also supports sending file
97    /// descriptors. The `fds` argument contains the file descriptors to send. The order of file
98    /// descriptors is maintained. Whereas implementation of [`std::io::Write::write`] are
99    /// allowed to block or not to block, this function must never block and return
100    /// `ErrorKind::WouldBlock` if needed.
101    ///
102    /// This function does not guarantee that all file descriptors are sent together with the data.
103    /// Any file descriptors that were sent are removed from the beginning of the given `Vec`.
104    ///
105    /// There is no guarantee that the given file descriptors are received together with the given
106    /// data. File descriptors might be received earlier than their corresponding data. It is not
107    /// allowed for file descriptors to be received later than the bytes that were sent at the same
108    /// time.
109    ///
110    /// # Multithreading
111    ///
112    /// If `Self` is `Send + Sync` and `write` is used concurrently from more than one thread:
113    ///
114    /// * Both the data and the file descriptors shall be written in order, but possibly
115    /// interleaved across threads.
116    /// * Neither the data nor the file descriptors shall be duplicated.
117    /// * The returned value shall always be the actual number of bytes written from `buf`.
118    fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>;
119
120    /// Like `write`, except that it writes from a slice of buffers. Like `write`, this
121    /// method must never block.
122    ///
123    /// This method must behave as a call to `write` with the buffers concatenated would.
124    ///
125    /// The default implementation calls `write` with the first nonempty buffer provided.
126    ///
127    /// # Multithreading
128    ///
129    /// Same as `write`.
130    fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
131        for buf in bufs {
132            if !buf.is_empty() {
133                return self.write(buf, fds);
134            }
135        }
136        Ok(0)
137    }
138}
139
140/// A wrapper around a `TcpStream` or `UnixStream`.
141///
142/// Use by default in `RustConnection` as stream.
143#[derive(Debug)]
144pub struct DefaultStream {
145    inner: DefaultStreamInner,
146}
147
148#[cfg(unix)]
149type DefaultStreamInner = RawFdContainer;
150
151#[cfg(not(unix))]
152type DefaultStreamInner = TcpStream;
153
154/// The address of a peer in a format suitable for xauth.
155///
156/// These values can be directly given to [`x11rb_protocol::xauth::get_auth`].
157type PeerAddr = (Family, Vec<u8>);
158
159impl DefaultStream {
160    /// Try to connect to the X11 server described by the given arguments.
161    pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> {
162        match addr {
163            ConnectAddress::Hostname(host, port) => {
164                // connect over TCP
165                let stream = TcpStream::connect((*host, *port))?;
166                Self::from_tcp_stream(stream)
167            }
168            #[cfg(unix)]
169            ConnectAddress::Socket(path) => {
170                // Try abstract unix socket first. If that fails, fall back to normal unix socket
171                #[cfg(any(target_os = "linux", target_os = "android"))]
172                if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) {
173                    // TODO: Does it make sense to add a constructor similar to from_unix_stream()?
174                    // If this is done: Move the set_nonblocking() from
175                    // connect_abstract_unix_stream() to that new function.
176                    let stream = DefaultStream { inner: stream };
177                    return Ok((stream, peer_addr::local()));
178                }
179
180                // connect over Unix domain socket
181                let stream = UnixStream::connect(path)?;
182                Self::from_unix_stream(stream)
183            }
184            #[cfg(not(unix))]
185            ConnectAddress::Socket(_) => {
186                // Unix domain sockets are not supported on Windows
187                Err(std::io::Error::new(
188                    std::io::ErrorKind::Other,
189                    "Unix domain sockets are not supported on Windows",
190                ))
191            }
192            _ => Err(std::io::Error::new(
193                std::io::ErrorKind::Other,
194                "The given address family is not implemented",
195            )),
196        }
197    }
198
199    /// Creates a new `Stream` from an already connected `TcpStream`.
200    ///
201    /// The stream will be set in non-blocking mode.
202    ///
203    /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`].
204    pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> {
205        let peer_addr = peer_addr::tcp(&stream.peer_addr()?);
206        stream.set_nonblocking(true)?;
207        let result = Self {
208            inner: stream.into(),
209        };
210        Ok((result, peer_addr))
211    }
212
213    /// Creates a new `Stream` from an already connected `UnixStream`.
214    ///
215    /// The stream will be set in non-blocking mode.
216    ///
217    /// This returns the peer address in a format suitable for [`x11rb_protocol::xauth::get_auth`].
218    #[cfg(unix)]
219    pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> {
220        stream.set_nonblocking(true)?;
221        let result = Self {
222            inner: stream.into(),
223        };
224        Ok((result, peer_addr::local()))
225    }
226
227    fn as_fd(&self) -> BorrowedFd<'_> {
228        self.inner.as_fd()
229    }
230}
231
232#[cfg(unix)]
233impl AsRawFd for DefaultStream {
234    fn as_raw_fd(&self) -> RawFd {
235        self.inner.as_raw_fd()
236    }
237}
238
239#[cfg(unix)]
240impl AsFd for DefaultStream {
241    fn as_fd(&self) -> BorrowedFd<'_> {
242        self.inner.as_fd()
243    }
244}
245
246#[cfg(unix)]
247impl IntoRawFd for DefaultStream {
248    fn into_raw_fd(self) -> RawFd {
249        self.inner.into_raw_fd()
250    }
251}
252
253#[cfg(unix)]
254impl From<DefaultStream> for OwnedFd {
255    fn from(stream: DefaultStream) -> Self {
256        stream.inner
257    }
258}
259
260#[cfg(windows)]
261impl AsRawSocket for DefaultStream {
262    fn as_raw_socket(&self) -> RawSocket {
263        self.inner.as_raw_socket()
264    }
265}
266
267#[cfg(windows)]
268impl AsSocket for DefaultStream {
269    fn as_socket(&self) -> BorrowedSocket<'_> {
270        self.inner.as_socket()
271    }
272}
273
274#[cfg(windows)]
275impl IntoRawSocket for DefaultStream {
276    fn into_raw_socket(self) -> RawSocket {
277        self.inner.into_raw_socket()
278    }
279}
280
281#[cfg(windows)]
282impl From<DefaultStream> for OwnedSocket {
283    fn from(stream: DefaultStream) -> Self {
284        stream.inner.into()
285    }
286}
287
288#[cfg(unix)]
289fn do_write(
290    stream: &DefaultStream,
291    bufs: &[IoSlice<'_>],
292    fds: &mut Vec<RawFdContainer>,
293) -> Result<usize> {
294    use rustix::io::Errno;
295    use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags};
296
297    fn sendmsg_wrapper(
298        fd: BorrowedFd<'_>,
299        iov: &[IoSlice<'_>],
300        cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>,
301        flags: SendFlags,
302    ) -> Result<usize> {
303        loop {
304            match sendmsg(fd, iov, cmsgs, flags) {
305                Ok(n) => return Ok(n),
306                // try again
307                Err(Errno::INTR) => {}
308                Err(e) => return Err(e.into()),
309            }
310        }
311    }
312
313    let fd = stream.as_fd();
314
315    let res = if !fds.is_empty() {
316        let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
317        let rights = SendAncillaryMessage::ScmRights(&fds);
318
319        let mut cmsg_space = vec![0u8; rights.size()];
320        let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
321        assert!(cmsg_buffer.push(rights));
322
323        sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())?
324    } else {
325        sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())?
326    };
327
328    // We successfully sent all FDs
329    fds.clear();
330
331    Ok(res)
332}
333
334impl Stream for DefaultStream {
335    fn poll(&self, mode: PollMode) -> Result<()> {
336        use rustix::event::{poll, PollFd, PollFlags};
337        use rustix::io::Errno;
338
339        let mut poll_flags = PollFlags::empty();
340        if mode.readable() {
341            poll_flags |= PollFlags::IN;
342        }
343        if mode.writable() {
344            poll_flags |= PollFlags::OUT;
345        }
346        let fd = self.as_fd();
347        let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)];
348        loop {
349            match poll(&mut poll_fds, -1) {
350                Ok(_) => break,
351                Err(Errno::INTR) => {}
352                Err(e) => return Err(e.into()),
353            }
354        }
355        // Let the errors (POLLERR) be handled when trying to read or write.
356        Ok(())
357    }
358
359    fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
360        #[cfg(unix)]
361        {
362            use rustix::io::Errno;
363            use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage};
364            use std::io::IoSliceMut;
365
366            // 1024 bytes on the stack should be enough for more file descriptors than the X server will ever
367            // send, as well as the header for the ancillary data. If you can find a case where this can
368            // overflow with an actual production X11 server, I'll buy you a steak dinner.
369            let mut cmsg = [0u8; 1024];
370            let mut iov = [IoSliceMut::new(buf)];
371            let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg);
372
373            let fd = self.as_fd();
374            let msg = loop {
375                match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) {
376                    Ok(msg) => break msg,
377                    // try again
378                    Err(Errno::INTR) => {}
379                    Err(e) => return Err(e.into()),
380                }
381            };
382
383            let fds_received = cmsg_buffer
384                .drain()
385                .filter_map(|cmsg| match cmsg {
386                    RecvAncillaryMessage::ScmRights(r) => Some(r),
387                    _ => None,
388                })
389                .flatten();
390
391            let mut cloexec_error = Ok(());
392            fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error));
393            cloexec_error?;
394
395            Ok(msg.bytes)
396        }
397        #[cfg(not(unix))]
398        {
399            use std::io::Read;
400            // No FDs are read, so nothing needs to be done with fd_storage
401            let _ = fd_storage;
402            loop {
403                // Use `impl Read for &TcpStream` to avoid needing a mutable `TcpStream`.
404                match (&mut &self.inner).read(buf) {
405                    Ok(n) => return Ok(n),
406                    // try again
407                    Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
408                    Err(e) => return Err(e),
409                }
410            }
411        }
412    }
413
414    fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
415        #[cfg(unix)]
416        {
417            do_write(self, &[IoSlice::new(buf)], fds)
418        }
419        #[cfg(not(unix))]
420        {
421            use std::io::{Error, ErrorKind, Write};
422            if !fds.is_empty() {
423                return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
424            }
425            loop {
426                // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`.
427                match (&mut &self.inner).write(buf) {
428                    Ok(n) => return Ok(n),
429                    // try again
430                    Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
431                    Err(e) => return Err(e),
432                }
433            }
434        }
435    }
436
437    fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
438        #[cfg(unix)]
439        {
440            do_write(self, bufs, fds)
441        }
442        #[cfg(not(unix))]
443        {
444            use std::io::{Error, ErrorKind, Write};
445            if !fds.is_empty() {
446                return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
447            }
448            loop {
449                // Use `impl Write for &TcpStream` to avoid needing a mutable `TcpStream`.
450                match (&mut &self.inner).write_vectored(bufs) {
451                    Ok(n) => return Ok(n),
452                    // try again
453                    Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
454                    Err(e) => return Err(e),
455                }
456            }
457        }
458    }
459}
460
461#[cfg(any(target_os = "linux", target_os = "android"))]
462fn connect_abstract_unix_stream(
463    path: &[u8],
464) -> std::result::Result<RawFdContainer, rustix::io::Errno> {
465    use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
466    use rustix::net::{
467        connect_unix, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType,
468    };
469
470    let socket = socket_with(
471        AddressFamily::UNIX,
472        SocketType::STREAM,
473        SocketFlags::CLOEXEC,
474        None,
475    )?;
476
477    connect_unix(&socket, &SocketAddrUnix::new_abstract_name(path)?)?;
478
479    // Make the FD non-blocking
480    fcntl_setfl(&socket, fcntl_getfl(&socket)? | OFlags::NONBLOCK)?;
481
482    Ok(socket)
483}
484
485/// Helper code to make sure that received FDs are marked as CLOEXEC
486#[cfg(any(
487    target_os = "android",
488    target_os = "dragonfly",
489    target_os = "freebsd",
490    target_os = "linux",
491    target_os = "netbsd",
492    target_os = "openbsd"
493))]
494mod recvmsg {
495    use super::RawFdContainer;
496    use rustix::net::RecvFlags;
497
498    pub(crate) fn flags() -> RecvFlags {
499        RecvFlags::CMSG_CLOEXEC
500    }
501
502    pub(crate) fn after_recvmsg<'a>(
503        fds: impl Iterator<Item = RawFdContainer> + 'a,
504        _cloexec_error: &'a mut Result<(), rustix::io::Errno>,
505    ) -> impl Iterator<Item = RawFdContainer> + 'a {
506        fds
507    }
508}
509
510/// Helper code to make sure that received FDs are marked as CLOEXEC
511#[cfg(all(
512    unix,
513    not(any(
514        target_os = "android",
515        target_os = "dragonfly",
516        target_os = "freebsd",
517        target_os = "linux",
518        target_os = "netbsd",
519        target_os = "openbsd"
520    ))
521))]
522mod recvmsg {
523    use super::RawFdContainer;
524    use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
525    use rustix::net::RecvFlags;
526
527    pub(crate) fn flags() -> RecvFlags {
528        RecvFlags::empty()
529    }
530
531    pub(crate) fn after_recvmsg<'a>(
532        fds: impl Iterator<Item = RawFdContainer> + 'a,
533        cloexec_error: &'a mut rustix::io::Result<()>,
534    ) -> impl Iterator<Item = RawFdContainer> + 'a {
535        fds.map(move |fd| {
536            if let Err(e) =
537                fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC))
538            {
539                *cloexec_error = Err(e);
540            }
541            fd
542        })
543    }
544}
545
546mod peer_addr {
547    use super::{Family, PeerAddr};
548    use std::net::{Ipv4Addr, SocketAddr};
549
550    // Get xauth information representing a local connection
551    pub(super) fn local() -> PeerAddr {
552        let hostname = crate::hostname()
553            .to_str()
554            .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
555        (Family::LOCAL, hostname)
556    }
557
558    // Get xauth information representing a TCP connection to the given address
559    pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr {
560        let ip = match addr {
561            SocketAddr::V4(addr) => *addr.ip(),
562            SocketAddr::V6(addr) => {
563                let ip = addr.ip();
564                if ip.is_loopback() {
565                    // This is a local connection.
566                    // Use LOCALHOST to cause a fall-through in the code below.
567                    Ipv4Addr::LOCALHOST
568                } else if let Some(ip) = ip.to_ipv4() {
569                    // Let the ipv4 code below handle this
570                    ip
571                } else {
572                    // Okay, this is really a v6 address
573                    return (Family::INTERNET6, ip.octets().to_vec());
574                }
575            }
576        };
577
578        // Handle the v4 address
579        if ip.is_loopback() {
580            local()
581        } else {
582            (Family::INTERNET, ip.octets().to_vec())
583        }
584    }
585}