wayland_backend/rs/
socket.rs

1//! Wayland socket manipulation
2
3use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
6use std::os::unix::net::UnixStream;
7use std::slice;
8
9use rustix::io::retry_on_intr;
10use rustix::net::{
11    recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
12    SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
13};
14
15use crate::protocol::{ArgumentType, Message};
16
17use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
18
19/// Maximum number of FD that can be sent in a single socket message
20pub const MAX_FDS_OUT: usize = 28;
21/// Maximum number of bytes that can be sent in a single socket message
22pub const MAX_BYTES_OUT: usize = 4096;
23
24/*
25 * Socket
26 */
27
28/// A wayland socket
29#[derive(Debug)]
30pub struct Socket {
31    stream: UnixStream,
32}
33
34impl Socket {
35    /// Send a single message to the socket
36    ///
37    /// A single socket message can contain several wayland messages
38    ///
39    /// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
40    /// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
41    /// end may lose some data.
42    pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
43        #[cfg(not(target_os = "macos"))]
44        let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
45        #[cfg(target_os = "macos")]
46        let flags = SendFlags::DONTWAIT;
47
48        if !fds.is_empty() {
49            let iov = [IoSlice::new(bytes)];
50            let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))];
51            let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
52            let fds =
53                unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
54            cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
55            Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
56        } else {
57            Ok(retry_on_intr(|| send(self, bytes, flags))?)
58        }
59    }
60
61    /// Receive a single message from the socket
62    ///
63    /// Return the number of bytes received and the number of Fds received.
64    ///
65    /// Errors with `WouldBlock` is no message is available.
66    ///
67    /// A single socket message can contain several wayland messages.
68    ///
69    /// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
70    /// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
71    /// be lost.
72    pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
73        #[cfg(not(target_os = "macos"))]
74        let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
75        #[cfg(target_os = "macos")]
76        let flags = RecvFlags::DONTWAIT;
77
78        let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
79        let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
80        let mut iov = [IoSliceMut::new(buffer)];
81        let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
82
83        let received_fds = cmsg_buffer
84            .drain()
85            .filter_map(|cmsg| match cmsg {
86                RecvAncillaryMessage::ScmRights(fds) => Some(fds),
87                _ => None,
88            })
89            .flatten();
90        fds.extend(received_fds);
91        #[cfg(target_os = "macos")]
92        for fd in fds.iter() {
93            if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
94                let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
95            }
96        }
97        Ok(msg.bytes)
98    }
99}
100
101impl From<UnixStream> for Socket {
102    fn from(stream: UnixStream) -> Self {
103        // macOS doesn't have MSG_NOSIGNAL, but has SO_NOSIGPIPE instead
104        #[cfg(target_os = "macos")]
105        let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
106        Self { stream }
107    }
108}
109
110impl AsFd for Socket {
111    fn as_fd(&self) -> BorrowedFd<'_> {
112        self.stream.as_fd()
113    }
114}
115
116impl AsRawFd for Socket {
117    fn as_raw_fd(&self) -> RawFd {
118        self.stream.as_raw_fd()
119    }
120}
121
122/*
123 * BufferedSocket
124 */
125
126/// An adapter around a raw Socket that directly handles buffering and
127/// conversion from/to wayland messages
128#[derive(Debug)]
129pub struct BufferedSocket {
130    socket: Socket,
131    in_data: Buffer<u8>,
132    in_fds: VecDeque<OwnedFd>,
133    out_data: Buffer<u8>,
134    out_fds: Vec<OwnedFd>,
135}
136
137impl BufferedSocket {
138    /// Wrap a Socket into a Buffered Socket
139    pub fn new(socket: Socket) -> Self {
140        Self {
141            socket,
142            in_data: Buffer::new(2 * MAX_BYTES_OUT), // Incoming buffers are twice as big in order to be
143            in_fds: VecDeque::new(),                 // able to store leftover data if needed
144            out_data: Buffer::new(MAX_BYTES_OUT),
145            out_fds: Vec::new(),
146        }
147    }
148
149    /// Flush the contents of the outgoing buffer into the socket
150    pub fn flush(&mut self) -> IoResult<()> {
151        let written = {
152            let bytes = self.out_data.get_contents();
153            if bytes.is_empty() {
154                return Ok(());
155            }
156            self.socket.send_msg(bytes, &self.out_fds)?
157        };
158        self.out_data.offset(written);
159        self.out_data.move_to_front();
160        self.out_fds.clear();
161        Ok(())
162    }
163
164    // internal method
165    //
166    // attempts to write a message in the internal out buffers,
167    // returns true if successful
168    //
169    // if false is returned, it means there is not enough space
170    // in the buffer
171    fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
172        match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
173            Ok(bytes_out) => {
174                self.out_data.advance(bytes_out);
175                Ok(true)
176            }
177            Err(MessageWriteError::BufferTooSmall) => Ok(false),
178            Err(MessageWriteError::DupFdFailed(e)) => Err(e),
179        }
180    }
181
182    /// Write a message to the outgoing buffer
183    ///
184    /// This method may flush the internal buffer if necessary (if it is full).
185    ///
186    /// If the message is too big to fit in the buffer, the error `Error::Sys(E2BIG)`
187    /// will be returned.
188    pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
189        if !self.attempt_write_message(msg)? {
190            // the attempt failed, there is not enough space in the buffer
191            // we need to flush it
192            if let Err(e) = self.flush() {
193                if e.kind() != ErrorKind::WouldBlock {
194                    return Err(e);
195                }
196            }
197            if !self.attempt_write_message(msg)? {
198                // If this fails again, this means the message is too big
199                // to be transmitted at all
200                return Err(rustix::io::Errno::TOOBIG.into());
201            }
202        }
203        Ok(())
204    }
205
206    /// Try to fill the incoming buffers of this socket, to prepare
207    /// a new round of parsing.
208    pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
209        // reorganize the buffers
210        self.in_data.move_to_front();
211        // receive a message
212        let in_bytes = {
213            let bytes = self.in_data.get_writable_storage();
214            self.socket.rcv_msg(bytes, &mut self.in_fds)?
215        };
216        if in_bytes == 0 {
217            // the other end of the socket was closed
218            return Err(rustix::io::Errno::PIPE.into());
219        }
220        // advance the storage
221        self.in_data.advance(in_bytes);
222        Ok(())
223    }
224
225    /// Read and deserialize a single message from the incoming buffers socket
226    ///
227    /// This method requires one closure that given an object id and an opcode,
228    /// must provide the signature of the associated request/event, in the form of
229    /// a `&'static [ArgumentType]`.
230    pub fn read_one_message<F>(
231        &mut self,
232        mut signature: F,
233    ) -> Result<Message<u32, OwnedFd>, MessageParseError>
234    where
235        F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
236    {
237        let (msg, read_data) = {
238            let data = self.in_data.get_contents();
239            if data.len() < 2 * 4 {
240                return Err(MessageParseError::MissingData);
241            }
242            let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
243            let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
244            let opcode = (word_2 & 0x0000_FFFF) as u16;
245            if let Some(sig) = signature(object_id, opcode) {
246                match parse_message(data, sig, &mut self.in_fds) {
247                    Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
248                    Err(e) => return Err(e),
249                }
250            } else {
251                // no signature found ?
252                return Err(MessageParseError::Malformed);
253            }
254        };
255
256        self.in_data.offset(read_data);
257
258        Ok(msg)
259    }
260}
261
262impl AsRawFd for BufferedSocket {
263    fn as_raw_fd(&self) -> RawFd {
264        self.socket.as_raw_fd()
265    }
266}
267
268impl AsFd for BufferedSocket {
269    fn as_fd(&self) -> BorrowedFd<'_> {
270        self.socket.as_fd()
271    }
272}
273
274/*
275 * Buffer
276 */
277#[derive(Debug)]
278struct Buffer<T: Copy> {
279    storage: Vec<T>,
280    occupied: usize,
281    offset: usize,
282}
283
284impl<T: Copy + Default> Buffer<T> {
285    fn new(size: usize) -> Self {
286        Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
287    }
288
289    /// Advance the internal counter of occupied space
290    fn advance(&mut self, bytes: usize) {
291        self.occupied += bytes;
292    }
293
294    /// Advance the read offset of current occupied space
295    fn offset(&mut self, bytes: usize) {
296        self.offset += bytes;
297    }
298
299    /// Clears the contents of the buffer
300    ///
301    /// This only sets the counter of occupied space back to zero,
302    /// allowing previous content to be overwritten.
303    #[allow(unused)]
304    fn clear(&mut self) {
305        self.occupied = 0;
306        self.offset = 0;
307    }
308
309    /// Get the current contents of the occupied space of the buffer
310    fn get_contents(&self) -> &[T] {
311        &self.storage[(self.offset)..(self.occupied)]
312    }
313
314    /// Get mutable access to the unoccupied space of the buffer
315    fn get_writable_storage(&mut self) -> &mut [T] {
316        &mut self.storage[(self.occupied)..]
317    }
318
319    /// Move the unread contents of the buffer to the front, to ensure
320    /// maximal write space availability
321    fn move_to_front(&mut self) {
322        if self.occupied > self.offset {
323            self.storage.copy_within((self.offset)..(self.occupied), 0)
324        }
325        self.occupied -= self.offset;
326        self.offset = 0;
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
334
335    use std::ffi::CString;
336    use std::os::unix::io::IntoRawFd;
337
338    use smallvec::smallvec;
339
340    fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
341        let stat1 = rustix::fs::fstat(a).unwrap();
342        let stat2 = rustix::fs::fstat(b).unwrap();
343        stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
344    }
345
346    // check if two messages are equal
347    //
348    // if arguments contain FDs, check that the fd point to
349    // the same file, rather than are the same number.
350    fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
351        msg1: &Message<u32, Fd>,
352        msg2: &Message<u32, Fd>,
353    ) {
354        assert_eq!(msg1.sender_id, msg2.sender_id);
355        assert_eq!(msg1.opcode, msg2.opcode);
356        assert_eq!(msg1.args.len(), msg2.args.len());
357        for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
358            if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
359                let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
360                let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
361                assert!(same_file(fd1, fd2));
362            } else {
363                assert_eq!(arg1, arg2);
364            }
365        }
366    }
367
368    #[test]
369    fn write_read_cycle() {
370        let msg = Message {
371            sender_id: 42,
372            opcode: 7,
373            args: smallvec![
374                Argument::Uint(3),
375                Argument::Fixed(-89),
376                Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
377                Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
378                Argument::Object(88),
379                Argument::NewId(56),
380                Argument::Int(-25),
381            ],
382        };
383
384        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
385        let mut client = BufferedSocket::new(Socket::from(client));
386        let mut server = BufferedSocket::new(Socket::from(server));
387
388        client.write_message(&msg).unwrap();
389        client.flush().unwrap();
390
391        static SIGNATURE: &[ArgumentType] = &[
392            ArgumentType::Uint,
393            ArgumentType::Fixed,
394            ArgumentType::Str(AllowNull::No),
395            ArgumentType::Array,
396            ArgumentType::Object(AllowNull::No),
397            ArgumentType::NewId,
398            ArgumentType::Int,
399        ];
400
401        server.fill_incoming_buffers().unwrap();
402
403        let ret_msg =
404            server
405                .read_one_message(|sender_id, opcode| {
406                    if sender_id == 42 && opcode == 7 {
407                        Some(SIGNATURE)
408                    } else {
409                        None
410                    }
411                })
412                .unwrap();
413
414        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
415    }
416
417    #[test]
418    fn write_read_cycle_fd() {
419        let msg = Message {
420            sender_id: 42,
421            opcode: 7,
422            args: smallvec![
423                Argument::Fd(1), // stdin
424                Argument::Fd(0), // stdout
425            ],
426        };
427
428        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
429        let mut client = BufferedSocket::new(Socket::from(client));
430        let mut server = BufferedSocket::new(Socket::from(server));
431
432        client.write_message(&msg).unwrap();
433        client.flush().unwrap();
434
435        static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
436
437        server.fill_incoming_buffers().unwrap();
438
439        let ret_msg =
440            server
441                .read_one_message(|sender_id, opcode| {
442                    if sender_id == 42 && opcode == 7 {
443                        Some(SIGNATURE)
444                    } else {
445                        None
446                    }
447                })
448                .unwrap();
449        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
450    }
451
452    #[test]
453    fn write_read_cycle_multiple() {
454        let messages = vec![
455            Message {
456                sender_id: 42,
457                opcode: 0,
458                args: smallvec![
459                    Argument::Int(42),
460                    Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
461                ],
462            },
463            Message {
464                sender_id: 42,
465                opcode: 1,
466                args: smallvec![
467                    Argument::Fd(1), // stdin
468                    Argument::Fd(0), // stdout
469                ],
470            },
471            Message {
472                sender_id: 42,
473                opcode: 2,
474                args: smallvec![
475                    Argument::Uint(3),
476                    Argument::Fd(2), // stderr
477                ],
478            },
479        ];
480
481        static SIGNATURES: &[&[ArgumentType]] = &[
482            &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
483            &[ArgumentType::Fd, ArgumentType::Fd],
484            &[ArgumentType::Uint, ArgumentType::Fd],
485        ];
486
487        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
488        let mut client = BufferedSocket::new(Socket::from(client));
489        let mut server = BufferedSocket::new(Socket::from(server));
490
491        for msg in &messages {
492            client.write_message(msg).unwrap();
493        }
494        client.flush().unwrap();
495
496        server.fill_incoming_buffers().unwrap();
497
498        let mut recv_msgs = Vec::new();
499        while let Ok(message) = server.read_one_message(|sender_id, opcode| {
500            if sender_id == 42 {
501                Some(SIGNATURES[opcode as usize])
502            } else {
503                None
504            }
505        }) {
506            recv_msgs.push(message);
507        }
508        assert_eq!(recv_msgs.len(), 3);
509        for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
510            assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
511        }
512    }
513
514    #[test]
515    fn parse_with_string_len_multiple_of_4() {
516        let msg = Message {
517            sender_id: 2,
518            opcode: 0,
519            args: smallvec![
520                Argument::Uint(18),
521                Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
522                Argument::Uint(1),
523            ],
524        };
525
526        let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
527        let mut client = BufferedSocket::new(Socket::from(client));
528        let mut server = BufferedSocket::new(Socket::from(server));
529
530        client.write_message(&msg).unwrap();
531        client.flush().unwrap();
532
533        static SIGNATURE: &[ArgumentType] =
534            &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
535
536        server.fill_incoming_buffers().unwrap();
537
538        let ret_msg =
539            server
540                .read_one_message(|sender_id, opcode| {
541                    if sender_id == 2 && opcode == 0 {
542                        Some(SIGNATURE)
543                    } else {
544                        None
545                    }
546                })
547                .unwrap();
548
549        assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
550    }
551}