1use std::cmp;
12use std::fmt;
13use std::ops;
14use std::sync::{mpsc, Arc};
15
16use crate::{EventSource, Poll, PostAction, Readiness, Token, TokenFactory};
17
18use super::ping::{make_ping, Ping, PingError, PingSource};
19
20const MAX_EVENTS_CHECK: usize = 1024;
21
22#[derive(Debug)]
24pub enum Event<T> {
25 Msg(T),
27 Closed,
32}
33
34#[derive(Debug)]
35struct PingOnDrop(Ping);
36
37impl ops::Deref for PingOnDrop {
38 type Target = Ping;
39
40 fn deref(&self) -> &Ping {
41 &self.0
42 }
43}
44
45impl Drop for PingOnDrop {
46 fn drop(&mut self) {
47 self.0.ping();
48 }
49}
50
51#[derive(Debug)]
55pub struct Sender<T> {
56 sender: mpsc::Sender<T>,
57 ping: PingOnDrop,
60}
61
62impl<T> Clone for Sender<T> {
63 #[cfg_attr(feature = "nightly_coverage", coverage(off))]
64 fn clone(&self) -> Sender<T> {
65 Sender {
66 sender: self.sender.clone(),
67 ping: PingOnDrop(self.ping.clone()),
68 }
69 }
70}
71
72impl<T> Sender<T> {
73 pub fn send(&self, t: T) -> Result<(), mpsc::SendError<T>> {
78 self.sender.send(t).map(|()| self.ping.ping())
79 }
80}
81
82#[derive(Debug)]
86pub struct SyncSender<T> {
87 sender: mpsc::SyncSender<T>,
88 ping: Arc<PingOnDrop>,
91}
92
93impl<T> Clone for SyncSender<T> {
94 #[cfg_attr(feature = "nightly_coverage", coverage(off))]
95 fn clone(&self) -> SyncSender<T> {
96 SyncSender {
97 sender: self.sender.clone(),
98 ping: self.ping.clone(),
99 }
100 }
101}
102
103impl<T> SyncSender<T> {
104 pub fn send(&self, t: T) -> Result<(), mpsc::SendError<T>> {
114 let ret = self.try_send(t);
115 match ret {
116 Ok(()) => Ok(()),
117 Err(mpsc::TrySendError::Full(t)) => self.sender.send(t).map(|()| self.ping.ping()),
118 Err(mpsc::TrySendError::Disconnected(t)) => Err(mpsc::SendError(t)),
119 }
120 }
121
122 pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError<T>> {
129 let ret = self.sender.try_send(t);
130 if let Ok(()) | Err(mpsc::TrySendError::Full(_)) = ret {
131 self.ping.ping();
132 }
133 ret
134 }
135}
136
137#[derive(Debug)]
141pub struct Channel<T> {
142 receiver: mpsc::Receiver<T>,
143 source: PingSource,
144 ping: Ping,
145 capacity: usize,
146}
147
148unsafe impl<T: Send> Send for Channel<T> {}
154
155impl<T> Channel<T> {
156 pub fn recv(&self) -> Result<T, mpsc::RecvError> {
162 self.receiver.recv()
163 }
164
165 pub fn try_recv(&self) -> Result<T, mpsc::TryRecvError> {
171 self.receiver.try_recv()
172 }
173}
174
175pub fn channel<T>() -> (Sender<T>, Channel<T>) {
177 let (sender, receiver) = mpsc::channel();
178 let (ping, source) = make_ping().expect("Failed to create a Ping.");
179 (
180 Sender {
181 sender,
182 ping: PingOnDrop(ping.clone()),
183 },
184 Channel {
185 receiver,
186 ping,
187 source,
188 capacity: usize::MAX,
189 },
190 )
191}
192
193pub fn sync_channel<T>(bound: usize) -> (SyncSender<T>, Channel<T>) {
195 let (sender, receiver) = mpsc::sync_channel(bound);
196 let (ping, source) = make_ping().expect("Failed to create a Ping.");
197 (
198 SyncSender {
199 sender,
200 ping: Arc::new(PingOnDrop(ping.clone())),
201 },
202 Channel {
203 receiver,
204 source,
205 ping,
206 capacity: bound,
207 },
208 )
209}
210
211impl<T> EventSource for Channel<T> {
212 type Event = Event<T>;
213 type Metadata = ();
214 type Ret = ();
215 type Error = ChannelError;
216
217 fn process_events<C>(
218 &mut self,
219 readiness: Readiness,
220 token: Token,
221 mut callback: C,
222 ) -> Result<PostAction, Self::Error>
223 where
224 C: FnMut(Self::Event, &mut Self::Metadata) -> Self::Ret,
225 {
226 let receiver = &self.receiver;
227 let capacity = self.capacity;
228 let mut clear_readiness = false;
229 let mut disconnected = false;
230
231 let action = self
232 .source
233 .process_events(readiness, token, |(), &mut ()| {
234 let max = cmp::min(capacity.saturating_add(1), MAX_EVENTS_CHECK);
236 for _ in 0..max {
237 match receiver.try_recv() {
238 Ok(val) => callback(Event::Msg(val), &mut ()),
239 Err(mpsc::TryRecvError::Empty) => {
240 clear_readiness = true;
241 break;
242 }
243 Err(mpsc::TryRecvError::Disconnected) => {
244 callback(Event::Closed, &mut ());
245 disconnected = true;
246 break;
247 }
248 }
249 }
250 })
251 .map_err(ChannelError)?;
252
253 if disconnected {
254 Ok(PostAction::Remove)
255 } else if clear_readiness {
256 Ok(action)
257 } else {
258 self.ping.ping();
260 Ok(PostAction::Continue)
261 }
262 }
263
264 fn register(&mut self, poll: &mut Poll, token_factory: &mut TokenFactory) -> crate::Result<()> {
265 self.source.register(poll, token_factory)
266 }
267
268 fn reregister(
269 &mut self,
270 poll: &mut Poll,
271 token_factory: &mut TokenFactory,
272 ) -> crate::Result<()> {
273 self.source.reregister(poll, token_factory)
274 }
275
276 fn unregister(&mut self, poll: &mut Poll) -> crate::Result<()> {
277 self.source.unregister(poll)
278 }
279}
280
281#[derive(Debug)]
283pub struct ChannelError(PingError);
284
285impl fmt::Display for ChannelError {
286 #[cfg_attr(feature = "nightly_coverage", coverage(off))]
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 fmt::Display::fmt(&self.0, f)
289 }
290}
291
292impl std::error::Error for ChannelError {
293 #[cfg_attr(feature = "nightly_coverage", coverage(off))]
294 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
295 Some(&self.0)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn basic_channel() {
305 let mut event_loop = crate::EventLoop::try_new().unwrap();
306
307 let handle = event_loop.handle();
308
309 let (tx, rx) = channel::<()>();
310
311 let mut got = (false, false);
313
314 let _channel_token = handle
315 .insert_source(rx, move |evt, &mut (), got: &mut (bool, bool)| match evt {
316 Event::Msg(()) => {
317 got.0 = true;
318 }
319 Event::Closed => {
320 got.1 = true;
321 }
322 })
323 .unwrap();
324
325 event_loop
327 .dispatch(Some(::std::time::Duration::ZERO), &mut got)
328 .unwrap();
329
330 assert_eq!(got, (false, false));
331
332 tx.send(()).unwrap();
334 event_loop
335 .dispatch(Some(::std::time::Duration::ZERO), &mut got)
336 .unwrap();
337
338 assert_eq!(got, (true, false));
339
340 ::std::mem::drop(tx);
342 event_loop
343 .dispatch(Some(::std::time::Duration::ZERO), &mut got)
344 .unwrap();
345
346 assert_eq!(got, (true, true));
347 }
348
349 #[test]
350 fn basic_sync_channel() {
351 let mut event_loop = crate::EventLoop::try_new().unwrap();
352
353 let handle = event_loop.handle();
354
355 let (tx, rx) = sync_channel::<()>(2);
356
357 let mut received = (0, false);
358
359 let _channel_token = handle
360 .insert_source(
361 rx,
362 move |evt, &mut (), received: &mut (u32, bool)| match evt {
363 Event::Msg(()) => {
364 received.0 += 1;
365 }
366 Event::Closed => {
367 received.1 = true;
368 }
369 },
370 )
371 .unwrap();
372
373 event_loop
375 .dispatch(Some(::std::time::Duration::ZERO), &mut received)
376 .unwrap();
377
378 assert_eq!(received.0, 0);
379 assert!(!received.1);
380
381 tx.send(()).unwrap();
383 tx.send(()).unwrap();
384 assert!(tx.try_send(()).is_err());
385
386 event_loop
388 .dispatch(Some(::std::time::Duration::ZERO), &mut received)
389 .unwrap();
390
391 assert_eq!(received.0, 2);
392 assert!(!received.1);
393
394 tx.send(()).unwrap();
396 std::mem::drop(tx);
397
398 event_loop
400 .dispatch(Some(::std::time::Duration::ZERO), &mut received)
401 .unwrap();
402
403 assert_eq!(received.0, 3);
404 assert!(received.1);
405 }
406
407 #[test]
408 fn test_more_than_1024() {
409 let mut event_loop = crate::EventLoop::try_new().unwrap();
410 let handle = event_loop.handle();
411
412 let (tx, rx) = channel::<()>();
413 let mut received = (0u32, false);
414
415 handle
416 .insert_source(
417 rx,
418 move |evt, &mut (), received: &mut (u32, bool)| match evt {
419 Event::Msg(()) => received.0 += 1,
420 Event::Closed => received.1 = true,
421 },
422 )
423 .unwrap();
424
425 event_loop
426 .dispatch(Some(std::time::Duration::ZERO), &mut received)
427 .unwrap();
428
429 assert_eq!(received.0, 0);
430 assert!(!received.1);
431
432 for _ in 0..MAX_EVENTS_CHECK + 1 {
434 tx.send(()).unwrap();
435 }
436
437 event_loop
438 .dispatch(Some(std::time::Duration::ZERO), &mut received)
439 .unwrap();
440
441 assert_eq!(received.0, MAX_EVENTS_CHECK as u32);
442 assert!(!received.1);
443
444 event_loop
445 .dispatch(Some(std::time::Duration::ZERO), &mut received)
446 .unwrap();
447
448 assert_eq!(received.0, (MAX_EVENTS_CHECK + 1) as u32);
449 assert!(!received.1);
450 }
451}