wayland_backend/rs/server_impl/
handle.rs

1use std::{
2    ffi::CString,
3    os::unix::{
4        io::{OwnedFd, RawFd},
5        net::UnixStream,
6    },
7    sync::{Arc, Mutex, Weak},
8};
9
10use crate::{
11    protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
12    types::server::{DisconnectReason, GlobalInfo, InvalidId},
13};
14
15use super::{
16    client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
17    InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
18};
19
20pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
21
22#[derive(Debug)]
23pub struct State<D: 'static> {
24    pub(crate) clients: ClientStore<D>,
25    pub(crate) registry: Registry<D>,
26    pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
27    pub(crate) poll_fd: OwnedFd,
28}
29
30impl<D> State<D> {
31    pub(crate) fn new(poll_fd: OwnedFd) -> Self {
32        let debug =
33            matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
34        Self {
35            clients: ClientStore::new(debug),
36            registry: Registry::new(),
37            pending_destructors: Vec::new(),
38            poll_fd,
39        }
40    }
41
42    pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
43        let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
44        self.registry.cleanup(&dead_clients);
45        // return a closure that will do the cleanup once invoked
46        let pending_destructors = std::mem::take(&mut self.pending_destructors);
47        move |handle, data| {
48            for (object_data, client_id, object_id) in pending_destructors {
49                object_data.clone().destroyed(
50                    handle,
51                    data,
52                    ClientId { id: client_id },
53                    ObjectId { id: object_id },
54                );
55            }
56            std::mem::drop(dead_clients);
57        }
58    }
59
60    pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
61        if let Some(ClientId { id: client }) = client {
62            match self.clients.get_client_mut(client) {
63                Ok(client) => client.flush(),
64                Err(InvalidId) => Ok(()),
65            }
66        } else {
67            for client in self.clients.clients_mut() {
68                let _ = client.flush();
69            }
70            Ok(())
71        }
72    }
73}
74
75#[derive(Clone)]
76pub struct InnerHandle {
77    pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
78}
79
80impl std::fmt::Debug for InnerHandle {
81    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        fmt.debug_struct("InnerHandle[rs]").finish_non_exhaustive()
83    }
84}
85
86#[derive(Clone)]
87pub struct WeakInnerHandle {
88    pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
89}
90
91impl std::fmt::Debug for WeakInnerHandle {
92    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        fmt.debug_struct("WeakInnerHandle[rs]").finish_non_exhaustive()
94    }
95}
96
97impl WeakInnerHandle {
98    pub fn upgrade(&self) -> Option<InnerHandle> {
99        self.state.upgrade().map(|state| InnerHandle { state })
100    }
101}
102
103impl InnerHandle {
104    pub fn downgrade(&self) -> WeakInnerHandle {
105        WeakInnerHandle { state: Arc::downgrade(&self.state) }
106    }
107
108    pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
109        self.state.lock().unwrap().object_info(id)
110    }
111
112    pub fn insert_client(
113        &self,
114        stream: UnixStream,
115        data: Arc<dyn ClientData>,
116    ) -> std::io::Result<InnerClientId> {
117        self.state.lock().unwrap().insert_client(stream, data)
118    }
119
120    pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
121        self.state.lock().unwrap().get_client(id)
122    }
123
124    pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
125        self.state.lock().unwrap().get_client_data(id)
126    }
127
128    pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
129        self.state.lock().unwrap().get_client_credentials(id)
130    }
131
132    pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
133        self.state.lock().unwrap().with_all_clients(&mut f)
134    }
135
136    pub fn with_all_objects_for(
137        &self,
138        client_id: InnerClientId,
139        mut f: impl FnMut(ObjectId),
140    ) -> Result<(), InvalidId> {
141        self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
142    }
143
144    pub fn object_for_protocol_id(
145        &self,
146        client_id: InnerClientId,
147        interface: &'static Interface,
148        protocol_id: u32,
149    ) -> Result<ObjectId, InvalidId> {
150        self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
151    }
152
153    pub fn create_object<D: 'static>(
154        &self,
155        client_id: InnerClientId,
156        interface: &'static Interface,
157        version: u32,
158        data: Arc<dyn ObjectData<D>>,
159    ) -> Result<ObjectId, InvalidId> {
160        let mut state = self.state.lock().unwrap();
161        let state = (&mut *state as &mut dyn ErasedState)
162            .downcast_mut::<State<D>>()
163            .expect("Wrong type parameter passed to Handle::create_object().");
164        let client = state.clients.get_client_mut(client_id)?;
165        Ok(ObjectId { id: client.create_object(interface, version, data) })
166    }
167
168    pub fn null_id() -> ObjectId {
169        ObjectId {
170            id: InnerObjectId {
171                id: 0,
172                serial: 0,
173                client_id: InnerClientId { id: 0, serial: 0 },
174                interface: &ANONYMOUS_INTERFACE,
175            },
176        }
177    }
178
179    pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
180        self.state.lock().unwrap().send_event(msg)
181    }
182
183    pub fn get_object_data<D: 'static>(
184        &self,
185        id: InnerObjectId,
186    ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
187        let mut state = self.state.lock().unwrap();
188        let state = (&mut *state as &mut dyn ErasedState)
189            .downcast_mut::<State<D>>()
190            .expect("Wrong type parameter passed to Handle::get_object_data().");
191        state.clients.get_client(id.client_id.clone())?.get_object_data(id)
192    }
193
194    pub fn get_object_data_any(
195        &self,
196        id: InnerObjectId,
197    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
198        self.state.lock().unwrap().get_object_data_any(id)
199    }
200
201    pub fn set_object_data<D: 'static>(
202        &self,
203        id: InnerObjectId,
204        data: Arc<dyn ObjectData<D>>,
205    ) -> Result<(), InvalidId> {
206        let mut state = self.state.lock().unwrap();
207        let state = (&mut *state as &mut dyn ErasedState)
208            .downcast_mut::<State<D>>()
209            .expect("Wrong type parameter passed to Handle::set_object_data().");
210        state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
211    }
212
213    pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
214        self.state.lock().unwrap().post_error(object_id, error_code, message)
215    }
216
217    pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
218        self.state.lock().unwrap().kill_client(client_id, reason)
219    }
220
221    pub fn create_global<D: 'static>(
222        &self,
223        interface: &'static Interface,
224        version: u32,
225        handler: Arc<dyn GlobalHandler<D>>,
226    ) -> InnerGlobalId {
227        let mut state = self.state.lock().unwrap();
228        let state = (&mut *state as &mut dyn ErasedState)
229            .downcast_mut::<State<D>>()
230            .expect("Wrong type parameter passed to Handle::create_global().");
231        state.registry.create_global(interface, version, handler, &mut state.clients)
232    }
233
234    pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
235        let mut state = self.state.lock().unwrap();
236        let state = (&mut *state as &mut dyn ErasedState)
237            .downcast_mut::<State<D>>()
238            .expect("Wrong type parameter passed to Handle::create_global().");
239
240        state.registry.disable_global(id, &mut state.clients)
241    }
242
243    pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
244        let mut state = self.state.lock().unwrap();
245        let state = (&mut *state as &mut dyn ErasedState)
246            .downcast_mut::<State<D>>()
247            .expect("Wrong type parameter passed to Handle::create_global().");
248
249        state.registry.remove_global(id, &mut state.clients)
250    }
251
252    pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
253        self.state.lock().unwrap().global_info(id)
254    }
255
256    pub fn get_global_handler<D: 'static>(
257        &self,
258        id: InnerGlobalId,
259    ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
260        let mut state = self.state.lock().unwrap();
261        let state = (&mut *state as &mut dyn ErasedState)
262            .downcast_mut::<State<D>>()
263            .expect("Wrong type parameter passed to Handle::get_global_handler().");
264        state.registry.get_handler(id)
265    }
266
267    pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
268        self.state.lock().unwrap().flush(client)
269    }
270}
271
272pub(crate) trait ErasedState: downcast_rs::Downcast {
273    fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
274    fn insert_client(
275        &mut self,
276        stream: UnixStream,
277        data: Arc<dyn ClientData>,
278    ) -> std::io::Result<InnerClientId>;
279    fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
280    fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
281    fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
282    fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
283    fn with_all_objects_for(
284        &self,
285        client_id: InnerClientId,
286        f: &mut dyn FnMut(ObjectId),
287    ) -> Result<(), InvalidId>;
288    fn object_for_protocol_id(
289        &self,
290        client_id: InnerClientId,
291        interface: &'static Interface,
292        protocol_id: u32,
293    ) -> Result<ObjectId, InvalidId>;
294    fn get_object_data_any(
295        &self,
296        id: InnerObjectId,
297    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
298    fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
299    fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
300    fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
301    fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
302    fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
303}
304
305downcast_rs::impl_downcast!(ErasedState);
306
307impl<D> ErasedState for State<D> {
308    fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
309        self.clients.get_client(id.client_id.clone())?.object_info(id)
310    }
311
312    fn insert_client(
313        &mut self,
314        stream: UnixStream,
315        data: Arc<dyn ClientData>,
316    ) -> std::io::Result<InnerClientId> {
317        let id = self.clients.create_client(stream, data);
318        let client = self.clients.get_client(id.clone()).unwrap();
319
320        // register the client to the internal epoll
321        #[cfg(any(target_os = "linux", target_os = "android"))]
322        let ret = {
323            use rustix::event::epoll;
324            epoll::add(
325                &self.poll_fd,
326                client,
327                epoll::EventData::new_u64(id.as_u64()),
328                epoll::EventFlags::IN,
329            )
330        };
331
332        #[cfg(any(
333            target_os = "dragonfly",
334            target_os = "freebsd",
335            target_os = "netbsd",
336            target_os = "openbsd",
337            target_os = "macos"
338        ))]
339        let ret = {
340            use rustix::event::kqueue::*;
341            use std::os::unix::io::{AsFd, AsRawFd};
342
343            let evt = Event::new(
344                EventFilter::Read(client.as_fd().as_raw_fd()),
345                EventFlags::ADD | EventFlags::RECEIPT,
346                id.as_u64() as isize,
347            );
348
349            let mut events = Vec::new();
350            unsafe { kevent(&self.poll_fd, &[evt], &mut events, None).map(|_| ()) }
351        };
352
353        match ret {
354            Ok(()) => Ok(id),
355            Err(e) => {
356                self.kill_client(id, DisconnectReason::ConnectionClosed);
357                Err(e.into())
358            }
359        }
360    }
361
362    fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
363        if self.clients.get_client(id.client_id.clone()).is_ok() {
364            Ok(ClientId { id: id.client_id })
365        } else {
366            Err(InvalidId)
367        }
368    }
369
370    fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
371        let client = self.clients.get_client(id)?;
372        Ok(client.data.clone())
373    }
374
375    fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
376        let client = self.clients.get_client(id)?;
377        Ok(client.get_credentials())
378    }
379
380    fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
381        for client in self.clients.all_clients_id() {
382            f(client)
383        }
384    }
385
386    fn with_all_objects_for(
387        &self,
388        client_id: InnerClientId,
389        f: &mut dyn FnMut(ObjectId),
390    ) -> Result<(), InvalidId> {
391        let client = self.clients.get_client(client_id)?;
392        for object in client.all_objects() {
393            f(object)
394        }
395        Ok(())
396    }
397
398    fn object_for_protocol_id(
399        &self,
400        client_id: InnerClientId,
401        interface: &'static Interface,
402        protocol_id: u32,
403    ) -> Result<ObjectId, InvalidId> {
404        let client = self.clients.get_client(client_id)?;
405        let object = client.object_for_protocol_id(protocol_id)?;
406        if same_interface(interface, object.interface) {
407            Ok(ObjectId { id: object })
408        } else {
409            Err(InvalidId)
410        }
411    }
412
413    fn get_object_data_any(
414        &self,
415        id: InnerObjectId,
416    ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
417        self.clients
418            .get_client(id.client_id.clone())?
419            .get_object_data(id)
420            .map(|arc| arc.into_any_arc())
421    }
422
423    fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
424        self.clients
425            .get_client_mut(msg.sender_id.id.client_id.clone())?
426            .send_event(msg, Some(&mut self.pending_destructors))
427    }
428
429    fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
430        if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
431            client.post_error(object_id, error_code, message)
432        }
433    }
434
435    fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
436        if let Ok(client) = self.clients.get_client_mut(client_id) {
437            client.kill(reason)
438        }
439    }
440    fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
441        self.registry.get_info(id)
442    }
443
444    fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
445        self.flush(client)
446    }
447}