1use std::collections::VecDeque;
4use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
5use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
6use std::os::unix::net::UnixStream;
7use std::slice;
8
9use rustix::io::retry_on_intr;
10use rustix::net::{
11 recvmsg, send, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags,
12 SendAncillaryBuffer, SendAncillaryMessage, SendFlags,
13};
14
15use crate::protocol::{ArgumentType, Message};
16
17use super::wire::{parse_message, write_to_buffers, MessageParseError, MessageWriteError};
18
19pub const MAX_FDS_OUT: usize = 28;
21pub const MAX_BYTES_OUT: usize = 4096;
23
24#[derive(Debug)]
30pub struct Socket {
31 stream: UnixStream,
32}
33
34impl Socket {
35 pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
43 #[cfg(not(target_os = "macos"))]
44 let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
45 #[cfg(target_os = "macos")]
46 let flags = SendFlags::DONTWAIT;
47
48 if !fds.is_empty() {
49 let iov = [IoSlice::new(bytes)];
50 let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))];
51 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
52 let fds =
53 unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
54 cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
55 Ok(retry_on_intr(|| sendmsg(self, &iov, &mut cmsg_buffer, flags))?)
56 } else {
57 Ok(retry_on_intr(|| send(self, bytes, flags))?)
58 }
59 }
60
61 pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
73 #[cfg(not(target_os = "macos"))]
74 let flags = RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC;
75 #[cfg(target_os = "macos")]
76 let flags = RecvFlags::DONTWAIT;
77
78 let mut cmsg_space = [0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
79 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
80 let mut iov = [IoSliceMut::new(buffer)];
81 let msg = retry_on_intr(|| recvmsg(&self.stream, &mut iov[..], &mut cmsg_buffer, flags))?;
82
83 let received_fds = cmsg_buffer
84 .drain()
85 .filter_map(|cmsg| match cmsg {
86 RecvAncillaryMessage::ScmRights(fds) => Some(fds),
87 _ => None,
88 })
89 .flatten();
90 fds.extend(received_fds);
91 #[cfg(target_os = "macos")]
92 for fd in fds.iter() {
93 if let Ok(flags) = rustix::io::fcntl_getfd(fd) {
94 let _ = rustix::io::fcntl_setfd(fd, flags | rustix::io::FdFlags::CLOEXEC);
95 }
96 }
97 Ok(msg.bytes)
98 }
99}
100
101impl From<UnixStream> for Socket {
102 fn from(stream: UnixStream) -> Self {
103 #[cfg(target_os = "macos")]
105 let _ = rustix::net::sockopt::set_socket_nosigpipe(&stream, true);
106 Self { stream }
107 }
108}
109
110impl AsFd for Socket {
111 fn as_fd(&self) -> BorrowedFd<'_> {
112 self.stream.as_fd()
113 }
114}
115
116impl AsRawFd for Socket {
117 fn as_raw_fd(&self) -> RawFd {
118 self.stream.as_raw_fd()
119 }
120}
121
122#[derive(Debug)]
129pub struct BufferedSocket {
130 socket: Socket,
131 in_data: Buffer<u8>,
132 in_fds: VecDeque<OwnedFd>,
133 out_data: Buffer<u8>,
134 out_fds: Vec<OwnedFd>,
135}
136
137impl BufferedSocket {
138 pub fn new(socket: Socket) -> Self {
140 Self {
141 socket,
142 in_data: Buffer::new(2 * MAX_BYTES_OUT), in_fds: VecDeque::new(), out_data: Buffer::new(MAX_BYTES_OUT),
145 out_fds: Vec::new(),
146 }
147 }
148
149 pub fn flush(&mut self) -> IoResult<()> {
151 let written = {
152 let bytes = self.out_data.get_contents();
153 if bytes.is_empty() {
154 return Ok(());
155 }
156 self.socket.send_msg(bytes, &self.out_fds)?
157 };
158 self.out_data.offset(written);
159 self.out_data.move_to_front();
160 self.out_fds.clear();
161 Ok(())
162 }
163
164 fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
172 match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
173 Ok(bytes_out) => {
174 self.out_data.advance(bytes_out);
175 Ok(true)
176 }
177 Err(MessageWriteError::BufferTooSmall) => Ok(false),
178 Err(MessageWriteError::DupFdFailed(e)) => Err(e),
179 }
180 }
181
182 pub fn write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<()> {
189 if !self.attempt_write_message(msg)? {
190 if let Err(e) = self.flush() {
193 if e.kind() != ErrorKind::WouldBlock {
194 return Err(e);
195 }
196 }
197 if !self.attempt_write_message(msg)? {
198 return Err(rustix::io::Errno::TOOBIG.into());
201 }
202 }
203 Ok(())
204 }
205
206 pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
209 self.in_data.move_to_front();
211 let in_bytes = {
213 let bytes = self.in_data.get_writable_storage();
214 self.socket.rcv_msg(bytes, &mut self.in_fds)?
215 };
216 if in_bytes == 0 {
217 return Err(rustix::io::Errno::PIPE.into());
219 }
220 self.in_data.advance(in_bytes);
222 Ok(())
223 }
224
225 pub fn read_one_message<F>(
231 &mut self,
232 mut signature: F,
233 ) -> Result<Message<u32, OwnedFd>, MessageParseError>
234 where
235 F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
236 {
237 let (msg, read_data) = {
238 let data = self.in_data.get_contents();
239 if data.len() < 2 * 4 {
240 return Err(MessageParseError::MissingData);
241 }
242 let object_id = u32::from_ne_bytes([data[0], data[1], data[2], data[3]]);
243 let word_2 = u32::from_ne_bytes([data[4], data[5], data[6], data[7]]);
244 let opcode = (word_2 & 0x0000_FFFF) as u16;
245 if let Some(sig) = signature(object_id, opcode) {
246 match parse_message(data, sig, &mut self.in_fds) {
247 Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
248 Err(e) => return Err(e),
249 }
250 } else {
251 return Err(MessageParseError::Malformed);
253 }
254 };
255
256 self.in_data.offset(read_data);
257
258 Ok(msg)
259 }
260}
261
262impl AsRawFd for BufferedSocket {
263 fn as_raw_fd(&self) -> RawFd {
264 self.socket.as_raw_fd()
265 }
266}
267
268impl AsFd for BufferedSocket {
269 fn as_fd(&self) -> BorrowedFd<'_> {
270 self.socket.as_fd()
271 }
272}
273
274#[derive(Debug)]
278struct Buffer<T: Copy> {
279 storage: Vec<T>,
280 occupied: usize,
281 offset: usize,
282}
283
284impl<T: Copy + Default> Buffer<T> {
285 fn new(size: usize) -> Self {
286 Self { storage: vec![T::default(); size], occupied: 0, offset: 0 }
287 }
288
289 fn advance(&mut self, bytes: usize) {
291 self.occupied += bytes;
292 }
293
294 fn offset(&mut self, bytes: usize) {
296 self.offset += bytes;
297 }
298
299 #[allow(unused)]
304 fn clear(&mut self) {
305 self.occupied = 0;
306 self.offset = 0;
307 }
308
309 fn get_contents(&self) -> &[T] {
311 &self.storage[(self.offset)..(self.occupied)]
312 }
313
314 fn get_writable_storage(&mut self) -> &mut [T] {
316 &mut self.storage[(self.occupied)..]
317 }
318
319 fn move_to_front(&mut self) {
322 if self.occupied > self.offset {
323 self.storage.copy_within((self.offset)..(self.occupied), 0)
324 }
325 self.occupied -= self.offset;
326 self.offset = 0;
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::protocol::{AllowNull, Argument, ArgumentType, Message};
334
335 use std::ffi::CString;
336 use std::os::unix::io::IntoRawFd;
337
338 use smallvec::smallvec;
339
340 fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
341 let stat1 = rustix::fs::fstat(a).unwrap();
342 let stat2 = rustix::fs::fstat(b).unwrap();
343 stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
344 }
345
346 fn assert_eq_msgs<Fd: AsRawFd + std::fmt::Debug>(
351 msg1: &Message<u32, Fd>,
352 msg2: &Message<u32, Fd>,
353 ) {
354 assert_eq!(msg1.sender_id, msg2.sender_id);
355 assert_eq!(msg1.opcode, msg2.opcode);
356 assert_eq!(msg1.args.len(), msg2.args.len());
357 for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
358 if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
359 let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
360 let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
361 assert!(same_file(fd1, fd2));
362 } else {
363 assert_eq!(arg1, arg2);
364 }
365 }
366 }
367
368 #[test]
369 fn write_read_cycle() {
370 let msg = Message {
371 sender_id: 42,
372 opcode: 7,
373 args: smallvec![
374 Argument::Uint(3),
375 Argument::Fixed(-89),
376 Argument::Str(Some(Box::new(CString::new(&b"I like trains!"[..]).unwrap()))),
377 Argument::Array(vec![1, 2, 3, 4, 5, 6, 7, 8, 9].into()),
378 Argument::Object(88),
379 Argument::NewId(56),
380 Argument::Int(-25),
381 ],
382 };
383
384 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
385 let mut client = BufferedSocket::new(Socket::from(client));
386 let mut server = BufferedSocket::new(Socket::from(server));
387
388 client.write_message(&msg).unwrap();
389 client.flush().unwrap();
390
391 static SIGNATURE: &[ArgumentType] = &[
392 ArgumentType::Uint,
393 ArgumentType::Fixed,
394 ArgumentType::Str(AllowNull::No),
395 ArgumentType::Array,
396 ArgumentType::Object(AllowNull::No),
397 ArgumentType::NewId,
398 ArgumentType::Int,
399 ];
400
401 server.fill_incoming_buffers().unwrap();
402
403 let ret_msg =
404 server
405 .read_one_message(|sender_id, opcode| {
406 if sender_id == 42 && opcode == 7 {
407 Some(SIGNATURE)
408 } else {
409 None
410 }
411 })
412 .unwrap();
413
414 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
415 }
416
417 #[test]
418 fn write_read_cycle_fd() {
419 let msg = Message {
420 sender_id: 42,
421 opcode: 7,
422 args: smallvec![
423 Argument::Fd(1), Argument::Fd(0), ],
426 };
427
428 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
429 let mut client = BufferedSocket::new(Socket::from(client));
430 let mut server = BufferedSocket::new(Socket::from(server));
431
432 client.write_message(&msg).unwrap();
433 client.flush().unwrap();
434
435 static SIGNATURE: &[ArgumentType] = &[ArgumentType::Fd, ArgumentType::Fd];
436
437 server.fill_incoming_buffers().unwrap();
438
439 let ret_msg =
440 server
441 .read_one_message(|sender_id, opcode| {
442 if sender_id == 42 && opcode == 7 {
443 Some(SIGNATURE)
444 } else {
445 None
446 }
447 })
448 .unwrap();
449 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
450 }
451
452 #[test]
453 fn write_read_cycle_multiple() {
454 let messages = vec![
455 Message {
456 sender_id: 42,
457 opcode: 0,
458 args: smallvec![
459 Argument::Int(42),
460 Argument::Str(Some(Box::new(CString::new(&b"I like trains"[..]).unwrap()))),
461 ],
462 },
463 Message {
464 sender_id: 42,
465 opcode: 1,
466 args: smallvec![
467 Argument::Fd(1), Argument::Fd(0), ],
470 },
471 Message {
472 sender_id: 42,
473 opcode: 2,
474 args: smallvec![
475 Argument::Uint(3),
476 Argument::Fd(2), ],
478 },
479 ];
480
481 static SIGNATURES: &[&[ArgumentType]] = &[
482 &[ArgumentType::Int, ArgumentType::Str(AllowNull::No)],
483 &[ArgumentType::Fd, ArgumentType::Fd],
484 &[ArgumentType::Uint, ArgumentType::Fd],
485 ];
486
487 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
488 let mut client = BufferedSocket::new(Socket::from(client));
489 let mut server = BufferedSocket::new(Socket::from(server));
490
491 for msg in &messages {
492 client.write_message(msg).unwrap();
493 }
494 client.flush().unwrap();
495
496 server.fill_incoming_buffers().unwrap();
497
498 let mut recv_msgs = Vec::new();
499 while let Ok(message) = server.read_one_message(|sender_id, opcode| {
500 if sender_id == 42 {
501 Some(SIGNATURES[opcode as usize])
502 } else {
503 None
504 }
505 }) {
506 recv_msgs.push(message);
507 }
508 assert_eq!(recv_msgs.len(), 3);
509 for (msg1, msg2) in messages.into_iter().zip(recv_msgs.into_iter()) {
510 assert_eq_msgs(&msg1.map_fd(|fd| fd.as_raw_fd()), &msg2.map_fd(IntoRawFd::into_raw_fd));
511 }
512 }
513
514 #[test]
515 fn parse_with_string_len_multiple_of_4() {
516 let msg = Message {
517 sender_id: 2,
518 opcode: 0,
519 args: smallvec![
520 Argument::Uint(18),
521 Argument::Str(Some(Box::new(CString::new(&b"wl_shell"[..]).unwrap()))),
522 Argument::Uint(1),
523 ],
524 };
525
526 let (client, server) = ::std::os::unix::net::UnixStream::pair().unwrap();
527 let mut client = BufferedSocket::new(Socket::from(client));
528 let mut server = BufferedSocket::new(Socket::from(server));
529
530 client.write_message(&msg).unwrap();
531 client.flush().unwrap();
532
533 static SIGNATURE: &[ArgumentType] =
534 &[ArgumentType::Uint, ArgumentType::Str(AllowNull::No), ArgumentType::Uint];
535
536 server.fill_incoming_buffers().unwrap();
537
538 let ret_msg =
539 server
540 .read_one_message(|sender_id, opcode| {
541 if sender_id == 2 && opcode == 0 {
542 Some(SIGNATURE)
543 } else {
544 None
545 }
546 })
547 .unwrap();
548
549 assert_eq_msgs(&msg.map_fd(|fd| fd.as_raw_fd()), &ret_msg.map_fd(IntoRawFd::into_raw_fd));
550 }
551}