wayland_backend/rs/server_impl/
common_poll.rs

1use std::{
2    os::unix::io::{AsRawFd, BorrowedFd, OwnedFd},
3    sync::{Arc, Mutex},
4};
5
6use super::{
7    handle::State, ClientId, Data, GlobalHandler, GlobalId, Handle, InnerClientId, InnerGlobalId,
8    InnerHandle, InnerObjectId, ObjectId,
9};
10use crate::{
11    core_interfaces::{WL_DISPLAY_INTERFACE, WL_REGISTRY_INTERFACE},
12    protocol::{same_interface, Argument, Message},
13    rs::map::Object,
14    types::server::InitError,
15};
16
17#[cfg(any(target_os = "linux", target_os = "android"))]
18use rustix::event::epoll;
19
20#[cfg(any(
21    target_os = "dragonfly",
22    target_os = "freebsd",
23    target_os = "netbsd",
24    target_os = "openbsd",
25    target_os = "macos"
26))]
27use rustix::event::kqueue::*;
28use smallvec::SmallVec;
29
30#[derive(Debug)]
31pub struct InnerBackend<D: 'static> {
32    state: Arc<Mutex<State<D>>>,
33}
34
35impl<D> InnerBackend<D> {
36    pub fn new() -> Result<Self, InitError> {
37        #[cfg(any(target_os = "linux", target_os = "android"))]
38        let poll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)
39            .map_err(Into::into)
40            .map_err(InitError::Io)?;
41
42        #[cfg(any(
43            target_os = "dragonfly",
44            target_os = "freebsd",
45            target_os = "netbsd",
46            target_os = "openbsd",
47            target_os = "macos"
48        ))]
49        let poll_fd = kqueue().map_err(Into::into).map_err(InitError::Io)?;
50
51        Ok(Self { state: Arc::new(Mutex::new(State::new(poll_fd))) })
52    }
53
54    pub fn flush(&self, client: Option<ClientId>) -> std::io::Result<()> {
55        self.state.lock().unwrap().flush(client)
56    }
57
58    pub fn handle(&self) -> Handle {
59        Handle { handle: InnerHandle { state: self.state.clone() as Arc<_> } }
60    }
61
62    pub fn poll_fd(&self) -> BorrowedFd {
63        let raw_fd = self.state.lock().unwrap().poll_fd.as_raw_fd();
64        // This allows the lifetime of the BorrowedFd to be tied to &self rather than the lock guard,
65        // which is the real safety concern
66        unsafe { BorrowedFd::borrow_raw(raw_fd) }
67    }
68
69    pub fn dispatch_client(
70        &self,
71        data: &mut D,
72        client_id: InnerClientId,
73    ) -> std::io::Result<usize> {
74        let ret = self.dispatch_events_for(data, client_id);
75        let cleanup = self.state.lock().unwrap().cleanup();
76        cleanup(&self.handle(), data);
77        ret
78    }
79
80    #[cfg(any(target_os = "linux", target_os = "android"))]
81    pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
82        use std::os::unix::io::AsFd;
83
84        let poll_fd = self.poll_fd();
85        let mut dispatched = 0;
86        loop {
87            let mut events = epoll::EventVec::with_capacity(32);
88            epoll::wait(poll_fd.as_fd(), &mut events, 0)?;
89
90            if events.is_empty() {
91                break;
92            }
93
94            for event in events.iter() {
95                let id = InnerClientId::from_u64(event.data.u64());
96                // remove the cb while we call it, to gracefully handle reentrancy
97                if let Ok(count) = self.dispatch_events_for(data, id) {
98                    dispatched += count;
99                }
100            }
101            let cleanup = self.state.lock().unwrap().cleanup();
102            cleanup(&self.handle(), data);
103        }
104
105        Ok(dispatched)
106    }
107
108    #[cfg(any(
109        target_os = "dragonfly",
110        target_os = "freebsd",
111        target_os = "netbsd",
112        target_os = "openbsd",
113        target_os = "macos"
114    ))]
115    pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
116        use std::time::Duration;
117
118        let poll_fd = self.poll_fd();
119        let mut dispatched = 0;
120        loop {
121            let mut events = Vec::with_capacity(32);
122            let nevents = unsafe { kevent(&poll_fd, &[], &mut events, Some(Duration::ZERO))? };
123
124            if nevents == 0 {
125                break;
126            }
127
128            for event in events.iter().take(nevents) {
129                let id = InnerClientId::from_u64(event.udata() as u64);
130                // remove the cb while we call it, to gracefully handle reentrancy
131                if let Ok(count) = self.dispatch_events_for(data, id) {
132                    dispatched += count;
133                }
134            }
135            let cleanup = self.state.lock().unwrap().cleanup();
136            cleanup(&self.handle(), data);
137        }
138
139        Ok(dispatched)
140    }
141
142    pub(crate) fn dispatch_events_for(
143        &self,
144        data: &mut D,
145        client_id: InnerClientId,
146    ) -> std::io::Result<usize> {
147        let mut dispatched = 0;
148        let handle = self.handle();
149        let mut state = self.state.lock().unwrap();
150        loop {
151            let action = {
152                let state = &mut *state;
153                if let Ok(client) = state.clients.get_client_mut(client_id.clone()) {
154                    let (message, object) = match client.next_request() {
155                        Ok(v) => v,
156                        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
157                            if dispatched > 0 {
158                                break;
159                            } else {
160                                return Err(e);
161                            }
162                        }
163                        Err(e) => {
164                            #[cfg(any(target_os = "linux", target_os = "android"))]
165                            {
166                                epoll::delete(&state.poll_fd, client)?;
167                            }
168
169                            #[cfg(any(
170                                target_os = "dragonfly",
171                                target_os = "freebsd",
172                                target_os = "netbsd",
173                                target_os = "openbsd",
174                                target_os = "macos"
175                            ))]
176                            {
177                                use rustix::event::kqueue::*;
178                                use std::os::unix::io::{AsFd, AsRawFd};
179
180                                let evt = Event::new(
181                                    EventFilter::Read(client.as_fd().as_raw_fd()),
182                                    EventFlags::DELETE,
183                                    client_id.as_u64() as isize,
184                                );
185
186                                let mut events = Vec::new();
187                                unsafe {
188                                    kevent(&state.poll_fd, &[evt], &mut events, None)
189                                        .map(|_| ())?;
190                                }
191                            }
192                            return Err(e);
193                        }
194                    };
195                    dispatched += 1;
196                    if same_interface(object.interface, &WL_DISPLAY_INTERFACE) {
197                        client.handle_display_request(message, &mut state.registry);
198                        continue;
199                    } else if same_interface(object.interface, &WL_REGISTRY_INTERFACE) {
200                        if let Some((client, global, object, handler)) =
201                            client.handle_registry_request(message, &mut state.registry)
202                        {
203                            DispatchAction::Bind { client, global, object, handler }
204                        } else {
205                            continue;
206                        }
207                    } else {
208                        let object_id = InnerObjectId {
209                            id: message.sender_id,
210                            serial: object.data.serial,
211                            interface: object.interface,
212                            client_id: client.id.clone(),
213                        };
214                        let opcode = message.opcode;
215                        let (arguments, is_destructor, created_id) =
216                            match client.process_request(&object, message) {
217                                Some(args) => args,
218                                None => continue,
219                            };
220                        // Return the whole set to invoke the callback while handle is not borrower via client
221                        DispatchAction::Request {
222                            object,
223                            object_id,
224                            opcode,
225                            arguments,
226                            is_destructor,
227                            created_id,
228                        }
229                    }
230                } else {
231                    return Err(std::io::Error::new(
232                        std::io::ErrorKind::InvalidInput,
233                        "Invalid client ID",
234                    ));
235                }
236            };
237            match action {
238                DispatchAction::Request {
239                    object,
240                    object_id,
241                    opcode,
242                    arguments,
243                    is_destructor,
244                    created_id,
245                } => {
246                    // temporarily unlock the state Mutex while this request is dispatched
247                    std::mem::drop(state);
248                    let ret = object.data.user_data.clone().request(
249                        &handle.clone(),
250                        data,
251                        ClientId { id: client_id.clone() },
252                        Message {
253                            sender_id: ObjectId { id: object_id.clone() },
254                            opcode,
255                            args: arguments,
256                        },
257                    );
258                    if is_destructor {
259                        object.data.user_data.clone().destroyed(
260                            &handle.clone(),
261                            data,
262                            ClientId { id: client_id.clone() },
263                            ObjectId { id: object_id.clone() },
264                        );
265                    }
266                    // acquire the lock again and continue
267                    state = self.state.lock().unwrap();
268                    if is_destructor {
269                        if let Ok(client) = state.clients.get_client_mut(client_id.clone()) {
270                            client.send_delete_id(object_id);
271                        }
272                    }
273                    match (created_id, ret) {
274                        (Some(child_id), Some(child_data)) => {
275                            if let Ok(client) = state.clients.get_client_mut(client_id.clone()) {
276                                client
277                                    .map
278                                    .with(child_id.id, |obj| obj.data.user_data = child_data)
279                                    .unwrap();
280                            }
281                        }
282                        (None, None) => {}
283                        (Some(child_id), None) => {
284                            // Allow the callback to not return any data if the client is already dead (typically
285                            // if the callback provoked a protocol error)
286                            if let Ok(client) = state.clients.get_client(client_id.clone()) {
287                                if !client.killed {
288                                    panic!(
289                                        "Callback creating object {} did not provide any object data.",
290                                        child_id
291                                    );
292                                }
293                            }
294                        }
295                        (None, Some(_)) => {
296                            panic!("An object data was returned from a callback not creating any object");
297                        }
298                    }
299                    // dropping the object calls destructors from which users could call into wayland-backend again.
300                    // so lets release and relock the state again, to avoid a deadlock
301                    std::mem::drop(state);
302                    std::mem::drop(object);
303                    state = self.state.lock().unwrap();
304                }
305                DispatchAction::Bind { object, client, global, handler } => {
306                    // temporarily unlock the state Mutex while this request is dispatched
307                    std::mem::drop(state);
308                    let child_data = handler.bind(
309                        &handle.clone(),
310                        data,
311                        ClientId { id: client.clone() },
312                        GlobalId { id: global },
313                        ObjectId { id: object.clone() },
314                    );
315                    // acquire the lock again and continue
316                    state = self.state.lock().unwrap();
317                    if let Ok(client) = state.clients.get_client_mut(client.clone()) {
318                        client.map.with(object.id, |obj| obj.data.user_data = child_data).unwrap();
319                    }
320                }
321            }
322        }
323        Ok(dispatched)
324    }
325}
326
327enum DispatchAction<D: 'static> {
328    Request {
329        object: Object<Data<D>>,
330        object_id: InnerObjectId,
331        opcode: u16,
332        arguments: SmallVec<[Argument<ObjectId, OwnedFd>; 4]>,
333        is_destructor: bool,
334        created_id: Option<InnerObjectId>,
335    },
336    Bind {
337        object: InnerObjectId,
338        client: InnerClientId,
339        global: InnerGlobalId,
340        handler: Arc<dyn GlobalHandler<D>>,
341    },
342}