1use std::{
4 fmt,
5 os::unix::{
6 io::{AsRawFd, BorrowedFd, OwnedFd, RawFd},
7 net::UnixStream,
8 },
9 sync::{Arc, Condvar, Mutex, MutexGuard, Weak},
10};
11
12use crate::{
13 core_interfaces::WL_DISPLAY_INTERFACE,
14 debug,
15 protocol::{
16 check_for_signature, same_interface, same_interface_or_anonymous, AllowNull, Argument,
17 ArgumentType, Interface, Message, ObjectInfo, ProtocolError, ANONYMOUS_INTERFACE,
18 INLINE_ARGS,
19 },
20};
21use smallvec::SmallVec;
22
23use super::{
24 client::*,
25 map::{Object, ObjectMap, SERVER_ID_LIMIT},
26 socket::{BufferedSocket, Socket},
27 wire::MessageParseError,
28};
29
30#[derive(Debug, Clone)]
31struct Data {
32 client_destroyed: bool,
33 server_destroyed: bool,
34 user_data: Arc<dyn ObjectData>,
35 serial: u32,
36}
37
38#[derive(Clone)]
40pub struct InnerObjectId {
41 serial: u32,
42 id: u32,
43 interface: &'static Interface,
44}
45
46impl std::cmp::PartialEq for InnerObjectId {
47 fn eq(&self, other: &Self) -> bool {
48 self.id == other.id
49 && self.serial == other.serial
50 && same_interface(self.interface, other.interface)
51 }
52}
53
54impl std::cmp::Eq for InnerObjectId {}
55
56impl std::hash::Hash for InnerObjectId {
57 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
58 self.serial.hash(state);
59 self.id.hash(state);
60 }
61}
62
63impl fmt::Display for InnerObjectId {
64 #[cfg_attr(coverage, coverage(off))]
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 write!(f, "{}@{}", self.interface.name, self.id)
67 }
68}
69
70impl fmt::Debug for InnerObjectId {
71 #[cfg_attr(coverage, coverage(off))]
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 write!(f, "ObjectId({}, {})", self, self.serial)
74 }
75}
76
77impl InnerObjectId {
78 pub fn is_null(&self) -> bool {
79 self.id == 0
80 }
81
82 pub fn interface(&self) -> &'static Interface {
83 self.interface
84 }
85
86 pub fn protocol_id(&self) -> u32 {
87 self.id
88 }
89}
90
91#[derive(Debug)]
92struct ProtocolState {
93 socket: BufferedSocket,
94 map: ObjectMap<Data>,
95 last_error: Option<WaylandError>,
96 last_serial: u32,
97 debug: bool,
98}
99
100#[derive(Debug)]
101struct ReadingState {
102 prepared_reads: usize,
103 read_condvar: Arc<Condvar>,
104 read_serial: usize,
105}
106
107#[derive(Debug)]
108pub struct ConnectionState {
109 protocol: Mutex<ProtocolState>,
110 read: Mutex<ReadingState>,
111}
112
113impl ConnectionState {
114 fn lock_protocol(&self) -> MutexGuard<ProtocolState> {
115 self.protocol.lock().unwrap()
116 }
117
118 fn lock_read(&self) -> MutexGuard<ReadingState> {
119 self.read.lock().unwrap()
120 }
121}
122
123#[derive(Clone, Debug)]
124pub struct InnerBackend {
125 state: Arc<ConnectionState>,
126}
127
128#[derive(Clone, Debug)]
129pub struct WeakInnerBackend {
130 state: Weak<ConnectionState>,
131}
132
133impl WeakInnerBackend {
134 pub fn upgrade(&self) -> Option<InnerBackend> {
135 Weak::upgrade(&self.state).map(|state| InnerBackend { state })
136 }
137}
138
139impl PartialEq for InnerBackend {
140 fn eq(&self, rhs: &Self) -> bool {
141 Arc::ptr_eq(&self.state, &rhs.state)
142 }
143}
144
145impl Eq for InnerBackend {}
146
147impl InnerBackend {
148 pub fn downgrade(&self) -> WeakInnerBackend {
149 WeakInnerBackend { state: Arc::downgrade(&self.state) }
150 }
151
152 pub fn connect(stream: UnixStream) -> Result<Self, NoWaylandLib> {
153 let socket = BufferedSocket::new(Socket::from(stream));
154 let mut map = ObjectMap::new();
155 map.insert_at(
156 1,
157 Object {
158 interface: &WL_DISPLAY_INTERFACE,
159 version: 1,
160 data: Data {
161 client_destroyed: false,
162 server_destroyed: false,
163 user_data: Arc::new(DumbObjectData),
164 serial: 0,
165 },
166 },
167 )
168 .unwrap();
169
170 let debug = debug::has_debug_client_env();
171
172 Ok(Self {
173 state: Arc::new(ConnectionState {
174 protocol: Mutex::new(ProtocolState {
175 socket,
176 map,
177 last_error: None,
178 last_serial: 0,
179 debug,
180 }),
181 read: Mutex::new(ReadingState {
182 prepared_reads: 0,
183 read_condvar: Arc::new(Condvar::new()),
184 read_serial: 0,
185 }),
186 }),
187 })
188 }
189
190 pub fn flush(&self) -> Result<(), WaylandError> {
192 let mut guard = self.state.lock_protocol();
193 guard.no_last_error()?;
194 if let Err(e) = guard.socket.flush() {
195 return Err(guard.store_if_not_wouldblock_and_return_error(e));
196 }
197 Ok(())
198 }
199
200 pub fn poll_fd(&self) -> BorrowedFd {
201 let raw_fd = self.state.lock_protocol().socket.as_raw_fd();
202 unsafe { BorrowedFd::borrow_raw(raw_fd) }
205 }
206}
207
208#[derive(Debug)]
209pub struct InnerReadEventsGuard {
210 state: Arc<ConnectionState>,
211 done: bool,
212}
213
214impl InnerReadEventsGuard {
215 pub fn try_new(backend: InnerBackend) -> Option<Self> {
220 backend.state.lock_read().prepared_reads += 1;
221 Some(Self { state: backend.state, done: false })
222 }
223
224 pub fn connection_fd(&self) -> BorrowedFd {
226 let raw_fd = self.state.lock_protocol().socket.as_raw_fd();
227 unsafe { BorrowedFd::borrow_raw(raw_fd) }
230 }
231
232 pub fn read(mut self) -> Result<usize, WaylandError> {
242 let mut guard = self.state.lock_read();
243 guard.prepared_reads -= 1;
244 self.done = true;
245 if guard.prepared_reads == 0 {
246 let ret = dispatch_events(self.state.clone());
248 guard.read_serial = guard.read_serial.wrapping_add(1);
250 guard.read_condvar.notify_all();
251 ret
253 } else {
254 let serial = guard.read_serial;
256 let condvar = guard.read_condvar.clone();
257 let _guard =
258 condvar.wait_while(guard, |backend| serial == backend.read_serial).unwrap();
259 self.state.lock_protocol().no_last_error()?;
260 Ok(0)
261 }
262 }
263}
264
265impl Drop for InnerReadEventsGuard {
266 fn drop(&mut self) {
267 if !self.done {
268 let mut guard = self.state.lock_read();
269 guard.prepared_reads -= 1;
270 if guard.prepared_reads == 0 {
271 guard.read_serial = guard.read_serial.wrapping_add(1);
273 guard.read_condvar.notify_all();
274 }
275 }
276 }
277}
278
279impl InnerBackend {
280 pub fn display_id(&self) -> ObjectId {
281 ObjectId { id: InnerObjectId { serial: 0, id: 1, interface: &WL_DISPLAY_INTERFACE } }
282 }
283
284 pub fn last_error(&self) -> Option<WaylandError> {
285 self.state.lock_protocol().last_error.clone()
286 }
287
288 pub fn info(&self, id: ObjectId) -> Result<ObjectInfo, InvalidId> {
289 let object = self.state.lock_protocol().get_object(id.id.clone())?;
290 if object.data.client_destroyed {
291 Err(InvalidId)
292 } else {
293 Ok(ObjectInfo { id: id.id.id, interface: object.interface, version: object.version })
294 }
295 }
296
297 pub fn null_id() -> ObjectId {
298 ObjectId { id: InnerObjectId { serial: 0, id: 0, interface: &ANONYMOUS_INTERFACE } }
299 }
300
301 pub fn send_request(
302 &self,
303 Message { sender_id: ObjectId { id }, opcode, args }: Message<ObjectId, RawFd>,
304 data: Option<Arc<dyn ObjectData>>,
305 child_spec: Option<(&'static Interface, u32)>,
306 ) -> Result<ObjectId, InvalidId> {
307 let mut guard = self.state.lock_protocol();
308 let object = guard.get_object(id.clone())?;
309
310 let message_desc = match object.interface.requests.get(opcode as usize) {
311 Some(msg) => msg,
312 None => {
313 panic!("Unknown opcode {} for object {}@{}.", opcode, object.interface.name, id.id);
314 }
315 };
316
317 if object.data.client_destroyed {
318 if guard.debug {
319 debug::print_send_message(id.interface.name, id.id, message_desc.name, &args, true);
320 }
321 return Err(InvalidId);
322 }
323
324 if !check_for_signature(message_desc.signature, &args) {
325 panic!(
326 "Unexpected signature for request {}@{}.{}: expected {:?}, got {:?}.",
327 object.interface.name, id.id, message_desc.name, message_desc.signature, args
328 );
329 }
330
331 let child_spec = if message_desc
333 .signature
334 .iter()
335 .any(|arg| matches!(arg, ArgumentType::NewId))
336 {
337 if let Some((iface, version)) = child_spec {
338 if let Some(child_interface) = message_desc.child_interface {
339 if !same_interface(child_interface, iface) {
340 panic!(
341 "Error when sending request {}@{}.{}: expected interface {} but got {}",
342 object.interface.name,
343 id.id,
344 message_desc.name,
345 child_interface.name,
346 iface.name
347 );
348 }
349 if version != object.version {
350 panic!(
351 "Error when sending request {}@{}.{}: expected version {} but got {}",
352 object.interface.name,
353 id.id,
354 message_desc.name,
355 object.version,
356 version
357 );
358 }
359 }
360 Some((iface, version))
361 } else if let Some(child_interface) = message_desc.child_interface {
362 Some((child_interface, object.version))
363 } else {
364 panic!(
365 "Error when sending request {}@{}.{}: target interface must be specified for a generic constructor.",
366 object.interface.name,
367 id.id,
368 message_desc.name
369 );
370 }
371 } else {
372 None
373 };
374
375 let child = if let Some((child_interface, child_version)) = child_spec {
376 let child_serial = guard.next_serial();
377
378 let child = Object {
379 interface: child_interface,
380 version: child_version,
381 data: Data {
382 client_destroyed: false,
383 server_destroyed: false,
384 user_data: Arc::new(DumbObjectData),
385 serial: child_serial,
386 },
387 };
388
389 let child_id = guard.map.client_insert_new(child);
390
391 guard
392 .map
393 .with(child_id, |obj| {
394 obj.data.user_data = data.expect(
395 "Sending a request creating an object without providing an object data.",
396 );
397 })
398 .unwrap();
399 Some((child_id, child_serial, child_interface))
400 } else {
401 None
402 };
403
404 let args = args.into_iter().map(|arg| {
406 if let Argument::NewId(ObjectId { id: p }) = arg {
407 if p.id != 0 {
408 panic!("The newid provided when sending request {}@{}.{} is not a placeholder.", object.interface.name, id.id, message_desc.name);
409 }
410 if let Some((child_id, child_serial, child_interface)) = child {
411 Argument::NewId(ObjectId { id: InnerObjectId { id: child_id, serial: child_serial, interface: child_interface}})
412 } else {
413 unreachable!();
414 }
415 } else {
416 arg
417 }
418 }).collect::<SmallVec<[_; INLINE_ARGS]>>();
419
420 if guard.debug {
421 debug::print_send_message(
422 object.interface.name,
423 id.id,
424 message_desc.name,
425 &args,
426 false,
427 );
428 }
429 #[cfg(feature = "log")]
430 crate::log_debug!("Sending {}.{} ({})", id, message_desc.name, debug::DisplaySlice(&args));
431
432 let mut msg_args = SmallVec::with_capacity(args.len());
435 let mut arg_interfaces = message_desc.arg_interfaces.iter();
436 for (i, arg) in args.into_iter().enumerate() {
437 msg_args.push(match arg {
438 Argument::Array(a) => Argument::Array(a),
439 Argument::Int(i) => Argument::Int(i),
440 Argument::Uint(u) => Argument::Uint(u),
441 Argument::Str(s) => Argument::Str(s),
442 Argument::Fixed(f) => Argument::Fixed(f),
443 Argument::NewId(nid) => Argument::NewId(nid.id.id),
444 Argument::Fd(f) => Argument::Fd(f),
445 Argument::Object(o) => {
446 let next_interface = arg_interfaces.next().unwrap();
447 if o.id.id != 0 {
448 let arg_object = guard.get_object(o.id.clone())?;
449 if !same_interface_or_anonymous(next_interface, arg_object.interface) {
450 panic!("Request {}@{}.{} expects an argument of interface {} but {} was provided instead.", object.interface.name, id.id, message_desc.name, next_interface.name, arg_object.interface.name);
451 }
452 } else if !matches!(message_desc.signature[i], ArgumentType::Object(AllowNull::Yes)) {
453 panic!("Request {}@{}.{} expects an non-null object argument.", object.interface.name, id.id, message_desc.name);
454 }
455 Argument::Object(o.id.id)
456 }
457 });
458 }
459
460 let msg = Message { sender_id: id.id, opcode, args: msg_args };
461
462 if let Err(err) = guard.socket.write_message(&msg) {
463 guard.last_error = Some(WaylandError::Io(err));
464 }
465
466 if message_desc.is_destructor {
468 guard
469 .map
470 .with(id.id, |obj| {
471 obj.data.client_destroyed = true;
472 })
473 .unwrap();
474 object.data.user_data.destroyed(ObjectId { id });
475 }
476 if let Some((child_id, child_serial, child_interface)) = child {
477 Ok(ObjectId {
478 id: InnerObjectId {
479 id: child_id,
480 serial: child_serial,
481 interface: child_interface,
482 },
483 })
484 } else {
485 Ok(Self::null_id())
486 }
487 }
488
489 pub fn get_data(&self, id: ObjectId) -> Result<Arc<dyn ObjectData>, InvalidId> {
490 let object = self.state.lock_protocol().get_object(id.id)?;
491 Ok(object.data.user_data)
492 }
493
494 pub fn set_data(&self, id: ObjectId, data: Arc<dyn ObjectData>) -> Result<(), InvalidId> {
495 self.state
496 .lock_protocol()
497 .map
498 .with(id.id.id, move |objdata| {
499 if objdata.data.serial != id.id.serial {
500 Err(InvalidId)
501 } else {
502 objdata.data.user_data = data;
503 Ok(())
504 }
505 })
506 .unwrap_or(Err(InvalidId))
507 }
508
509 pub fn dispatch_inner_queue(&self) -> Result<usize, WaylandError> {
511 Ok(0)
512 }
513}
514
515impl ProtocolState {
516 fn next_serial(&mut self) -> u32 {
517 self.last_serial = self.last_serial.wrapping_add(1);
518 self.last_serial
519 }
520
521 #[inline]
522 fn no_last_error(&self) -> Result<(), WaylandError> {
523 if let Some(ref err) = self.last_error {
524 Err(err.clone())
525 } else {
526 Ok(())
527 }
528 }
529
530 #[inline]
531 fn store_and_return_error(&mut self, err: impl Into<WaylandError>) -> WaylandError {
532 let err = err.into();
533 crate::log_error!("{}", err);
534 self.last_error = Some(err.clone());
535 err
536 }
537
538 #[inline]
539 fn store_if_not_wouldblock_and_return_error(&mut self, e: std::io::Error) -> WaylandError {
540 if e.kind() != std::io::ErrorKind::WouldBlock {
541 self.store_and_return_error(e)
542 } else {
543 e.into()
544 }
545 }
546
547 fn get_object(&self, id: InnerObjectId) -> Result<Object<Data>, InvalidId> {
548 let object = self.map.find(id.id).ok_or(InvalidId)?;
549 if object.data.serial != id.serial {
550 return Err(InvalidId);
551 }
552 Ok(object)
553 }
554
555 fn handle_display_event(&mut self, message: Message<u32, OwnedFd>) -> Result<(), WaylandError> {
556 if self.debug {
557 debug::print_dispatched_message(
558 "wl_display",
559 message.sender_id,
560 if message.opcode == 0 { "error" } else { "delete_id" },
561 &message.args,
562 );
563 }
564 match message.opcode {
565 0 => {
566 if let [Argument::Object(obj), Argument::Uint(code), Argument::Str(Some(ref message))] =
568 message.args[..]
569 {
570 let object = self.map.find(obj);
571 let err = WaylandError::Protocol(ProtocolError {
572 code,
573 object_id: obj,
574 object_interface: object
575 .map(|obj| obj.interface.name)
576 .unwrap_or("<unknown>")
577 .into(),
578 message: message.to_string_lossy().into(),
579 });
580 return Err(self.store_and_return_error(err));
581 } else {
582 unreachable!()
583 }
584 }
585 1 => {
586 if let [Argument::Uint(id)] = message.args[..] {
588 let client_destroyed = self
589 .map
590 .with(id, |obj| {
591 obj.data.server_destroyed = true;
592 obj.data.client_destroyed
593 })
594 .unwrap_or(false);
595 if client_destroyed {
596 self.map.remove(id);
597 }
598 } else {
599 unreachable!()
600 }
601 }
602 _ => unreachable!(),
603 }
604 Ok(())
605 }
606}
607
608fn dispatch_events(state: Arc<ConnectionState>) -> Result<usize, WaylandError> {
609 let backend = Backend { backend: InnerBackend { state } };
610 let mut guard = backend.backend.state.lock_protocol();
611 guard.no_last_error()?;
612 let mut dispatched = 0;
613 loop {
614 let ProtocolState { ref mut socket, ref map, .. } = *guard;
616 let message = match socket.read_one_message(|id, opcode| {
617 map.find(id)
618 .and_then(|o| o.interface.events.get(opcode as usize))
619 .map(|desc| desc.signature)
620 }) {
621 Ok(msg) => msg,
622 Err(MessageParseError::MissingData) | Err(MessageParseError::MissingFD) => {
623 if let Err(e) = guard.socket.fill_incoming_buffers() {
625 if e.kind() != std::io::ErrorKind::WouldBlock {
626 return Err(guard.store_and_return_error(e));
627 } else if dispatched == 0 {
628 return Err(e.into());
629 } else {
630 break;
631 }
632 }
633 continue;
634 }
635 Err(MessageParseError::Malformed) => {
636 let err = WaylandError::Protocol(ProtocolError {
638 code: 0,
639 object_id: 0,
640 object_interface: "".into(),
641 message: "Malformed Wayland message.".into(),
642 });
643 return Err(guard.store_and_return_error(err));
644 }
645 };
646
647 let receiver = guard.map.find(message.sender_id).unwrap();
650 let message_desc = receiver.interface.events.get(message.opcode as usize).unwrap();
651
652 if message.sender_id == 1 {
654 guard.handle_display_event(message)?;
655 continue;
656 }
657
658 let mut created_id = None;
659
660 let mut args = SmallVec::with_capacity(message.args.len());
662 let mut arg_interfaces = message_desc.arg_interfaces.iter();
663 for arg in message.args.into_iter() {
664 args.push(match arg {
665 Argument::Array(a) => Argument::Array(a),
666 Argument::Int(i) => Argument::Int(i),
667 Argument::Uint(u) => Argument::Uint(u),
668 Argument::Str(s) => Argument::Str(s),
669 Argument::Fixed(f) => Argument::Fixed(f),
670 Argument::Fd(f) => Argument::Fd(f),
671 Argument::Object(o) => {
672 if o != 0 {
673 let obj = match guard.map.find(o) {
675 Some(o) => o,
676 None => {
677 let err = WaylandError::Protocol(ProtocolError {
678 code: 0,
679 object_id: 0,
680 object_interface: "".into(),
681 message: format!("Unknown object {}.", o),
682 });
683 return Err(guard.store_and_return_error(err));
684 }
685 };
686 if let Some(next_interface) = arg_interfaces.next() {
687 if !same_interface_or_anonymous(next_interface, obj.interface) {
688 let err = WaylandError::Protocol(ProtocolError {
689 code: 0,
690 object_id: 0,
691 object_interface: "".into(),
692 message: format!(
693 "Protocol error: server sent object {} for interface {}, but it has interface {}.",
694 o, next_interface.name, obj.interface.name
695 ),
696 });
697 return Err(guard.store_and_return_error(err));
698 }
699 }
700 Argument::Object(ObjectId { id: InnerObjectId { id: o, serial: obj.data.serial, interface: obj.interface }})
701 } else {
702 Argument::Object(ObjectId { id: InnerObjectId { id: 0, serial: 0, interface: &ANONYMOUS_INTERFACE }})
703 }
704 }
705 Argument::NewId(new_id) => {
706 let child_interface = match message_desc.child_interface {
708 Some(iface) => iface,
709 None => panic!("Received event {}@{}.{} which creates an object without specifying its interface, this is unsupported.", receiver.interface.name, message.sender_id, message_desc.name),
710 };
711
712 let child_udata = Arc::new(UninitObjectData);
713
714 if new_id >= SERVER_ID_LIMIT
716 && guard.map.with(new_id, |obj| obj.data.client_destroyed).unwrap_or(false)
717 {
718 guard.map.remove(new_id);
719 }
720
721 let child_obj = Object {
722 interface: child_interface,
723 version: receiver.version,
724 data: Data {
725 client_destroyed: receiver.data.client_destroyed,
726 server_destroyed: false,
727 user_data: child_udata,
728 serial: guard.next_serial(),
729 }
730 };
731
732 let child_id = InnerObjectId { id: new_id, serial: child_obj.data.serial, interface: child_obj.interface };
733 created_id = Some(child_id.clone());
734
735 if let Err(()) = guard.map.insert_at(new_id, child_obj) {
736 let err = WaylandError::Protocol(ProtocolError {
738 code: 0,
739 object_id: 0,
740 object_interface: "".into(),
741 message: format!(
742 "Protocol error: server tried to create \
743 an object \"{}\" with invalid id {}.",
744 child_interface.name, new_id
745 ),
746 });
747 return Err(guard.store_and_return_error(err));
748 }
749
750 Argument::NewId(ObjectId { id: child_id })
751 }
752 });
753 }
754
755 if guard.debug {
756 debug::print_dispatched_message(
757 receiver.interface.name,
758 message.sender_id,
759 message_desc.name,
760 &args,
761 );
762 }
763
764 if receiver.data.client_destroyed {
766 continue;
767 }
768
769 let id = InnerObjectId {
771 id: message.sender_id,
772 serial: receiver.data.serial,
773 interface: receiver.interface,
774 };
775
776 std::mem::drop(guard);
778 #[cfg(feature = "log")]
779 crate::log_debug!(
780 "Dispatching {}.{} ({})",
781 id,
782 receiver.version,
783 debug::DisplaySlice(&args)
784 );
785 let ret = receiver
786 .data
787 .user_data
788 .clone()
789 .event(&backend, Message { sender_id: ObjectId { id }, opcode: message.opcode, args });
790 guard = backend.backend.state.lock_protocol();
792
793 if message_desc.is_destructor {
795 guard
796 .map
797 .with(message.sender_id, |obj| {
798 obj.data.server_destroyed = true;
799 obj.data.client_destroyed = true;
800 })
801 .unwrap();
802 receiver.data.user_data.destroyed(ObjectId {
803 id: InnerObjectId {
804 id: message.sender_id,
805 serial: receiver.data.serial,
806 interface: receiver.interface,
807 },
808 });
809 }
810
811 match (created_id, ret) {
812 (Some(child_id), Some(child_data)) => {
813 guard.map.with(child_id.id, |obj| obj.data.user_data = child_data).unwrap();
814 }
815 (None, None) => {}
816 (Some(child_id), None) => {
817 panic!("Callback creating object {} did not provide any object data.", child_id);
818 }
819 (None, Some(_)) => {
820 panic!("An object data was returned from a callback not creating any object");
821 }
822 }
823
824 dispatched += 1;
825 }
826 Ok(dispatched)
827}