1use std::{
2 ffi::CString,
3 os::unix::{
4 io::{AsFd, BorrowedFd, OwnedFd, RawFd},
5 net::UnixStream,
6 },
7 sync::Arc,
8};
9
10use crate::{
11 core_interfaces::{WL_CALLBACK_INTERFACE, WL_DISPLAY_INTERFACE, WL_REGISTRY_INTERFACE},
12 debug,
13 protocol::{
14 check_for_signature, same_interface, same_interface_or_anonymous, AllowNull, Argument,
15 ArgumentType, Interface, Message, ObjectInfo, ProtocolError, ANONYMOUS_INTERFACE,
16 INLINE_ARGS,
17 },
18 rs::map::SERVER_ID_LIMIT,
19 types::server::{DisconnectReason, InvalidId},
20};
21
22use smallvec::SmallVec;
23
24use crate::rs::{
25 map::{Object, ObjectMap},
26 socket::{BufferedSocket, Socket},
27 wire::MessageParseError,
28};
29
30use super::{
31 handle::PendingDestructor, registry::Registry, ClientData, ClientId, Credentials, Data,
32 DumbObjectData, GlobalHandler, InnerClientId, InnerGlobalId, InnerObjectId, ObjectData,
33 ObjectId, UninitObjectData,
34};
35
36type ArgSmallVec<Fd> = SmallVec<[Argument<ObjectId, Fd>; INLINE_ARGS]>;
37
38#[repr(u32)]
39#[allow(dead_code)]
40pub(crate) enum DisplayError {
41 InvalidObject = 0,
42 InvalidMethod = 1,
43 NoMemory = 2,
44 Implementation = 3,
45}
46
47#[derive(Debug)]
48pub(crate) struct Client<D: 'static> {
49 socket: BufferedSocket,
50 pub(crate) map: ObjectMap<Data<D>>,
51 debug: bool,
52 last_serial: u32,
53 pub(crate) id: InnerClientId,
54 pub(crate) killed: bool,
55 pub(crate) data: Arc<dyn ClientData>,
56}
57
58impl<D> Client<D> {
59 fn next_serial(&mut self) -> u32 {
60 self.last_serial = self.last_serial.wrapping_add(1);
61 self.last_serial
62 }
63}
64
65impl<D> Client<D> {
66 pub(crate) fn new(
67 stream: UnixStream,
68 id: InnerClientId,
69 debug: bool,
70 data: Arc<dyn ClientData>,
71 ) -> Self {
72 let socket = BufferedSocket::new(Socket::from(stream));
73 let mut map = ObjectMap::new();
74 map.insert_at(
75 1,
76 Object {
77 interface: &WL_DISPLAY_INTERFACE,
78 version: 1,
79 data: Data { user_data: Arc::new(DumbObjectData), serial: 0 },
80 },
81 )
82 .unwrap();
83
84 data.initialized(ClientId { id: id.clone() });
85
86 Self { socket, map, debug, id, killed: false, last_serial: 0, data }
87 }
88
89 pub(crate) fn create_object(
90 &mut self,
91 interface: &'static Interface,
92 version: u32,
93 user_data: Arc<dyn ObjectData<D>>,
94 ) -> InnerObjectId {
95 let serial = self.next_serial();
96 let id = self.map.server_insert_new(Object {
97 interface,
98 version,
99 data: Data { serial, user_data },
100 });
101 InnerObjectId { id, serial, client_id: self.id.clone(), interface }
102 }
103
104 pub(crate) fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
105 let object = self.get_object(id.clone())?;
106 Ok(ObjectInfo { id: id.id, interface: object.interface, version: object.version })
107 }
108
109 pub(crate) fn send_event(
110 &mut self,
111 Message { sender_id: object_id, opcode, args }: Message<ObjectId, RawFd>,
112 pending_destructors: Option<&mut Vec<super::handle::PendingDestructor<D>>>,
113 ) -> Result<(), InvalidId> {
114 if self.killed {
115 return Ok(());
116 }
117 let object = self.get_object(object_id.id.clone())?;
118
119 let message_desc = match object.interface.events.get(opcode as usize) {
120 Some(msg) => msg,
121 None => {
122 panic!(
123 "Unknown opcode {} for object {}@{}.",
124 opcode, object.interface.name, object_id.id
125 );
126 }
127 };
128
129 if !check_for_signature(message_desc.signature, &args) {
130 panic!(
131 "Unexpected signature for event {}@{}.{}: expected {:?}, got {:?}.",
132 object.interface.name,
133 object_id.id,
134 message_desc.name,
135 message_desc.signature,
136 args
137 );
138 }
139
140 if self.debug {
141 debug::print_send_message(
142 object.interface.name,
143 object_id.id.id,
144 message_desc.name,
145 &args,
146 false,
147 );
148 }
149
150 let mut msg_args = SmallVec::with_capacity(args.len());
151 let mut arg_interfaces = message_desc.arg_interfaces.iter();
152 for (i, arg) in args.into_iter().enumerate() {
153 msg_args.push(match arg {
154 Argument::Array(a) => Argument::Array(a),
155 Argument::Int(i) => Argument::Int(i),
156 Argument::Uint(u) => Argument::Uint(u),
157 Argument::Str(s) => Argument::Str(s),
158 Argument::Fixed(f) => Argument::Fixed(f),
159 Argument::Fd(f) => Argument::Fd(f),
160 Argument::NewId(o) => {
161 if o.id.id != 0 {
162 if o.id.client_id != self.id {
163 panic!("Attempting to send an event with objects from wrong client.")
164 }
165 let object = self.get_object(o.id.clone())?;
166 let child_interface = match message_desc.child_interface {
167 Some(iface) => iface,
168 None => panic!("Trying to send event {}@{}.{} which creates an object without specifying its interface, this is unsupported.", object_id.id.interface.name, object_id.id, message_desc.name),
169 };
170 if !same_interface(child_interface, object.interface) {
171 panic!("Event {}@{}.{} expects a newid argument of interface {} but {} was provided instead.", object.interface.name, object_id.id, message_desc.name, child_interface.name, object.interface.name);
172 }
173 } else if !matches!(message_desc.signature[i], ArgumentType::NewId) {
174 panic!("Request {}@{}.{} expects an non-null newid argument.", object.interface.name, object_id.id, message_desc.name);
175 }
176 Argument::Object(o.id.id)
177 },
178 Argument::Object(o) => {
179 let next_interface = arg_interfaces.next().unwrap();
180 if o.id.id != 0 {
181 if o.id.client_id != self.id {
182 panic!("Attempting to send an event with objects from wrong client.")
183 }
184 let arg_object = self.get_object(o.id.clone())?;
185 if !same_interface_or_anonymous(next_interface, arg_object.interface) {
186 panic!("Event {}@{}.{} expects an object argument of interface {} but {} was provided instead.", object.interface.name, object_id.id, message_desc.name, next_interface.name, arg_object.interface.name);
187 }
188 } else if !matches!(message_desc.signature[i], ArgumentType::Object(AllowNull::Yes)) {
189 panic!("Request {}@{}.{} expects an non-null object argument.", object.interface.name, object_id.id, message_desc.name);
190 }
191 Argument::Object(o.id.id)
192 }
193 });
194 }
195
196 let msg = Message { sender_id: object_id.id.id, opcode, args: msg_args };
197
198 if self.socket.write_message(&msg).is_err() {
199 self.kill(DisconnectReason::ConnectionClosed);
200 }
201
202 if message_desc.is_destructor {
204 self.map.remove(object_id.id.id);
205 if let Some(vec) = pending_destructors {
206 vec.push((object.data.user_data.clone(), self.id.clone(), object_id.id.clone()));
207 }
208 self.send_delete_id(object_id.id);
209 }
210
211 Ok(())
212 }
213
214 pub(crate) fn send_delete_id(&mut self, object_id: InnerObjectId) {
215 if object_id.id < SERVER_ID_LIMIT {
217 let msg = message!(1, 1, [Argument::Uint(object_id.id)]);
218 if self.socket.write_message(&msg).is_err() {
219 self.kill(DisconnectReason::ConnectionClosed);
220 }
221 }
222 self.map.remove(object_id.id);
223 }
224
225 pub(crate) fn get_object_data(
226 &self,
227 id: InnerObjectId,
228 ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
229 let object = self.get_object(id)?;
230 Ok(object.data.user_data)
231 }
232
233 pub(crate) fn set_object_data(
234 &mut self,
235 id: InnerObjectId,
236 data: Arc<dyn ObjectData<D>>,
237 ) -> Result<(), InvalidId> {
238 self.map
239 .with(id.id, |objdata| {
240 if objdata.data.serial != id.serial {
241 Err(InvalidId)
242 } else {
243 objdata.data.user_data = data;
244 Ok(())
245 }
246 })
247 .unwrap_or(Err(InvalidId))
248 }
249
250 pub(crate) fn post_display_error(&mut self, code: DisplayError, message: CString) {
251 self.post_error(
252 InnerObjectId {
253 id: 1,
254 interface: &WL_DISPLAY_INTERFACE,
255 client_id: self.id.clone(),
256 serial: 0,
257 },
258 code as u32,
259 message,
260 )
261 }
262
263 pub(crate) fn post_error(
264 &mut self,
265 object_id: InnerObjectId,
266 error_code: u32,
267 message: CString,
268 ) {
269 let converted_message = message.to_string_lossy().into();
270 let _ = self.send_event(
272 message!(
273 ObjectId {
274 id: InnerObjectId {
275 id: 1,
276 interface: &WL_DISPLAY_INTERFACE,
277 client_id: self.id.clone(),
278 serial: 0
279 }
280 },
281 0, [
283 Argument::Object(ObjectId { id: object_id.clone() }),
284 Argument::Uint(error_code),
285 Argument::Str(Some(Box::new(message))),
286 ],
287 ),
288 None,
290 );
291 let _ = self.flush();
292 self.kill(DisconnectReason::ProtocolError(ProtocolError {
293 code: error_code,
294 object_id: object_id.id,
295 object_interface: object_id.interface.name.into(),
296 message: converted_message,
297 }));
298 }
299
300 #[cfg(any(target_os = "linux", target_os = "android"))]
301 pub(crate) fn get_credentials(&self) -> Credentials {
302 let creds =
303 rustix::net::sockopt::get_socket_peercred(&self.socket).expect("getsockopt failed!?");
304 let pid = rustix::process::Pid::as_raw(Some(creds.pid));
305 Credentials { pid, uid: creds.uid.as_raw(), gid: creds.gid.as_raw() }
306 }
307
308 #[cfg(not(any(target_os = "linux", target_os = "android")))]
309 pub(crate) fn get_credentials(&self) -> Credentials {
311 Credentials { pid: 0, uid: 0, gid: 0 }
312 }
313
314 pub(crate) fn kill(&mut self, reason: DisconnectReason) {
315 self.killed = true;
316 self.data.disconnected(ClientId { id: self.id.clone() }, reason);
317 }
318
319 pub(crate) fn flush(&mut self) -> std::io::Result<()> {
320 self.socket.flush()
321 }
322
323 pub(crate) fn all_objects(&self) -> impl Iterator<Item = ObjectId> + '_ {
324 let client_id = self.id.clone();
325 self.map.all_objects().map(move |(id, obj)| ObjectId {
326 id: InnerObjectId {
327 id,
328 client_id: client_id.clone(),
329 interface: obj.interface,
330 serial: obj.data.serial,
331 },
332 })
333 }
334
335 #[allow(clippy::type_complexity)]
336 pub(crate) fn next_request(
337 &mut self,
338 ) -> std::io::Result<(Message<u32, OwnedFd>, Object<Data<D>>)> {
339 if self.killed {
340 return Err(rustix::io::Errno::PIPE.into());
341 }
342 loop {
343 let map = &self.map;
344 let msg = match self.socket.read_one_message(|id, opcode| {
345 map.find(id)
346 .and_then(|o| o.interface.requests.get(opcode as usize))
347 .map(|desc| desc.signature)
348 }) {
349 Ok(msg) => msg,
350 Err(MessageParseError::MissingData) | Err(MessageParseError::MissingFD) => {
351 if let Err(e) = self.socket.fill_incoming_buffers() {
353 if e.kind() != std::io::ErrorKind::WouldBlock {
354 self.kill(DisconnectReason::ConnectionClosed);
355 }
356 return Err(e);
357 }
358 continue;
359 }
360 Err(MessageParseError::Malformed) => {
361 self.kill(DisconnectReason::ConnectionClosed);
362 return Err(rustix::io::Errno::PROTO.into());
363 }
364 };
365
366 let obj = self.map.find(msg.sender_id).unwrap();
367
368 if self.debug {
369 debug::print_dispatched_message(
370 obj.interface.name,
371 msg.sender_id,
372 obj.interface.requests.get(msg.opcode as usize).unwrap().name,
373 &msg.args,
374 );
375 }
376
377 return Ok((msg, obj));
378 }
379 }
380
381 fn get_object(&self, id: InnerObjectId) -> Result<Object<Data<D>>, InvalidId> {
382 let object = self.map.find(id.id).ok_or(InvalidId)?;
383 if object.data.serial != id.serial {
384 return Err(InvalidId);
385 }
386 Ok(object)
387 }
388
389 pub(crate) fn object_for_protocol_id(&self, pid: u32) -> Result<InnerObjectId, InvalidId> {
390 let object = self.map.find(pid).ok_or(InvalidId)?;
391 Ok(InnerObjectId {
392 id: pid,
393 client_id: self.id.clone(),
394 serial: object.data.serial,
395 interface: object.interface,
396 })
397 }
398
399 fn queue_all_destructors(&mut self, pending_destructors: &mut Vec<PendingDestructor<D>>) {
400 pending_destructors.extend(self.map.all_objects().map(|(id, obj)| {
401 (
402 obj.data.user_data.clone(),
403 self.id.clone(),
404 InnerObjectId {
405 id,
406 serial: obj.data.serial,
407 client_id: self.id.clone(),
408 interface: obj.interface,
409 },
410 )
411 }));
412 }
413
414 pub(crate) fn handle_display_request(
415 &mut self,
416 message: Message<u32, OwnedFd>,
417 registry: &mut Registry<D>,
418 ) {
419 match message.opcode {
420 0 => {
422 if let [Argument::NewId(new_id)] = message.args[..] {
423 let serial = self.next_serial();
424 let callback_obj = Object {
425 interface: &WL_CALLBACK_INTERFACE,
426 version: 1,
427 data: Data { user_data: Arc::new(DumbObjectData), serial },
428 };
429 if let Err(()) = self.map.insert_at(new_id, callback_obj) {
430 self.post_display_error(
431 DisplayError::InvalidObject,
432 CString::new(format!("Invalid new_id: {}.", new_id)).unwrap(),
433 );
434 return;
435 }
436 let cb_id = ObjectId {
437 id: InnerObjectId {
438 id: new_id,
439 client_id: self.id.clone(),
440 serial,
441 interface: &WL_CALLBACK_INTERFACE,
442 },
443 };
444 self.send_event(message!(cb_id, 0, [Argument::Uint(0)]), None).unwrap();
446 } else {
447 unreachable!()
448 }
449 }
450 1 => {
452 if let [Argument::NewId(new_id)] = message.args[..] {
453 let serial = self.next_serial();
454 let registry_obj = Object {
455 interface: &WL_REGISTRY_INTERFACE,
456 version: 1,
457 data: Data { user_data: Arc::new(DumbObjectData), serial },
458 };
459 let registry_id = InnerObjectId {
460 id: new_id,
461 serial,
462 client_id: self.id.clone(),
463 interface: &WL_REGISTRY_INTERFACE,
464 };
465 if let Err(()) = self.map.insert_at(new_id, registry_obj) {
466 self.post_display_error(
467 DisplayError::InvalidObject,
468 CString::new(format!("Invalid new_id: {}.", new_id)).unwrap(),
469 );
470 return;
471 }
472 let _ = registry.new_registry(registry_id, self);
473 } else {
474 unreachable!()
475 }
476 }
477 _ => {
478 self.post_display_error(
480 DisplayError::InvalidMethod,
481 CString::new(format!(
482 "Unknown opcode {} for interface wl_display.",
483 message.opcode
484 ))
485 .unwrap(),
486 );
487 }
488 }
489 }
490
491 #[allow(clippy::type_complexity)]
492 pub(crate) fn handle_registry_request(
493 &mut self,
494 message: Message<u32, OwnedFd>,
495 registry: &mut Registry<D>,
496 ) -> Option<(InnerClientId, InnerGlobalId, InnerObjectId, Arc<dyn GlobalHandler<D>>)> {
497 match message.opcode {
498 0 => {
500 if let [Argument::Uint(name), Argument::Str(Some(ref interface_name)), Argument::Uint(version), Argument::NewId(new_id)] =
501 message.args[..]
502 {
503 if let Some((interface, global_id, handler)) =
504 registry.check_bind(self, name, interface_name, version)
505 {
506 let serial = self.next_serial();
507 let object = Object {
508 interface,
509 version,
510 data: Data { serial, user_data: Arc::new(UninitObjectData) },
511 };
512 if let Err(()) = self.map.insert_at(new_id, object) {
513 self.post_display_error(
514 DisplayError::InvalidObject,
515 CString::new(format!("Invalid new_id: {}.", new_id)).unwrap(),
516 );
517 return None;
518 }
519 Some((
520 self.id.clone(),
521 global_id,
522 InnerObjectId {
523 id: new_id,
524 client_id: self.id.clone(),
525 interface,
526 serial,
527 },
528 handler.clone(),
529 ))
530 } else {
531 self.post_display_error(
532 DisplayError::InvalidObject,
533 CString::new(format!(
534 "Invalid binding of {} version {} for global {}.",
535 interface_name.to_string_lossy(),
536 version,
537 name
538 ))
539 .unwrap(),
540 );
541 None
542 }
543 } else {
544 unreachable!()
545 }
546 }
547 _ => {
548 self.post_display_error(
550 DisplayError::InvalidMethod,
551 CString::new(format!(
552 "Unknown opcode {} for interface wl_registry.",
553 message.opcode
554 ))
555 .unwrap(),
556 );
557 None
558 }
559 }
560 }
561
562 pub(crate) fn process_request(
563 &mut self,
564 object: &Object<Data<D>>,
565 message: Message<u32, OwnedFd>,
566 ) -> Option<(ArgSmallVec<OwnedFd>, bool, Option<InnerObjectId>)> {
567 let message_desc = object.interface.requests.get(message.opcode as usize).unwrap();
568 let mut new_args = SmallVec::with_capacity(message.args.len());
570 let mut arg_interfaces = message_desc.arg_interfaces.iter();
571 let mut created_id = None;
572 for (i, arg) in message.args.into_iter().enumerate() {
573 new_args.push(match arg {
574 Argument::Array(a) => Argument::Array(a),
575 Argument::Int(i) => Argument::Int(i),
576 Argument::Uint(u) => Argument::Uint(u),
577 Argument::Str(s) => Argument::Str(s),
578 Argument::Fixed(f) => Argument::Fixed(f),
579 Argument::Fd(f) => Argument::Fd(f),
580 Argument::Object(o) => {
581 let next_interface = arg_interfaces.next();
582 if o != 0 {
583 let obj = match self.map.find(o) {
585 Some(o) => o,
586 None => {
587 self.post_display_error(
588 DisplayError::InvalidObject,
589 CString::new(format!("Unknown id: {}.", o)).unwrap()
590 );
591 return None;
592 }
593 };
594 if let Some(next_interface) = next_interface {
595 if !same_interface_or_anonymous(next_interface, obj.interface) {
596 self.post_display_error(
597 DisplayError::InvalidObject,
598 CString::new(format!(
599 "Invalid object {} in request {}.{}: expected {} but got {}.",
600 o,
601 object.interface.name,
602 message_desc.name,
603 next_interface.name,
604 obj.interface.name,
605 )).unwrap()
606 );
607 return None;
608 }
609 }
610 Argument::Object(ObjectId { id: InnerObjectId { id: o, client_id: self.id.clone(), serial: obj.data.serial, interface: obj.interface }})
611 } else if matches!(message_desc.signature[i], ArgumentType::Object(AllowNull::Yes)) {
612 Argument::Object(ObjectId { id: InnerObjectId { id: 0, client_id: self.id.clone(), serial: 0, interface: &ANONYMOUS_INTERFACE }})
613 } else {
614 self.post_display_error(
615 DisplayError::InvalidObject,
616 CString::new(format!(
617 "Invalid null object in request {}.{}.",
618 object.interface.name,
619 message_desc.name,
620 )).unwrap()
621 );
622 return None;
623 }
624 }
625 Argument::NewId(new_id) => {
626 let child_interface = match message_desc.child_interface {
628 Some(iface) => iface,
629 None => panic!("Received request {}@{}.{} which creates an object without specifying its interface, this is unsupported.", object.interface.name, message.sender_id, message_desc.name),
630 };
631
632 let child_udata = Arc::new(UninitObjectData);
633
634 let child_obj = Object {
635 interface: child_interface,
636 version: object.version,
637 data: Data {
638 user_data: child_udata,
639 serial: self.next_serial(),
640 }
641 };
642
643 let child_id = InnerObjectId { id: new_id, client_id: self.id.clone(), serial: child_obj.data.serial, interface: child_obj.interface };
644 created_id = Some(child_id.clone());
645
646 if let Err(()) = self.map.insert_at(new_id, child_obj) {
647 self.post_display_error(
649 DisplayError::InvalidObject,
650 CString::new(format!("Invalid new_id: {}.", new_id)).unwrap()
651 );
652 return None;
653 }
654
655 Argument::NewId(ObjectId { id: child_id })
656 }
657 });
658 }
659 Some((new_args, message_desc.is_destructor, created_id))
660 }
661}
662
663impl<D> AsFd for Client<D> {
664 fn as_fd(&self) -> BorrowedFd<'_> {
665 self.socket.as_fd()
666 }
667}
668
669#[derive(Debug)]
670pub(crate) struct ClientStore<D: 'static> {
671 clients: Vec<Option<Client<D>>>,
672 last_serial: u32,
673 debug: bool,
674}
675
676impl<D> ClientStore<D> {
677 pub(crate) fn new(debug: bool) -> Self {
678 Self { clients: Vec::new(), last_serial: 0, debug }
679 }
680
681 pub(crate) fn create_client(
682 &mut self,
683 stream: UnixStream,
684 data: Arc<dyn ClientData>,
685 ) -> InnerClientId {
686 let serial = self.next_serial();
687 let (id, place) = match self.clients.iter_mut().enumerate().find(|(_, c)| c.is_none()) {
689 Some((id, place)) => (id, place),
690 None => {
691 self.clients.push(None);
692 (self.clients.len() - 1, self.clients.last_mut().unwrap())
693 }
694 };
695
696 let id = InnerClientId { id: id as u32, serial };
697
698 *place = Some(Client::new(stream, id.clone(), self.debug, data));
699
700 id
701 }
702
703 pub(crate) fn get_client(&self, id: InnerClientId) -> Result<&Client<D>, InvalidId> {
704 match self.clients.get(id.id as usize) {
705 Some(Some(client)) if client.id == id => Ok(client),
706 _ => Err(InvalidId),
707 }
708 }
709
710 pub(crate) fn get_client_mut(
711 &mut self,
712 id: InnerClientId,
713 ) -> Result<&mut Client<D>, InvalidId> {
714 match self.clients.get_mut(id.id as usize) {
715 Some(&mut Some(ref mut client)) if client.id == id => Ok(client),
716 _ => Err(InvalidId),
717 }
718 }
719
720 pub(crate) fn cleanup(
721 &mut self,
722 pending_destructors: &mut Vec<PendingDestructor<D>>,
723 ) -> SmallVec<[Client<D>; 1]> {
724 let mut cleaned = SmallVec::new();
725 for place in &mut self.clients {
726 if place.as_ref().map(|client| client.killed).unwrap_or(false) {
727 let mut client = place.take().unwrap();
729 client.queue_all_destructors(pending_destructors);
730 let _ = client.flush();
731 cleaned.push(client);
732 }
733 }
734 cleaned
735 }
736
737 fn next_serial(&mut self) -> u32 {
738 self.last_serial = self.last_serial.wrapping_add(1);
739 self.last_serial
740 }
741
742 pub(crate) fn clients_mut(&mut self) -> impl Iterator<Item = &mut Client<D>> {
743 self.clients.iter_mut().flat_map(|o| o.as_mut()).filter(|c| !c.killed)
744 }
745
746 pub(crate) fn all_clients_id(&self) -> impl Iterator<Item = ClientId> + '_ {
747 self.clients.iter().flat_map(|opt| {
748 opt.as_ref().filter(|c| !c.killed).map(|client| ClientId { id: client.id.clone() })
749 })
750 }
751}