wayland_backend/rs/server_impl/
handle.rs

1use std::{
2    ffi::CString,
3    os::unix::{
4        io::{OwnedFd, RawFd},
5        net::UnixStream,
6    },
7    sync::{Arc, Mutex, Weak},
8};
9
10use crate::{
11    protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
12    types::server::{DisconnectReason, GlobalInfo, InvalidId},
13};
14
15use super::{
16    client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
17    InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
18};
19
20pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
21
22#[derive(Debug)]
23pub struct State<D: 'static> {
24    pub(crate) clients: ClientStore<D>,
25    pub(crate) registry: Registry<D>,
26    pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
27    pub(crate) poll_fd: OwnedFd,
28}
29
30impl<D> State<D> {
31    pub(crate) fn new(poll_fd: OwnedFd) -> Self {
32        let debug =
33            matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
34        Self {
35            clients: ClientStore::new(debug),
36            registry: Registry::new(),
37            pending_destructors: Vec::new(),
38            poll_fd,
39        }
40    }
41
42    pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
43        let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
44        self.registry.cleanup(&dead_clients, &self.pending_destructors);
45        // return a closure that will do the cleanup once invoked
46        let pending_destructors = std::mem::take(&mut self.pending_destructors);
47        move |handle, data| {
48            for (object_data, client_id, object_id) in pending_destructors {
49                object_data.clone().destroyed(
50                    handle,
51                    data,
52                    ClientId { id: client_id },
53                    ObjectId { id: object_id },
54                );
55            }
56            std::mem::drop(dead_clients);
57        }
58    }
59
60    pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
61        if let Some(ClientId { id: client }) = client {
62            match self.clients.get_client_mut(client) {
63                Ok(client) => client.flush(),
64                Err(InvalidId) => Ok(()),
65            }
66        } else {
67            for client in self.clients.clients_mut() {
68                let _ = client.flush();
69            }
70            Ok(())
71        }
72    }
73}
74
75#[derive(Clone)]
76pub struct InnerHandle {
77    pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
78}
79
80impl std::fmt::Debug for InnerHandle {
81    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        fmt.debug_struct("InnerHandle[rs]").finish_non_exhaustive()
83    }
84}
85
86#[derive(Clone)]
87pub struct WeakInnerHandle {
88    pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
89}
90
91impl std::fmt::Debug for WeakInnerHandle {
92    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        fmt.debug_struct("WeakInnerHandle[rs]").finish_non_exhaustive()
94    }
95}
96
97impl WeakInnerHandle {
98    pub fn upgrade(&self) -> Option<InnerHandle> {
99        self.state.upgrade().map(|state| InnerHandle { state })
100    }
101}
102
103impl InnerHandle {
104    pub fn downgrade(&self) -> WeakInnerHandle {
105        WeakInnerHandle { state: Arc::downgrade(&self.state) }
106    }
107
108    pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
109        self.state.lock().unwrap().object_info(id)
110    }
111
112    pub fn insert_client(
113        &self,
114        stream: UnixStream,
115        data: Arc<dyn ClientData>,
116    ) -> std::io::Result<InnerClientId> {
117        self.state.lock().unwrap().insert_client(stream, data)
118    }
119
120    pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
121        self.state.lock().unwrap().get_client(id)
122    }
123
124    pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
125        self.state.lock().unwrap().get_client_data(id)
126    }
127
128    pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
129        self.state.lock().unwrap().get_client_credentials(id)
130    }
131
132    pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
133        self.state.lock().unwrap().with_all_clients(&mut f)
134    }
135
136    pub fn with_all_objects_for(
137        &self,
138        client_id: InnerClientId,
139        mut f: impl FnMut(ObjectId),
140    ) -> Result<(), InvalidId> {
141        self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
142    }
143
144    pub fn object_for_protocol_id(
145        &self,
146        client_id: InnerClientId,
147        interface: &'static Interface,
148        protocol_id: u32,
149    ) -> Result<ObjectId, InvalidId> {
150        self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
151    }
152
153    pub fn create_object<D: 'static>(
154        &self,
155        client_id: InnerClientId,
156        interface: &'static Interface,
157        version: u32,
158        data: Arc<dyn ObjectData<D>>,
159    ) -> Result<ObjectId, InvalidId> {
160        let mut state = self.state.lock().unwrap();
161        let state = (&mut *state as &mut dyn ErasedState)
162            .downcast_mut::<State<D>>()
163            .expect("Wrong type parameter passed to Handle::create_object().");
164        let client = state.clients.get_client_mut(client_id)?;
165        Ok(ObjectId { id: client.create_object(interface, version, data) })
166    }
167
168    pub fn destroy_object<D: 'static>(&self, id: &ObjectId) -> Result<(), InvalidId> {
169        let mut state = self.state.lock().unwrap();
170        let state = (&mut *state as &mut dyn ErasedState)
171            .downcast_mut::<State<D>>()
172            .expect("Wrong type parameter passed to Handle::destroy_object().");
173        let client = state.clients.get_client_mut(id.id.client_id.clone())?;
174        client.destroy_object(id.id.clone(), &mut state.pending_destructors)
175    }
176
177    pub fn null_id() -> ObjectId {
178        ObjectId {
179            id: InnerObjectId {
180                id: 0,
181                serial: 0,
182                client_id: InnerClientId { id: 0, serial: 0 },
183                interface: &ANONYMOUS_INTERFACE,
184            },
185        }
186    }
187
188    pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
189        self.state.lock().unwrap().send_event(msg)
190    }
191
192    pub fn get_object_data<D: 'static>(
193        &self,
194        id: InnerObjectId,
195    ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
196        let mut state = self.state.lock().unwrap();
197        let state = (&mut *state as &mut dyn ErasedState)
198            .downcast_mut::<State<D>>()
199            .expect("Wrong type parameter passed to Handle::get_object_data().");
200        state.clients.get_client(id.client_id.clone())?.get_object_data(id)
201    }
202
203    pub fn get_object_data_any(
204        &self,
205        id: InnerObjectId,
206    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
207        self.state.lock().unwrap().get_object_data_any(id)
208    }
209
210    pub fn set_object_data<D: 'static>(
211        &self,
212        id: InnerObjectId,
213        data: Arc<dyn ObjectData<D>>,
214    ) -> Result<(), InvalidId> {
215        let mut state = self.state.lock().unwrap();
216        let state = (&mut *state as &mut dyn ErasedState)
217            .downcast_mut::<State<D>>()
218            .expect("Wrong type parameter passed to Handle::set_object_data().");
219        state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
220    }
221
222    pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
223        self.state.lock().unwrap().post_error(object_id, error_code, message)
224    }
225
226    pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
227        self.state.lock().unwrap().kill_client(client_id, reason)
228    }
229
230    pub fn create_global<D: 'static>(
231        &self,
232        interface: &'static Interface,
233        version: u32,
234        handler: Arc<dyn GlobalHandler<D>>,
235    ) -> InnerGlobalId {
236        let mut state = self.state.lock().unwrap();
237        let state = (&mut *state as &mut dyn ErasedState)
238            .downcast_mut::<State<D>>()
239            .expect("Wrong type parameter passed to Handle::create_global().");
240        state.registry.create_global(interface, version, handler, &mut state.clients)
241    }
242
243    pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
244        let mut state = self.state.lock().unwrap();
245        let state = (&mut *state as &mut dyn ErasedState)
246            .downcast_mut::<State<D>>()
247            .expect("Wrong type parameter passed to Handle::create_global().");
248
249        state.registry.disable_global(id, &mut state.clients)
250    }
251
252    pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
253        let mut state = self.state.lock().unwrap();
254        let state = (&mut *state as &mut dyn ErasedState)
255            .downcast_mut::<State<D>>()
256            .expect("Wrong type parameter passed to Handle::create_global().");
257
258        state.registry.remove_global(id, &mut state.clients)
259    }
260
261    pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
262        self.state.lock().unwrap().global_info(id)
263    }
264
265    pub fn get_global_handler<D: 'static>(
266        &self,
267        id: InnerGlobalId,
268    ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
269        let mut state = self.state.lock().unwrap();
270        let state = (&mut *state as &mut dyn ErasedState)
271            .downcast_mut::<State<D>>()
272            .expect("Wrong type parameter passed to Handle::get_global_handler().");
273        state.registry.get_handler(id)
274    }
275
276    pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
277        self.state.lock().unwrap().flush(client)
278    }
279}
280
281pub(crate) trait ErasedState: downcast_rs::Downcast {
282    fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
283    fn insert_client(
284        &mut self,
285        stream: UnixStream,
286        data: Arc<dyn ClientData>,
287    ) -> std::io::Result<InnerClientId>;
288    fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
289    fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
290    fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
291    fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
292    fn with_all_objects_for(
293        &self,
294        client_id: InnerClientId,
295        f: &mut dyn FnMut(ObjectId),
296    ) -> Result<(), InvalidId>;
297    fn object_for_protocol_id(
298        &self,
299        client_id: InnerClientId,
300        interface: &'static Interface,
301        protocol_id: u32,
302    ) -> Result<ObjectId, InvalidId>;
303    fn get_object_data_any(
304        &self,
305        id: InnerObjectId,
306    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
307    fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
308    fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
309    fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
310    fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
311    fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
312}
313
314downcast_rs::impl_downcast!(ErasedState);
315
316impl<D> ErasedState for State<D> {
317    fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
318        self.clients.get_client(id.client_id.clone())?.object_info(id)
319    }
320
321    fn insert_client(
322        &mut self,
323        stream: UnixStream,
324        data: Arc<dyn ClientData>,
325    ) -> std::io::Result<InnerClientId> {
326        let id = self.clients.create_client(stream, data);
327        let client = self.clients.get_client(id.clone()).unwrap();
328
329        // register the client to the internal epoll
330        #[cfg(any(target_os = "linux", target_os = "android"))]
331        let ret = {
332            use rustix::event::epoll;
333            epoll::add(
334                &self.poll_fd,
335                client,
336                epoll::EventData::new_u64(id.as_u64()),
337                epoll::EventFlags::IN,
338            )
339        };
340
341        #[cfg(any(
342            target_os = "dragonfly",
343            target_os = "freebsd",
344            target_os = "netbsd",
345            target_os = "openbsd",
346            target_os = "macos"
347        ))]
348        let ret = {
349            use rustix::event::kqueue::*;
350            use std::os::unix::io::{AsFd, AsRawFd};
351
352            let evt = Event::new(
353                EventFilter::Read(client.as_fd().as_raw_fd()),
354                EventFlags::ADD | EventFlags::RECEIPT,
355                id.as_u64() as *mut _,
356            );
357
358            let events: &mut [Event] = &mut [];
359            unsafe { kevent(&self.poll_fd, &[evt], events, None).map(|_| ()) }
360        };
361
362        match ret {
363            Ok(()) => Ok(id),
364            Err(e) => {
365                self.kill_client(id, DisconnectReason::ConnectionClosed);
366                Err(e.into())
367            }
368        }
369    }
370
371    fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
372        if self.clients.get_client(id.client_id.clone()).is_ok() {
373            Ok(ClientId { id: id.client_id })
374        } else {
375            Err(InvalidId)
376        }
377    }
378
379    fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
380        let client = self.clients.get_client(id)?;
381        Ok(client.data.clone())
382    }
383
384    fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
385        let client = self.clients.get_client(id)?;
386        Ok(client.get_credentials())
387    }
388
389    fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
390        for client in self.clients.all_clients_id() {
391            f(client)
392        }
393    }
394
395    fn with_all_objects_for(
396        &self,
397        client_id: InnerClientId,
398        f: &mut dyn FnMut(ObjectId),
399    ) -> Result<(), InvalidId> {
400        let client = self.clients.get_client(client_id)?;
401        for object in client.all_objects() {
402            f(object)
403        }
404        Ok(())
405    }
406
407    fn object_for_protocol_id(
408        &self,
409        client_id: InnerClientId,
410        interface: &'static Interface,
411        protocol_id: u32,
412    ) -> Result<ObjectId, InvalidId> {
413        let client = self.clients.get_client(client_id)?;
414        let object = client.object_for_protocol_id(protocol_id)?;
415        if same_interface(interface, object.interface) {
416            Ok(ObjectId { id: object })
417        } else {
418            Err(InvalidId)
419        }
420    }
421
422    fn get_object_data_any(
423        &self,
424        id: InnerObjectId,
425    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
426        self.clients
427            .get_client(id.client_id.clone())?
428            .get_object_data(id)
429            .map(|arc| arc.into_any_arc())
430    }
431
432    fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
433        self.clients
434            .get_client_mut(msg.sender_id.id.client_id.clone())?
435            .send_event(msg, Some(&mut self.pending_destructors))
436    }
437
438    fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
439        if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
440            client.post_error(object_id, error_code, message)
441        }
442    }
443
444    fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
445        if let Ok(client) = self.clients.get_client_mut(client_id) {
446            client.kill(reason)
447        }
448    }
449    fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
450        self.registry.get_info(id)
451    }
452
453    fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
454        self.flush(client)
455    }
456}