1use std::{
2 ffi::CString,
3 os::unix::{
4 io::{OwnedFd, RawFd},
5 net::UnixStream,
6 },
7 sync::{Arc, Mutex, Weak},
8};
9
10use crate::{
11 protocol::{same_interface, Interface, Message, ObjectInfo, ANONYMOUS_INTERFACE},
12 types::server::{DisconnectReason, GlobalInfo, InvalidId},
13};
14
15use super::{
16 client::ClientStore, registry::Registry, ClientData, ClientId, Credentials, GlobalHandler,
17 InnerClientId, InnerGlobalId, InnerObjectId, ObjectData, ObjectId,
18};
19
20pub(crate) type PendingDestructor<D> = (Arc<dyn ObjectData<D>>, InnerClientId, InnerObjectId);
21
22#[derive(Debug)]
23pub struct State<D: 'static> {
24 pub(crate) clients: ClientStore<D>,
25 pub(crate) registry: Registry<D>,
26 pub(crate) pending_destructors: Vec<PendingDestructor<D>>,
27 pub(crate) poll_fd: OwnedFd,
28}
29
30impl<D> State<D> {
31 pub(crate) fn new(poll_fd: OwnedFd) -> Self {
32 let debug =
33 matches!(std::env::var_os("WAYLAND_DEBUG"), Some(str) if str == "1" || str == "server");
34 Self {
35 clients: ClientStore::new(debug),
36 registry: Registry::new(),
37 pending_destructors: Vec::new(),
38 poll_fd,
39 }
40 }
41
42 pub(crate) fn cleanup<'a>(&mut self) -> impl FnOnce(&super::Handle, &mut D) + 'a {
43 let dead_clients = self.clients.cleanup(&mut self.pending_destructors);
44 self.registry.cleanup(&dead_clients, &self.pending_destructors);
45 let pending_destructors = std::mem::take(&mut self.pending_destructors);
47 move |handle, data| {
48 for (object_data, client_id, object_id) in pending_destructors {
49 object_data.clone().destroyed(
50 handle,
51 data,
52 ClientId { id: client_id },
53 ObjectId { id: object_id },
54 );
55 }
56 std::mem::drop(dead_clients);
57 }
58 }
59
60 pub(crate) fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
61 if let Some(ClientId { id: client }) = client {
62 match self.clients.get_client_mut(client) {
63 Ok(client) => client.flush(),
64 Err(InvalidId) => Ok(()),
65 }
66 } else {
67 for client in self.clients.clients_mut() {
68 let _ = client.flush();
69 }
70 Ok(())
71 }
72 }
73}
74
75#[derive(Clone)]
76pub struct InnerHandle {
77 pub(crate) state: Arc<Mutex<dyn ErasedState + Send>>,
78}
79
80impl std::fmt::Debug for InnerHandle {
81 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 fmt.debug_struct("InnerHandle[rs]").finish_non_exhaustive()
83 }
84}
85
86#[derive(Clone)]
87pub struct WeakInnerHandle {
88 pub(crate) state: Weak<Mutex<dyn ErasedState + Send>>,
89}
90
91impl std::fmt::Debug for WeakInnerHandle {
92 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 fmt.debug_struct("WeakInnerHandle[rs]").finish_non_exhaustive()
94 }
95}
96
97impl WeakInnerHandle {
98 pub fn upgrade(&self) -> Option<InnerHandle> {
99 self.state.upgrade().map(|state| InnerHandle { state })
100 }
101}
102
103impl InnerHandle {
104 pub fn downgrade(&self) -> WeakInnerHandle {
105 WeakInnerHandle { state: Arc::downgrade(&self.state) }
106 }
107
108 pub fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
109 self.state.lock().unwrap().object_info(id)
110 }
111
112 pub fn insert_client(
113 &self,
114 stream: UnixStream,
115 data: Arc<dyn ClientData>,
116 ) -> std::io::Result<InnerClientId> {
117 self.state.lock().unwrap().insert_client(stream, data)
118 }
119
120 pub fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
121 self.state.lock().unwrap().get_client(id)
122 }
123
124 pub fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
125 self.state.lock().unwrap().get_client_data(id)
126 }
127
128 pub fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
129 self.state.lock().unwrap().get_client_credentials(id)
130 }
131
132 pub fn with_all_clients(&self, mut f: impl FnMut(ClientId)) {
133 self.state.lock().unwrap().with_all_clients(&mut f)
134 }
135
136 pub fn with_all_objects_for(
137 &self,
138 client_id: InnerClientId,
139 mut f: impl FnMut(ObjectId),
140 ) -> Result<(), InvalidId> {
141 self.state.lock().unwrap().with_all_objects_for(client_id, &mut f)
142 }
143
144 pub fn object_for_protocol_id(
145 &self,
146 client_id: InnerClientId,
147 interface: &'static Interface,
148 protocol_id: u32,
149 ) -> Result<ObjectId, InvalidId> {
150 self.state.lock().unwrap().object_for_protocol_id(client_id, interface, protocol_id)
151 }
152
153 pub fn create_object<D: 'static>(
154 &self,
155 client_id: InnerClientId,
156 interface: &'static Interface,
157 version: u32,
158 data: Arc<dyn ObjectData<D>>,
159 ) -> Result<ObjectId, InvalidId> {
160 let mut state = self.state.lock().unwrap();
161 let state = (&mut *state as &mut dyn ErasedState)
162 .downcast_mut::<State<D>>()
163 .expect("Wrong type parameter passed to Handle::create_object().");
164 let client = state.clients.get_client_mut(client_id)?;
165 Ok(ObjectId { id: client.create_object(interface, version, data) })
166 }
167
168 pub fn destroy_object<D: 'static>(&self, id: &ObjectId) -> Result<(), InvalidId> {
169 let mut state = self.state.lock().unwrap();
170 let state = (&mut *state as &mut dyn ErasedState)
171 .downcast_mut::<State<D>>()
172 .expect("Wrong type parameter passed to Handle::destroy_object().");
173 let client = state.clients.get_client_mut(id.id.client_id.clone())?;
174 client.destroy_object(id.id.clone(), &mut state.pending_destructors)
175 }
176
177 pub fn null_id() -> ObjectId {
178 ObjectId {
179 id: InnerObjectId {
180 id: 0,
181 serial: 0,
182 client_id: InnerClientId { id: 0, serial: 0 },
183 interface: &ANONYMOUS_INTERFACE,
184 },
185 }
186 }
187
188 pub fn send_event(&self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
189 self.state.lock().unwrap().send_event(msg)
190 }
191
192 pub fn get_object_data<D: 'static>(
193 &self,
194 id: InnerObjectId,
195 ) -> Result<Arc<dyn ObjectData<D>>, InvalidId> {
196 let mut state = self.state.lock().unwrap();
197 let state = (&mut *state as &mut dyn ErasedState)
198 .downcast_mut::<State<D>>()
199 .expect("Wrong type parameter passed to Handle::get_object_data().");
200 state.clients.get_client(id.client_id.clone())?.get_object_data(id)
201 }
202
203 pub fn get_object_data_any(
204 &self,
205 id: InnerObjectId,
206 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
207 self.state.lock().unwrap().get_object_data_any(id)
208 }
209
210 pub fn set_object_data<D: 'static>(
211 &self,
212 id: InnerObjectId,
213 data: Arc<dyn ObjectData<D>>,
214 ) -> Result<(), InvalidId> {
215 let mut state = self.state.lock().unwrap();
216 let state = (&mut *state as &mut dyn ErasedState)
217 .downcast_mut::<State<D>>()
218 .expect("Wrong type parameter passed to Handle::set_object_data().");
219 state.clients.get_client_mut(id.client_id.clone())?.set_object_data(id, data)
220 }
221
222 pub fn post_error(&self, object_id: InnerObjectId, error_code: u32, message: CString) {
223 self.state.lock().unwrap().post_error(object_id, error_code, message)
224 }
225
226 pub fn kill_client(&self, client_id: InnerClientId, reason: DisconnectReason) {
227 self.state.lock().unwrap().kill_client(client_id, reason)
228 }
229
230 pub fn create_global<D: 'static>(
231 &self,
232 interface: &'static Interface,
233 version: u32,
234 handler: Arc<dyn GlobalHandler<D>>,
235 ) -> InnerGlobalId {
236 let mut state = self.state.lock().unwrap();
237 let state = (&mut *state as &mut dyn ErasedState)
238 .downcast_mut::<State<D>>()
239 .expect("Wrong type parameter passed to Handle::create_global().");
240 state.registry.create_global(interface, version, handler, &mut state.clients)
241 }
242
243 pub fn disable_global<D: 'static>(&self, id: InnerGlobalId) {
244 let mut state = self.state.lock().unwrap();
245 let state = (&mut *state as &mut dyn ErasedState)
246 .downcast_mut::<State<D>>()
247 .expect("Wrong type parameter passed to Handle::create_global().");
248
249 state.registry.disable_global(id, &mut state.clients)
250 }
251
252 pub fn remove_global<D: 'static>(&self, id: InnerGlobalId) {
253 let mut state = self.state.lock().unwrap();
254 let state = (&mut *state as &mut dyn ErasedState)
255 .downcast_mut::<State<D>>()
256 .expect("Wrong type parameter passed to Handle::create_global().");
257
258 state.registry.remove_global(id, &mut state.clients)
259 }
260
261 pub fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
262 self.state.lock().unwrap().global_info(id)
263 }
264
265 pub fn get_global_handler<D: 'static>(
266 &self,
267 id: InnerGlobalId,
268 ) -> Result<Arc<dyn GlobalHandler<D>>, InvalidId> {
269 let mut state = self.state.lock().unwrap();
270 let state = (&mut *state as &mut dyn ErasedState)
271 .downcast_mut::<State<D>>()
272 .expect("Wrong type parameter passed to Handle::get_global_handler().");
273 state.registry.get_handler(id)
274 }
275
276 pub fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
277 self.state.lock().unwrap().flush(client)
278 }
279}
280
281pub(crate) trait ErasedState: downcast_rs::Downcast {
282 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId>;
283 fn insert_client(
284 &mut self,
285 stream: UnixStream,
286 data: Arc<dyn ClientData>,
287 ) -> std::io::Result<InnerClientId>;
288 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId>;
289 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId>;
290 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId>;
291 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId));
292 fn with_all_objects_for(
293 &self,
294 client_id: InnerClientId,
295 f: &mut dyn FnMut(ObjectId),
296 ) -> Result<(), InvalidId>;
297 fn object_for_protocol_id(
298 &self,
299 client_id: InnerClientId,
300 interface: &'static Interface,
301 protocol_id: u32,
302 ) -> Result<ObjectId, InvalidId>;
303 fn get_object_data_any(
304 &self,
305 id: InnerObjectId,
306 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId>;
307 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId>;
308 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString);
309 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason);
310 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId>;
311 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()>;
312}
313
314downcast_rs::impl_downcast!(ErasedState);
315
316impl<D> ErasedState for State<D> {
317 fn object_info(&self, id: InnerObjectId) -> Result<ObjectInfo, InvalidId> {
318 self.clients.get_client(id.client_id.clone())?.object_info(id)
319 }
320
321 fn insert_client(
322 &mut self,
323 stream: UnixStream,
324 data: Arc<dyn ClientData>,
325 ) -> std::io::Result<InnerClientId> {
326 let id = self.clients.create_client(stream, data);
327 let client = self.clients.get_client(id.clone()).unwrap();
328
329 #[cfg(any(target_os = "linux", target_os = "android"))]
331 let ret = {
332 use rustix::event::epoll;
333 epoll::add(
334 &self.poll_fd,
335 client,
336 epoll::EventData::new_u64(id.as_u64()),
337 epoll::EventFlags::IN,
338 )
339 };
340
341 #[cfg(any(
342 target_os = "dragonfly",
343 target_os = "freebsd",
344 target_os = "netbsd",
345 target_os = "openbsd",
346 target_os = "macos"
347 ))]
348 let ret = {
349 use rustix::event::kqueue::*;
350 use std::os::unix::io::{AsFd, AsRawFd};
351
352 let evt = Event::new(
353 EventFilter::Read(client.as_fd().as_raw_fd()),
354 EventFlags::ADD | EventFlags::RECEIPT,
355 id.as_u64() as *mut _,
356 );
357
358 let events: &mut [Event] = &mut [];
359 unsafe { kevent(&self.poll_fd, &[evt], events, None).map(|_| ()) }
360 };
361
362 match ret {
363 Ok(()) => Ok(id),
364 Err(e) => {
365 self.kill_client(id, DisconnectReason::ConnectionClosed);
366 Err(e.into())
367 }
368 }
369 }
370
371 fn get_client(&self, id: InnerObjectId) -> Result<ClientId, InvalidId> {
372 if self.clients.get_client(id.client_id.clone()).is_ok() {
373 Ok(ClientId { id: id.client_id })
374 } else {
375 Err(InvalidId)
376 }
377 }
378
379 fn get_client_data(&self, id: InnerClientId) -> Result<Arc<dyn ClientData>, InvalidId> {
380 let client = self.clients.get_client(id)?;
381 Ok(client.data.clone())
382 }
383
384 fn get_client_credentials(&self, id: InnerClientId) -> Result<Credentials, InvalidId> {
385 let client = self.clients.get_client(id)?;
386 Ok(client.get_credentials())
387 }
388
389 fn with_all_clients(&self, f: &mut dyn FnMut(ClientId)) {
390 for client in self.clients.all_clients_id() {
391 f(client)
392 }
393 }
394
395 fn with_all_objects_for(
396 &self,
397 client_id: InnerClientId,
398 f: &mut dyn FnMut(ObjectId),
399 ) -> Result<(), InvalidId> {
400 let client = self.clients.get_client(client_id)?;
401 for object in client.all_objects() {
402 f(object)
403 }
404 Ok(())
405 }
406
407 fn object_for_protocol_id(
408 &self,
409 client_id: InnerClientId,
410 interface: &'static Interface,
411 protocol_id: u32,
412 ) -> Result<ObjectId, InvalidId> {
413 let client = self.clients.get_client(client_id)?;
414 let object = client.object_for_protocol_id(protocol_id)?;
415 if same_interface(interface, object.interface) {
416 Ok(ObjectId { id: object })
417 } else {
418 Err(InvalidId)
419 }
420 }
421
422 fn get_object_data_any(
423 &self,
424 id: InnerObjectId,
425 ) -> Result<Arc<dyn std::any::Any + Send + Sync>, InvalidId> {
426 self.clients
427 .get_client(id.client_id.clone())?
428 .get_object_data(id)
429 .map(|arc| arc.into_any_arc())
430 }
431
432 fn send_event(&mut self, msg: Message<ObjectId, RawFd>) -> Result<(), InvalidId> {
433 self.clients
434 .get_client_mut(msg.sender_id.id.client_id.clone())?
435 .send_event(msg, Some(&mut self.pending_destructors))
436 }
437
438 fn post_error(&mut self, object_id: InnerObjectId, error_code: u32, message: CString) {
439 if let Ok(client) = self.clients.get_client_mut(object_id.client_id.clone()) {
440 client.post_error(object_id, error_code, message)
441 }
442 }
443
444 fn kill_client(&mut self, client_id: InnerClientId, reason: DisconnectReason) {
445 if let Ok(client) = self.clients.get_client_mut(client_id) {
446 client.kill(reason)
447 }
448 }
449 fn global_info(&self, id: InnerGlobalId) -> Result<GlobalInfo, InvalidId> {
450 self.registry.get_info(id)
451 }
452
453 fn flush(&mut self, client: Option<ClientId>) -> std::io::Result<()> {
454 self.flush(client)
455 }
456}