1use std::{
2 ffi::CString,
3 os::unix::{
4 io::{OwnedFd, RawFd},
5 net::UnixStream,
6 },
7 sync::{Arc, Mutex, Weak},
8};
9
10use crate::{
11 protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
12 types::server::{DisconnectReason, GlobalInfo, InvalidId},
13};
14
15use super::{
16 client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
17 InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
18};
19
20pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
21
22#[derive(Debug)]
23pub struct State<D: 'static> {
24 pub(crate) clients: ClientStore<D>,
25 pub(crate) registry: Registry<D>,
26 pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
27 pub(crate) poll_fd: OwnedFd,
28}
29
30impl<D> State<D> {
31 pub(crate) fn new(poll_fd: OwnedFd) -> Self {
32 let debug =
33 matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
34 Self {
35 clients: ClientStore::new(debug),
36 registry: Registry::new(),
37 pending_destructors: Vec::new(),
38 poll_fd,
39 }
40 }
41
42 pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
43 let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
44 self.registry.cleanup(&dead_clients);
45 let pending_destructors = std::mem::take(&mut self.pending_destructors);
47 move |handle, data| {
48 for (object_data, client_id, object_id) in pending_destructors {
49 object_data.clone().destroyed(
50 handle,
51 data,
52 ClientId { id: client_id },
53 ObjectId { id: object_id },
54 );
55 }
56 std::mem::drop(dead_clients);
57 }
58 }
59
60 pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
61 if let Some(ClientId { id: client }) = client {
62 match self.clients.get_client_mut(client) {
63 Ok(client) => client.flush(),
64 Err(InvalidId) => Ok(()),
65 }
66 } else {
67 for client in self.clients.clients_mut() {
68 let _ = client.flush();
69 }
70 Ok(())
71 }
72 }
73}
74
75#[derive(Clone)]
76pub struct InnerHandle {
77 pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
78}
79
80impl std::fmt::Debug for InnerHandle {
81 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 fmt.debug_struct("InnerHandle[rs]").finish_non_exhaustive()
83 }
84}
85
86#[derive(Clone)]
87pub struct WeakInnerHandle {
88 pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
89}
90
91impl std::fmt::Debug for WeakInnerHandle {
92 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 fmt.debug_struct("WeakInnerHandle[rs]").finish_non_exhaustive()
94 }
95}
96
97impl WeakInnerHandle {
98 pub fn upgrade(&self) -> Option<InnerHandle> {
99 self.state.upgrade().map(|state| InnerHandle { state })
100 }
101}
102
103impl InnerHandle {
104 pub fn downgrade(&self) -> WeakInnerHandle {
105 WeakInnerHandle { state: Arc::downgrade(&self.state) }
106 }
107
108 pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
109 self.state.lock().unwrap().object_info(id)
110 }
111
112 pub fn insert_client(
113 &self,
114 stream: UnixStream,
115 data: Arc<dyn ClientData>,
116 ) -> std::io::Result<InnerClientId> {
117 self.state.lock().unwrap().insert_client(stream, data)
118 }
119
120 pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
121 self.state.lock().unwrap().get_client(id)
122 }
123
124 pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
125 self.state.lock().unwrap().get_client_data(id)
126 }
127
128 pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
129 self.state.lock().unwrap().get_client_credentials(id)
130 }
131
132 pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
133 self.state.lock().unwrap().with_all_clients(&mut f)
134 }
135
136 pub fn with_all_objects_for(
137 &self,
138 client_id: InnerClientId,
139 mut f: impl FnMut(ObjectId),
140 ) -> Result<(), InvalidId> {
141 self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
142 }
143
144 pub fn object_for_protocol_id(
145 &self,
146 client_id: InnerClientId,
147 interface: &'static Interface,
148 protocol_id: u32,
149 ) -> Result<ObjectId, InvalidId> {
150 self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
151 }
152
153 pub fn create_object<D: 'static>(
154 &self,
155 client_id: InnerClientId,
156 interface: &'static Interface,
157 version: u32,
158 data: Arc<dyn ObjectData<D>>,
159 ) -> Result<ObjectId, InvalidId> {
160 let mut state = self.state.lock().unwrap();
161 let state = (&mut *state as &mut dyn ErasedState)
162 .downcast_mut::<State<D>>()
163 .expect("Wrong type parameter passed to Handle::create_object().");
164 let client = state.clients.get_client_mut(client_id)?;
165 Ok(ObjectId { id: client.create_object(interface, version, data) })
166 }
167
168 pub fn null_id() -> ObjectId {
169 ObjectId {
170 id: InnerObjectId {
171 id: 0,
172 serial: 0,
173 client_id: InnerClientId { id: 0, serial: 0 },
174 interface: &ANONYMOUS_INTERFACE,
175 },
176 }
177 }
178
179 pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
180 self.state.lock().unwrap().send_event(msg)
181 }
182
183 pub fn get_object_data<D: 'static>(
184 &self,
185 id: InnerObjectId,
186 ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
187 let mut state = self.state.lock().unwrap();
188 let state = (&mut *state as &mut dyn ErasedState)
189 .downcast_mut::<State<D>>()
190 .expect("Wrong type parameter passed to Handle::get_object_data().");
191 state.clients.get_client(id.client_id.clone())?.get_object_data(id)
192 }
193
194 pub fn get_object_data_any(
195 &self,
196 id: InnerObjectId,
197 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
198 self.state.lock().unwrap().get_object_data_any(id)
199 }
200
201 pub fn set_object_data<D: 'static>(
202 &self,
203 id: InnerObjectId,
204 data: Arc<dyn ObjectData<D>>,
205 ) -> Result<(), InvalidId> {
206 let mut state = self.state.lock().unwrap();
207 let state = (&mut *state as &mut dyn ErasedState)
208 .downcast_mut::<State<D>>()
209 .expect("Wrong type parameter passed to Handle::set_object_data().");
210 state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
211 }
212
213 pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
214 self.state.lock().unwrap().post_error(object_id, error_code, message)
215 }
216
217 pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
218 self.state.lock().unwrap().kill_client(client_id, reason)
219 }
220
221 pub fn create_global<D: 'static>(
222 &self,
223 interface: &'static Interface,
224 version: u32,
225 handler: Arc<dyn GlobalHandler<D>>,
226 ) -> InnerGlobalId {
227 let mut state = self.state.lock().unwrap();
228 let state = (&mut *state as &mut dyn ErasedState)
229 .downcast_mut::<State<D>>()
230 .expect("Wrong type parameter passed to Handle::create_global().");
231 state.registry.create_global(interface, version, handler, &mut state.clients)
232 }
233
234 pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
235 let mut state = self.state.lock().unwrap();
236 let state = (&mut *state as &mut dyn ErasedState)
237 .downcast_mut::<State<D>>()
238 .expect("Wrong type parameter passed to Handle::create_global().");
239
240 state.registry.disable_global(id, &mut state.clients)
241 }
242
243 pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
244 let mut state = self.state.lock().unwrap();
245 let state = (&mut *state as &mut dyn ErasedState)
246 .downcast_mut::<State<D>>()
247 .expect("Wrong type parameter passed to Handle::create_global().");
248
249 state.registry.remove_global(id, &mut state.clients)
250 }
251
252 pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
253 self.state.lock().unwrap().global_info(id)
254 }
255
256 pub fn get_global_handler<D: 'static>(
257 &self,
258 id: InnerGlobalId,
259 ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
260 let mut state = self.state.lock().unwrap();
261 let state = (&mut *state as &mut dyn ErasedState)
262 .downcast_mut::<State<D>>()
263 .expect("Wrong type parameter passed to Handle::get_global_handler().");
264 state.registry.get_handler(id)
265 }
266
267 pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
268 self.state.lock().unwrap().flush(client)
269 }
270}
271
272pub(crate) trait ErasedState: downcast_rs::Downcast {
273 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
274 fn insert_client(
275 &mut self,
276 stream: UnixStream,
277 data: Arc<dyn ClientData>,
278 ) -> std::io::Result<InnerClientId>;
279 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
280 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
281 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
282 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
283 fn with_all_objects_for(
284 &self,
285 client_id: InnerClientId,
286 f: &mut dyn FnMut(ObjectId),
287 ) -> Result<(), InvalidId>;
288 fn object_for_protocol_id(
289 &self,
290 client_id: InnerClientId,
291 interface: &'static Interface,
292 protocol_id: u32,
293 ) -> Result<ObjectId, InvalidId>;
294 fn get_object_data_any(
295 &self,
296 id: InnerObjectId,
297 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
298 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
299 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
300 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
301 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
302 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
303}
304
305downcast_rs::impl_downcast!(ErasedState);
306
307impl<D> ErasedState for State<D> {
308 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
309 self.clients.get_client(id.client_id.clone())?.object_info(id)
310 }
311
312 fn insert_client(
313 &mut self,
314 stream: UnixStream,
315 data: Arc<dyn ClientData>,
316 ) -> std::io::Result<InnerClientId> {
317 let id = self.clients.create_client(stream, data);
318 let client = self.clients.get_client(id.clone()).unwrap();
319
320 #[cfg(any(target_os = "linux", target_os = "android"))]
322 let ret = {
323 use rustix::event::epoll;
324 epoll::add(
325 &self.poll_fd,
326 client,
327 epoll::EventData::new_u64(id.as_u64()),
328 epoll::EventFlags::IN,
329 )
330 };
331
332 #[cfg(any(
333 target_os = "dragonfly",
334 target_os = "freebsd",
335 target_os = "netbsd",
336 target_os = "openbsd",
337 target_os = "macos"
338 ))]
339 let ret = {
340 use rustix::event::kqueue::*;
341 use std::os::unix::io::{AsFd, AsRawFd};
342
343 let evt = Event::new(
344 EventFilter::Read(client.as_fd().as_raw_fd()),
345 EventFlags::ADD | EventFlags::RECEIPT,
346 id.as_u64() as isize,
347 );
348
349 let mut events = Vec::new();
350 unsafe { kevent(&self.poll_fd, &[evt], &mut events, None).map(|_| ()) }
351 };
352
353 match ret {
354 Ok(()) => Ok(id),
355 Err(e) => {
356 self.kill_client(id, DisconnectReason::ConnectionClosed);
357 Err(e.into())
358 }
359 }
360 }
361
362 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
363 if self.clients.get_client(id.client_id.clone()).is_ok() {
364 Ok(ClientId { id: id.client_id })
365 } else {
366 Err(InvalidId)
367 }
368 }
369
370 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
371 let client = self.clients.get_client(id)?;
372 Ok(client.data.clone())
373 }
374
375 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
376 let client = self.clients.get_client(id)?;
377 Ok(client.get_credentials())
378 }
379
380 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
381 for client in self.clients.all_clients_id() {
382 f(client)
383 }
384 }
385
386 fn with_all_objects_for(
387 &self,
388 client_id: InnerClientId,
389 f: &mut dyn FnMut(ObjectId),
390 ) -> Result<(), InvalidId> {
391 let client = self.clients.get_client(client_id)?;
392 for object in client.all_objects() {
393 f(object)
394 }
395 Ok(())
396 }
397
398 fn object_for_protocol_id(
399 &self,
400 client_id: InnerClientId,
401 interface: &'static Interface,
402 protocol_id: u32,
403 ) -> Result<ObjectId, InvalidId> {
404 let client = self.clients.get_client(client_id)?;
405 let object = client.object_for_protocol_id(protocol_id)?;
406 if same_interface(interface, object.interface) {
407 Ok(ObjectId { id: object })
408 } else {
409 Err(InvalidId)
410 }
411 }
412
413 fn get_object_data_any(
414 &self,
415 id: InnerObjectId,
416 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
417 self.clients
418 .get_client(id.client_id.clone())?
419 .get_object_data(id)
420 .map(|arc| arc.into_any_arc())
421 }
422
423 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
424 self.clients
425 .get_client_mut(msg.sender_id.id.client_id.clone())?
426 .send_event(msg, Some(&mut self.pending_destructors))
427 }
428
429 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
430 if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
431 client.post_error(object_id, error_code, message)
432 }
433 }
434
435 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
436 if let Ok(client) = self.clients.get_client_mut(client_id) {
437 client.kill(reason)
438 }
439 }
440 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
441 self.registry.get_info(id)
442 }
443
444 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
445 self.flush(client)
446 }
447}