aboutsummaryrefslogtreecommitdiff
path: root/core/src/async_util/condwait.rs
blob: b96d979aa78c83e87e77ca17eb27a93933d1d52d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
use super::CondVar;
use crate::async_runtime::lock::Mutex;

/// CondWait is a wrapper struct for CondVar with a Mutex boolean flag.
///
/// # Example
///
///```
/// use std::sync::Arc;
///
/// use karyon_core::async_util::CondWait;
/// use karyon_core::async_runtime::spawn;
///
///  async {
///     let cond_wait = Arc::new(CondWait::new());
///     let task = spawn({
///         let cond_wait = cond_wait.clone();
///         async move {
///             cond_wait.wait().await;
///             // ...
///         }
///     });
///
///     cond_wait.signal().await;
///  };
///
/// ```
///
pub struct CondWait {
    /// The CondVar
    condvar: CondVar,
    /// Boolean flag
    w: Mutex<bool>,
}

impl CondWait {
    /// Creates a new CondWait.
    pub fn new() -> Self {
        Self {
            condvar: CondVar::new(),
            w: Mutex::new(false),
        }
    }

    /// Waits for a signal or broadcast.
    pub async fn wait(&self) {
        let mut w = self.w.lock().await;

        // While the boolean flag is false, wait for a signal.
        while !*w {
            w = self.condvar.wait(w).await;
        }
    }

    /// Signal a waiting task.
    pub async fn signal(&self) {
        *self.w.lock().await = true;
        self.condvar.signal();
    }

    /// Signal all waiting tasks.
    pub async fn broadcast(&self) {
        *self.w.lock().await = true;
        self.condvar.broadcast();
    }

    /// Reset the boolean flag value to false.
    pub async fn reset(&self) {
        *self.w.lock().await = false;
    }
}

impl Default for CondWait {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use std::sync::{
        atomic::{AtomicUsize, Ordering},
        Arc,
    };

    use crate::async_runtime::{block_on, spawn};

    use super::*;

    #[test]
    fn test_cond_wait() {
        block_on(async {
            let cond_wait = Arc::new(CondWait::new());
            let count = Arc::new(AtomicUsize::new(0));

            let task = spawn({
                let cond_wait = cond_wait.clone();
                let count = count.clone();
                async move {
                    cond_wait.wait().await;
                    count.fetch_add(1, Ordering::Relaxed);
                    // do something
                }
            });

            // Send a signal to the waiting task
            cond_wait.signal().await;

            let _ = task.await;

            // Reset the boolean flag
            cond_wait.reset().await;

            assert_eq!(count.load(Ordering::Relaxed), 1);

            let task1 = spawn({
                let cond_wait = cond_wait.clone();
                let count = count.clone();
                async move {
                    cond_wait.wait().await;
                    count.fetch_add(1, Ordering::Relaxed);
                    // do something
                }
            });

            let task2 = spawn({
                let cond_wait = cond_wait.clone();
                let count = count.clone();
                async move {
                    cond_wait.wait().await;
                    count.fetch_add(1, Ordering::Relaxed);
                    // do something
                }
            });

            // Broadcast a signal to all waiting tasks
            cond_wait.broadcast().await;

            let _ = task1.await;
            let _ = task2.await;
            assert_eq!(count.load(Ordering::Relaxed), 3);
        });
    }
}