1use rustix::fd::{AsFd, BorrowedFd};
2use std::io::{IoSlice, Result};
3use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd};
6#[cfg(unix)]
7use std::os::unix::net::UnixStream;
8#[cfg(windows)]
9use std::os::windows::io::{
10 AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket,
11};
12
13use crate::utils::RawFdContainer;
14use x11rb_protocol::parse_display::ConnectAddress;
15use x11rb_protocol::xauth::Family;
16
17#[derive(Debug, Clone, Copy)]
19pub enum PollMode {
20 Readable,
22
23 Writable,
25
26 ReadAndWritable,
28}
29
30impl PollMode {
31 pub fn readable(self) -> bool {
33 match self {
34 PollMode::Readable | PollMode::ReadAndWritable => true,
35 PollMode::Writable => false,
36 }
37 }
38
39 pub fn writable(self) -> bool {
41 match self {
42 PollMode::Writable | PollMode::ReadAndWritable => true,
43 PollMode::Readable => false,
44 }
45 }
46}
47
48pub trait Stream {
53 fn poll(&self, mode: PollMode) -> Result<()>;
69
70 fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>;
92
93 fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>;
119
120 fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
131 for buf in bufs {
132 if !buf.is_empty() {
133 return self.write(buf, fds);
134 }
135 }
136 Ok(0)
137 }
138}
139
140#[derive(Debug)]
144pub struct DefaultStream {
145 inner: DefaultStreamInner,
146}
147
148#[cfg(unix)]
149type DefaultStreamInner = RawFdContainer;
150
151#[cfg(not(unix))]
152type DefaultStreamInner = TcpStream;
153
154type PeerAddr = (Family, Vec<u8>);
158
159impl DefaultStream {
160 pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> {
162 match addr {
163 ConnectAddress::Hostname(host, port) => {
164 let stream = TcpStream::connect((*host, *port))?;
166 Self::from_tcp_stream(stream)
167 }
168 #[cfg(unix)]
169 ConnectAddress::Socket(path) => {
170 #[cfg(any(target_os = "linux", target_os = "android"))]
172 if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) {
173 let stream = DefaultStream { inner: stream };
177 return Ok((stream, peer_addr::local()));
178 }
179
180 let stream = UnixStream::connect(path)?;
182 Self::from_unix_stream(stream)
183 }
184 #[cfg(not(unix))]
185 ConnectAddress::Socket(_) => {
186 Err(std::io::Error::new(
188 std::io::ErrorKind::Other,
189 "Unix domain sockets are not supported on Windows",
190 ))
191 }
192 _ => Err(std::io::Error::new(
193 std::io::ErrorKind::Other,
194 "The given address family is not implemented",
195 )),
196 }
197 }
198
199 pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> {
205 let peer_addr = peer_addr::tcp(&stream.peer_addr()?);
206 stream.set_nonblocking(true)?;
207 let result = Self {
208 inner: stream.into(),
209 };
210 Ok((result, peer_addr))
211 }
212
213 #[cfg(unix)]
219 pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> {
220 stream.set_nonblocking(true)?;
221 let result = Self {
222 inner: stream.into(),
223 };
224 Ok((result, peer_addr::local()))
225 }
226
227 fn as_fd(&self) -> BorrowedFd<'_> {
228 self.inner.as_fd()
229 }
230}
231
232#[cfg(unix)]
233impl AsRawFd for DefaultStream {
234 fn as_raw_fd(&self) -> RawFd {
235 self.inner.as_raw_fd()
236 }
237}
238
239#[cfg(unix)]
240impl AsFd for DefaultStream {
241 fn as_fd(&self) -> BorrowedFd<'_> {
242 self.inner.as_fd()
243 }
244}
245
246#[cfg(unix)]
247impl IntoRawFd for DefaultStream {
248 fn into_raw_fd(self) -> RawFd {
249 self.inner.into_raw_fd()
250 }
251}
252
253#[cfg(unix)]
254impl From<DefaultStream> for OwnedFd {
255 fn from(stream: DefaultStream) -> Self {
256 stream.inner
257 }
258}
259
260#[cfg(windows)]
261impl AsRawSocket for DefaultStream {
262 fn as_raw_socket(&self) -> RawSocket {
263 self.inner.as_raw_socket()
264 }
265}
266
267#[cfg(windows)]
268impl AsSocket for DefaultStream {
269 fn as_socket(&self) -> BorrowedSocket<'_> {
270 self.inner.as_socket()
271 }
272}
273
274#[cfg(windows)]
275impl IntoRawSocket for DefaultStream {
276 fn into_raw_socket(self) -> RawSocket {
277 self.inner.into_raw_socket()
278 }
279}
280
281#[cfg(windows)]
282impl From<DefaultStream> for OwnedSocket {
283 fn from(stream: DefaultStream) -> Self {
284 stream.inner.into()
285 }
286}
287
288#[cfg(unix)]
289fn do_write(
290 stream: &DefaultStream,
291 bufs: &[IoSlice<'_>],
292 fds: &mut Vec<RawFdContainer>,
293) -> Result<usize> {
294 use rustix::io::Errno;
295 use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags};
296
297 fn sendmsg_wrapper(
298 fd: BorrowedFd<'_>,
299 iov: &[IoSlice<'_>],
300 cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>,
301 flags: SendFlags,
302 ) -> Result<usize> {
303 loop {
304 match sendmsg(fd, iov, cmsgs, flags) {
305 Ok(n) => return Ok(n),
306 Err(Errno::INTR) => {}
308 Err(e) => return Err(e.into()),
309 }
310 }
311 }
312
313 let fd = stream.as_fd();
314
315 let res = if !fds.is_empty() {
316 let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
317 let rights = SendAncillaryMessage::ScmRights(&fds);
318
319 let mut cmsg_space = vec![0u8; rights.size()];
320 let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
321 assert!(cmsg_buffer.push(rights));
322
323 sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())?
324 } else {
325 sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())?
326 };
327
328 fds.clear();
330
331 Ok(res)
332}
333
334impl Stream for DefaultStream {
335 fn poll(&self, mode: PollMode) -> Result<()> {
336 use rustix::event::{poll, PollFd, PollFlags};
337 use rustix::io::Errno;
338
339 let mut poll_flags = PollFlags::empty();
340 if mode.readable() {
341 poll_flags |= PollFlags::IN;
342 }
343 if mode.writable() {
344 poll_flags |= PollFlags::OUT;
345 }
346 let fd = self.as_fd();
347 let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)];
348 loop {
349 match poll(&mut poll_fds, -1) {
350 Ok(_) => break,
351 Err(Errno::INTR) => {}
352 Err(e) => return Err(e.into()),
353 }
354 }
355 Ok(())
357 }
358
359 fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
360 #[cfg(unix)]
361 {
362 use rustix::io::Errno;
363 use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage};
364 use std::io::IoSliceMut;
365
366 let mut cmsg = [0u8; 1024];
370 let mut iov = [IoSliceMut::new(buf)];
371 let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg);
372
373 let fd = self.as_fd();
374 let msg = loop {
375 match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) {
376 Ok(msg) => break msg,
377 Err(Errno::INTR) => {}
379 Err(e) => return Err(e.into()),
380 }
381 };
382
383 let fds_received = cmsg_buffer
384 .drain()
385 .filter_map(|cmsg| match cmsg {
386 RecvAncillaryMessage::ScmRights(r) => Some(r),
387 _ => None,
388 })
389 .flatten();
390
391 let mut cloexec_error = Ok(());
392 fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error));
393 cloexec_error?;
394
395 Ok(msg.bytes)
396 }
397 #[cfg(not(unix))]
398 {
399 use std::io::Read;
400 let _ = fd_storage;
402 loop {
403 match (&mut &self.inner).read(buf) {
405 Ok(n) => return Ok(n),
406 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
408 Err(e) => return Err(e),
409 }
410 }
411 }
412 }
413
414 fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
415 #[cfg(unix)]
416 {
417 do_write(self, &[IoSlice::new(buf)], fds)
418 }
419 #[cfg(not(unix))]
420 {
421 use std::io::{Error, ErrorKind, Write};
422 if !fds.is_empty() {
423 return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
424 }
425 loop {
426 match (&mut &self.inner).write(buf) {
428 Ok(n) => return Ok(n),
429 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
431 Err(e) => return Err(e),
432 }
433 }
434 }
435 }
436
437 fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
438 #[cfg(unix)]
439 {
440 do_write(self, bufs, fds)
441 }
442 #[cfg(not(unix))]
443 {
444 use std::io::{Error, ErrorKind, Write};
445 if !fds.is_empty() {
446 return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
447 }
448 loop {
449 match (&mut &self.inner).write_vectored(bufs) {
451 Ok(n) => return Ok(n),
452 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
454 Err(e) => return Err(e),
455 }
456 }
457 }
458 }
459}
460
461#[cfg(any(target_os = "linux", target_os = "android"))]
462fn connect_abstract_unix_stream(
463 path: &[u8],
464) -> std::result::Result<RawFdContainer, rustix::io::Errno> {
465 use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
466 use rustix::net::{
467 connect_unix, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType,
468 };
469
470 let socket = socket_with(
471 AddressFamily::UNIX,
472 SocketType::STREAM,
473 SocketFlags::CLOEXEC,
474 None,
475 )?;
476
477 connect_unix(&socket, &SocketAddrUnix::new_abstract_name(path)?)?;
478
479 fcntl_setfl(&socket, fcntl_getfl(&socket)? | OFlags::NONBLOCK)?;
481
482 Ok(socket)
483}
484
485#[cfg(any(
487 target_os = "android",
488 target_os = "dragonfly",
489 target_os = "freebsd",
490 target_os = "linux",
491 target_os = "netbsd",
492 target_os = "openbsd"
493))]
494mod recvmsg {
495 use super::RawFdContainer;
496 use rustix::net::RecvFlags;
497
498 pub(crate) fn flags() -> RecvFlags {
499 RecvFlags::CMSG_CLOEXEC
500 }
501
502 pub(crate) fn after_recvmsg<'a>(
503 fds: impl Iterator<Item = RawFdContainer> + 'a,
504 _cloexec_error: &'a mut Result<(), rustix::io::Errno>,
505 ) -> impl Iterator<Item = RawFdContainer> + 'a {
506 fds
507 }
508}
509
510#[cfg(all(
512 unix,
513 not(any(
514 target_os = "android",
515 target_os = "dragonfly",
516 target_os = "freebsd",
517 target_os = "linux",
518 target_os = "netbsd",
519 target_os = "openbsd"
520 ))
521))]
522mod recvmsg {
523 use super::RawFdContainer;
524 use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
525 use rustix::net::RecvFlags;
526
527 pub(crate) fn flags() -> RecvFlags {
528 RecvFlags::empty()
529 }
530
531 pub(crate) fn after_recvmsg<'a>(
532 fds: impl Iterator<Item = RawFdContainer> + 'a,
533 cloexec_error: &'a mut rustix::io::Result<()>,
534 ) -> impl Iterator<Item = RawFdContainer> + 'a {
535 fds.map(move |fd| {
536 if let Err(e) =
537 fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC))
538 {
539 *cloexec_error = Err(e);
540 }
541 fd
542 })
543 }
544}
545
546mod peer_addr {
547 use super::{Family, PeerAddr};
548 use std::net::{Ipv4Addr, SocketAddr};
549
550 pub(super) fn local() -> PeerAddr {
552 let hostname = crate::hostname()
553 .to_str()
554 .map_or_else(Vec::new, |s| s.as_bytes().to_vec());
555 (Family::LOCAL, hostname)
556 }
557
558 pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr {
560 let ip = match addr {
561 SocketAddr::V4(addr) => *addr.ip(),
562 SocketAddr::V6(addr) => {
563 let ip = addr.ip();
564 if ip.is_loopback() {
565 Ipv4Addr::LOCALHOST
568 } else if let Some(ip) = ip.to_ipv4() {
569 ip
571 } else {
572 return (Family::INTERNET6, ip.octets().to_vec());
574 }
575 }
576 };
577
578 if ip.is_loopback() {
580 local()
581 } else {
582 (Family::INTERNET, ip.octets().to_vec())
583 }
584 }
585}