1use std::io::IoSlice;
4use std::sync::{Condvar, Mutex, MutexGuard, TryLockError};
5use std::time::Instant;
6
7use crate::connection::{
8 compute_length_field, Connection, ReplyOrError, RequestConnection, RequestKind,
9};
10use crate::cookie::{Cookie, CookieWithFds, VoidCookie};
11use crate::errors::DisplayParsingError;
12pub use crate::errors::{ConnectError, ConnectionError, ParseError, ReplyError, ReplyOrIdError};
13use crate::extension_manager::ExtensionManager;
14use crate::protocol::bigreq::{ConnectionExt as _, EnableReply};
15use crate::protocol::xproto::{Setup, GET_INPUT_FOCUS_REQUEST, QUERY_EXTENSION_REQUEST};
16use crate::utils::RawFdContainer;
17use crate::x11_utils::{ExtensionInformation, TryParse, TryParseFd};
18use x11rb_protocol::connect::Connect;
19use x11rb_protocol::connection::{Connection as ProtoConnection, PollReply, ReplyFdKind};
20use x11rb_protocol::id_allocator::IdAllocator;
21use x11rb_protocol::{xauth::get_auth, DiscardMode, RawEventAndSeqNumber, SequenceNumber};
22
23mod packet_reader;
24mod stream;
25mod write_buffer;
26
27use packet_reader::PacketReader;
28pub use stream::{DefaultStream, PollMode, Stream};
29use write_buffer::WriteBuffer;
30
31type Buffer = <RustConnection as RequestConnection>::Buf;
32pub type BufWithFds = crate::connection::BufWithFds<Buffer>;
34
35#[derive(Debug)]
36enum MaxRequestBytes {
37 Unknown,
38 Requested(Option<SequenceNumber>),
39 Known(usize),
40}
41
42#[derive(Debug)]
43struct ConnectionInner {
44 inner: ProtoConnection,
45 write_buffer: WriteBuffer,
46}
47
48type MutexGuardInner<'a> = MutexGuard<'a, ConnectionInner>;
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq)]
51pub(crate) enum BlockingMode {
52 Blocking,
53 NonBlocking,
54}
55
56#[derive(Debug)]
65pub struct RustConnection<S: Stream = DefaultStream> {
66 inner: Mutex<ConnectionInner>,
67 stream: S,
68 packet_reader: Mutex<PacketReader>,
71 reader_condition: Condvar,
72 setup: Setup,
73 extension_manager: Mutex<ExtensionManager>,
74 maximum_request_bytes: Mutex<MaxRequestBytes>,
75 id_allocator: Mutex<IdAllocator>,
76}
77
78impl RustConnection<DefaultStream> {
109 pub fn connect(dpy_name: Option<&str>) -> Result<(Self, usize), ConnectError> {
113 let parsed_display = x11rb_protocol::parse_display::parse_display(dpy_name)?;
115 let screen = parsed_display.screen.into();
116
117 let mut error = None;
120 for addr in parsed_display.connect_instruction() {
121 let start = Instant::now();
122 match DefaultStream::connect(&addr) {
123 Ok((stream, (family, address))) => {
124 crate::trace!(
125 "Connected to X11 server via {:?} in {:?}",
126 addr,
127 start.elapsed()
128 );
129
130 let (auth_name, auth_data) = get_auth(family, &address, parsed_display.display)
132 .unwrap_or(None)
134 .unwrap_or_else(|| (Vec::new(), Vec::new()));
135 crate::trace!("Picked authentication via auth mechanism {:?}", auth_name);
136
137 return Ok((
139 Self::connect_to_stream_with_auth_info(
140 stream, screen, auth_name, auth_data,
141 )?,
142 screen,
143 ));
144 }
145 Err(e) => {
146 crate::debug!("Failed to connect to X11 server via {:?}: {:?}", addr, e);
147 error = Some(e);
148 continue;
149 }
150 }
151 }
152
153 Err(match error {
155 Some(e) => ConnectError::IoError(e),
156 None => DisplayParsingError::Unknown.into(),
157 })
158 }
159}
160
161impl<S: Stream> RustConnection<S> {
162 pub fn connect_to_stream(stream: S, screen: usize) -> Result<Self, ConnectError> {
168 Self::connect_to_stream_with_auth_info(stream, screen, Vec::new(), Vec::new())
169 }
170
171 pub fn connect_to_stream_with_auth_info(
181 stream: S,
182 screen: usize,
183 auth_name: Vec<u8>,
184 auth_data: Vec<u8>,
185 ) -> Result<Self, ConnectError> {
186 let (mut connect, setup_request) = Connect::with_authorization(auth_name, auth_data);
187
188 let mut nwritten = 0;
190 let mut fds = vec![];
191
192 crate::trace!(
193 "Writing connection setup with {} bytes",
194 setup_request.len()
195 );
196 while nwritten != setup_request.len() {
197 stream.poll(PollMode::Writable)?;
198 match stream.write(&setup_request[nwritten..], &mut fds) {
200 Ok(0) => {
201 return Err(std::io::Error::new(
202 std::io::ErrorKind::WriteZero,
203 "failed to write whole buffer",
204 )
205 .into())
206 }
207 Ok(n) => nwritten += n,
208 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
210 Err(e) => return Err(e.into()),
211 }
212 }
213
214 loop {
216 stream.poll(PollMode::Readable)?;
217 crate::trace!(
218 "Reading connection setup with at least {} bytes remaining",
219 connect.buffer().len()
220 );
221 let adv = match stream.read(connect.buffer(), &mut fds) {
222 Ok(0) => {
223 return Err(std::io::Error::new(
224 std::io::ErrorKind::UnexpectedEof,
225 "failed to read whole buffer",
226 )
227 .into())
228 }
229 Ok(n) => n,
230 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
232 Err(e) => return Err(e.into()),
233 };
234 crate::trace!("Read {} bytes", adv);
235
236 if connect.advance(adv) {
238 break;
239 }
240 }
241
242 let setup = connect.into_setup()?;
244
245 if screen >= setup.roots.len() {
247 return Err(ConnectError::InvalidScreen);
248 }
249
250 Self::for_connected_stream(stream, setup)
252 }
253
254 pub fn for_connected_stream(stream: S, setup: Setup) -> Result<Self, ConnectError> {
260 let id_allocator = IdAllocator::new(setup.resource_id_base, setup.resource_id_mask)?;
261
262 Ok(RustConnection {
263 inner: Mutex::new(ConnectionInner {
264 inner: ProtoConnection::new(),
265 write_buffer: WriteBuffer::new(),
266 }),
267 stream,
268 packet_reader: Mutex::new(PacketReader::new()),
269 reader_condition: Condvar::new(),
270 setup,
271 extension_manager: Default::default(),
272 maximum_request_bytes: Mutex::new(MaxRequestBytes::Unknown),
273 id_allocator: Mutex::new(id_allocator),
274 })
275 }
276
277 fn send_request(
282 &self,
283 bufs: &[IoSlice<'_>],
284 fds: Vec<RawFdContainer>,
285 kind: ReplyFdKind,
286 ) -> Result<SequenceNumber, ConnectionError> {
287 let _guard = crate::debug_span!("send_request").entered();
288
289 let request_info = RequestInfo {
290 extension_manager: &self.extension_manager,
291 major_opcode: bufs[0][0],
292 minor_opcode: bufs[0][1],
293 };
294 crate::debug!("Sending {}", request_info);
295
296 let mut storage = Default::default();
297 let bufs = compute_length_field(self, bufs, &mut storage)?;
298
299 let mut inner = self.inner.lock().unwrap();
304
305 loop {
306 let send_result = inner.inner.send_request(kind);
307 match send_result {
308 Some(seqno) => {
309 let _inner = self.write_all_vectored(inner, bufs, fds)?;
311 return Ok(seqno);
312 }
313 None => {
314 crate::trace!("Syncing with the X11 server since there are too many outstanding void requests");
315 inner = self.send_sync(inner)?;
316 }
317 }
318 }
319 }
320
321 fn send_sync<'a>(
327 &'a self,
328 mut inner: MutexGuardInner<'a>,
329 ) -> Result<MutexGuardInner<'a>, std::io::Error> {
330 let length = 1u16.to_ne_bytes();
331 let request = [
332 GET_INPUT_FOCUS_REQUEST,
333 0, length[0],
335 length[1],
336 ];
337
338 let seqno = inner
339 .inner
340 .send_request(ReplyFdKind::ReplyWithoutFDs)
341 .expect("Sending a HasResponse request should not be blocked by syncs");
342 inner
343 .inner
344 .discard_reply(seqno, DiscardMode::DiscardReplyAndError);
345 let inner = self.write_all_vectored(inner, &[IoSlice::new(&request)], Vec::new())?;
346
347 Ok(inner)
348 }
349
350 fn write_all_vectored<'a>(
353 &'a self,
354 mut inner: MutexGuardInner<'a>,
355 mut bufs: &[IoSlice<'_>],
356 mut fds: Vec<RawFdContainer>,
357 ) -> std::io::Result<MutexGuardInner<'a>> {
358 let mut partial_buf: &[u8] = &[];
359 while !partial_buf.is_empty() || !bufs.is_empty() {
360 self.stream.poll(PollMode::ReadAndWritable)?;
361 let write_result = if !partial_buf.is_empty() {
362 inner
364 .write_buffer
365 .write(&self.stream, partial_buf, &mut fds)
366 } else {
367 inner
369 .write_buffer
370 .write_vectored(&self.stream, bufs, &mut fds)
371 };
372 match write_result {
373 Ok(0) => {
374 return Err(std::io::Error::new(
375 std::io::ErrorKind::WriteZero,
376 "failed to write anything",
377 ));
378 }
379 Ok(mut count) => {
380 if count >= partial_buf.len() {
382 count -= partial_buf.len();
383 partial_buf = &[];
384 } else {
385 partial_buf = &partial_buf[count..];
386 count = 0;
387 }
388 while count > 0 {
389 if count >= bufs[0].len() {
390 count -= bufs[0].len();
391 } else {
392 partial_buf = &bufs[0][count..];
393 count = 0;
394 }
395 bufs = &bufs[1..];
396 while bufs.first().map(|s| s.len()) == Some(0) {
398 bufs = &bufs[1..];
399 }
400 }
401 }
402 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
403 crate::trace!("Writing more data would block for now");
404 inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
408 }
409 Err(e) => return Err(e),
410 }
411 }
412 if !fds.is_empty() {
413 return Err(std::io::Error::new(
414 std::io::ErrorKind::Other,
415 "Left over FDs after sending the request",
416 ));
417 }
418 Ok(inner)
419 }
420
421 fn flush_impl<'a>(
422 &'a self,
423 mut inner: MutexGuardInner<'a>,
424 ) -> std::io::Result<MutexGuardInner<'a>> {
425 while inner.write_buffer.needs_flush() {
427 self.stream.poll(PollMode::ReadAndWritable)?;
428 let flush_result = inner.write_buffer.flush(&self.stream);
429 match flush_result {
430 Ok(()) => break,
432 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
433 crate::trace!("Flushing more data would block for now");
434 inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
438 }
439 Err(e) => return Err(e),
440 }
441 }
442 Ok(inner)
443 }
444
445 fn read_packet_and_enqueue<'a>(
457 &'a self,
458 mut inner: MutexGuardInner<'a>,
459 mode: BlockingMode,
460 ) -> Result<MutexGuardInner<'a>, std::io::Error> {
461 match self.packet_reader.try_lock() {
463 Err(TryLockError::WouldBlock) => {
464 match mode {
466 BlockingMode::NonBlocking => {
467 crate::trace!("read_packet_and_enqueue in NonBlocking mode doing nothing since reader is already locked");
468 return Ok(inner);
469 }
470 BlockingMode::Blocking => {
471 crate::trace!("read_packet_and_enqueue in Blocking mode waiting for pre-existing reader");
472 }
473 }
474
475 Ok(self.reader_condition.wait(inner).unwrap())
484 }
485 Err(TryLockError::Poisoned(e)) => panic!("{}", e),
486 Ok(mut packet_reader) => {
487 let notify_on_drop = NotifyOnDrop(&self.reader_condition);
490
491 if mode == BlockingMode::Blocking {
493 drop(inner);
496 self.stream.poll(PollMode::Readable)?;
498 inner = self.inner.lock().unwrap();
500 }
501
502 let mut fds = Vec::new();
504 let mut packets = Vec::new();
505 packet_reader.try_read_packets(&self.stream, &mut packets, &mut fds)?;
506
507 drop(packet_reader);
515
516 inner.inner.enqueue_fds(fds);
518 packets
519 .into_iter()
520 .for_each(|packet| inner.inner.enqueue_packet(packet));
521
522 drop(notify_on_drop);
527
528 Ok(inner)
530 }
531 }
532 }
533
534 fn prefetch_maximum_request_bytes_impl(&self, max_bytes: &mut MutexGuard<'_, MaxRequestBytes>) {
535 if let MaxRequestBytes::Unknown = **max_bytes {
536 crate::info!("Prefetching maximum request length");
537 let request = self
538 .bigreq_enable()
539 .map(|cookie| cookie.into_sequence_number())
540 .ok();
541 **max_bytes = MaxRequestBytes::Requested(request);
542 }
543 }
544
545 pub fn stream(&self) -> &S {
547 &self.stream
548 }
549}
550
551impl<S: Stream> RequestConnection for RustConnection<S> {
552 type Buf = Vec<u8>;
553
554 fn send_request_with_reply<Reply>(
555 &self,
556 bufs: &[IoSlice<'_>],
557 fds: Vec<RawFdContainer>,
558 ) -> Result<Cookie<'_, Self, Reply>, ConnectionError>
559 where
560 Reply: TryParse,
561 {
562 Ok(Cookie::new(
563 self,
564 self.send_request(bufs, fds, ReplyFdKind::ReplyWithoutFDs)?,
565 ))
566 }
567
568 fn send_request_with_reply_with_fds<Reply>(
569 &self,
570 bufs: &[IoSlice<'_>],
571 fds: Vec<RawFdContainer>,
572 ) -> Result<CookieWithFds<'_, Self, Reply>, ConnectionError>
573 where
574 Reply: TryParseFd,
575 {
576 Ok(CookieWithFds::new(
577 self,
578 self.send_request(bufs, fds, ReplyFdKind::ReplyWithFDs)?,
579 ))
580 }
581
582 fn send_request_without_reply(
583 &self,
584 bufs: &[IoSlice<'_>],
585 fds: Vec<RawFdContainer>,
586 ) -> Result<VoidCookie<'_, Self>, ConnectionError> {
587 Ok(VoidCookie::new(
588 self,
589 self.send_request(bufs, fds, ReplyFdKind::NoReply)?,
590 ))
591 }
592
593 fn discard_reply(&self, sequence: SequenceNumber, _kind: RequestKind, mode: DiscardMode) {
594 crate::debug!(
595 "Discarding reply to request {} in mode {:?}",
596 sequence,
597 mode
598 );
599 self.inner
600 .lock()
601 .unwrap()
602 .inner
603 .discard_reply(sequence, mode);
604 }
605
606 fn prefetch_extension_information(
607 &self,
608 extension_name: &'static str,
609 ) -> Result<(), ConnectionError> {
610 self.extension_manager
611 .lock()
612 .unwrap()
613 .prefetch_extension_information(self, extension_name)
614 }
615
616 fn extension_information(
617 &self,
618 extension_name: &'static str,
619 ) -> Result<Option<ExtensionInformation>, ConnectionError> {
620 self.extension_manager
621 .lock()
622 .unwrap()
623 .extension_information(self, extension_name)
624 }
625
626 fn wait_for_reply_or_raw_error(
627 &self,
628 sequence: SequenceNumber,
629 ) -> Result<ReplyOrError<Vec<u8>>, ConnectionError> {
630 match self.wait_for_reply_with_fds_raw(sequence)? {
631 ReplyOrError::Reply((reply, _fds)) => Ok(ReplyOrError::Reply(reply)),
632 ReplyOrError::Error(e) => Ok(ReplyOrError::Error(e)),
633 }
634 }
635
636 fn wait_for_reply(&self, sequence: SequenceNumber) -> Result<Option<Vec<u8>>, ConnectionError> {
637 let _guard = crate::debug_span!("wait_for_reply", sequence).entered();
638
639 let mut inner = self.inner.lock().unwrap();
640 inner = self.flush_impl(inner)?;
641 loop {
642 crate::trace!({ sequence }, "Polling for reply");
643 let poll_result = inner.inner.poll_for_reply(sequence);
644 match poll_result {
645 PollReply::TryAgain => {}
646 PollReply::NoReply => return Ok(None),
647 PollReply::Reply(buffer) => return Ok(Some(buffer)),
648 }
649 inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
650 }
651 }
652
653 fn check_for_raw_error(
654 &self,
655 sequence: SequenceNumber,
656 ) -> Result<Option<Buffer>, ConnectionError> {
657 let _guard = crate::debug_span!("check_for_raw_error", sequence).entered();
658
659 let mut inner = self.inner.lock().unwrap();
660 if inner.inner.prepare_check_for_reply_or_error(sequence) {
661 crate::trace!("Inserting sync with the X11 server");
662 inner = self.send_sync(inner)?;
663 assert!(!inner.inner.prepare_check_for_reply_or_error(sequence));
664 }
665 inner = self.flush_impl(inner)?;
667 loop {
668 crate::trace!({ sequence }, "Polling for reply or error");
669 let poll_result = inner.inner.poll_check_for_reply_or_error(sequence);
670 match poll_result {
671 PollReply::TryAgain => {}
672 PollReply::NoReply => return Ok(None),
673 PollReply::Reply(buffer) => return Ok(Some(buffer)),
674 }
675 inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
676 }
677 }
678
679 fn wait_for_reply_with_fds_raw(
680 &self,
681 sequence: SequenceNumber,
682 ) -> Result<ReplyOrError<BufWithFds, Buffer>, ConnectionError> {
683 let _guard = crate::debug_span!("wait_for_reply_with_fds_raw", sequence).entered();
684
685 let mut inner = self.inner.lock().unwrap();
686 inner = self.flush_impl(inner)?;
688 loop {
689 crate::trace!({ sequence }, "Polling for reply or error");
690 if let Some(reply) = inner.inner.poll_for_reply_or_error(sequence) {
691 if reply.0[0] == 0 {
692 crate::trace!("Got error");
693 return Ok(ReplyOrError::Error(reply.0));
694 } else {
695 crate::trace!("Got reply");
696 return Ok(ReplyOrError::Reply(reply));
697 }
698 }
699 inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
700 }
701 }
702
703 fn maximum_request_bytes(&self) -> usize {
704 let mut max_bytes = self.maximum_request_bytes.lock().unwrap();
705 self.prefetch_maximum_request_bytes_impl(&mut max_bytes);
706 use MaxRequestBytes::*;
707 let max_bytes = &mut *max_bytes;
708 match max_bytes {
709 Unknown => unreachable!("We just prefetched this"),
710 Requested(seqno) => {
711 let _guard = crate::info_span!("maximum_request_bytes").entered();
712
713 let length = seqno
714 .and_then(|seqno| {
716 Cookie::<_, EnableReply>::new(self, seqno)
717 .reply()
719 .map(|reply| reply.maximum_request_length)
720 .ok()
721 })
722 .unwrap_or_else(|| self.setup.maximum_request_length.into())
724 .try_into()
726 .unwrap_or(usize::max_value());
727 let length = length * 4;
728 *max_bytes = Known(length);
729 crate::info!("Maximum request length is {} bytes", length);
730 length
731 }
732 Known(length) => *length,
733 }
734 }
735
736 fn prefetch_maximum_request_bytes(&self) {
737 let mut max_bytes = self.maximum_request_bytes.lock().unwrap();
738 self.prefetch_maximum_request_bytes_impl(&mut max_bytes);
739 }
740
741 fn parse_error(&self, error: &[u8]) -> Result<crate::x11_utils::X11Error, ParseError> {
742 let ext_mgr = self.extension_manager.lock().unwrap();
743 crate::x11_utils::X11Error::try_parse(error, &*ext_mgr)
744 }
745
746 fn parse_event(&self, event: &[u8]) -> Result<crate::protocol::Event, ParseError> {
747 let ext_mgr = self.extension_manager.lock().unwrap();
748 crate::protocol::Event::parse(event, &*ext_mgr)
749 }
750}
751
752impl<S: Stream> Connection for RustConnection<S> {
753 fn wait_for_raw_event_with_sequence(
754 &self,
755 ) -> Result<RawEventAndSeqNumber<Vec<u8>>, ConnectionError> {
756 let _guard = crate::trace_span!("wait_for_raw_event_with_sequence").entered();
757
758 let mut inner = self.inner.lock().unwrap();
759 loop {
760 if let Some(event) = inner.inner.poll_for_event_with_sequence() {
761 return Ok(event);
762 }
763 inner = self.read_packet_and_enqueue(inner, BlockingMode::Blocking)?;
764 }
765 }
766
767 fn poll_for_raw_event_with_sequence(
768 &self,
769 ) -> Result<Option<RawEventAndSeqNumber<Vec<u8>>>, ConnectionError> {
770 let _guard = crate::trace_span!("poll_for_raw_event_with_sequence").entered();
771
772 let mut inner = self.inner.lock().unwrap();
773 if let Some(event) = inner.inner.poll_for_event_with_sequence() {
774 Ok(Some(event))
775 } else {
776 inner = self.read_packet_and_enqueue(inner, BlockingMode::NonBlocking)?;
777 Ok(inner.inner.poll_for_event_with_sequence())
778 }
779 }
780
781 fn flush(&self) -> Result<(), ConnectionError> {
782 let inner = self.inner.lock().unwrap();
783 let _inner = self.flush_impl(inner)?;
784 Ok(())
785 }
786
787 fn setup(&self) -> &Setup {
788 &self.setup
789 }
790
791 fn generate_id(&self) -> Result<u32, ReplyOrIdError> {
792 let mut id_allocator = self.id_allocator.lock().unwrap();
793 if let Some(id) = id_allocator.generate_id() {
794 Ok(id)
795 } else {
796 use crate::protocol::xc_misc::{self, ConnectionExt as _};
797
798 if self
799 .extension_information(xc_misc::X11_EXTENSION_NAME)?
800 .is_none()
801 {
802 crate::error!("XIDs are exhausted and XC-MISC extension is not available");
803 Err(ReplyOrIdError::IdsExhausted)
804 } else {
805 crate::info!("XIDs are exhausted; fetching free range via XC-MISC");
806 id_allocator.update_xid_range(&self.xc_misc_get_xid_range()?.reply()?)?;
807 id_allocator
808 .generate_id()
809 .ok_or(ReplyOrIdError::IdsExhausted)
810 }
811 }
812 }
813}
814
815#[derive(Debug)]
817struct NotifyOnDrop<'a>(&'a Condvar);
818
819impl Drop for NotifyOnDrop<'_> {
820 fn drop(&mut self) {
821 self.0.notify_all();
822 }
823}
824
825struct RequestInfo<'a> {
827 extension_manager: &'a Mutex<ExtensionManager>,
828 major_opcode: u8,
829 minor_opcode: u8,
830}
831
832impl std::fmt::Display for RequestInfo<'_> {
833 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
834 if self.major_opcode == QUERY_EXTENSION_REQUEST {
837 write!(f, "QueryExtension request")
838 } else {
839 let guard = self.extension_manager.lock().unwrap();
840 write!(
841 f,
842 "{} request",
843 x11rb_protocol::protocol::get_request_name(
844 &*guard,
845 self.major_opcode,
846 self.minor_opcode
847 )
848 )
849 }
850 }
851}