Эх сурвалжийг харах

First version of VersionedParker.

Jing Yang 4 жил өмнө
parent
commit
42ef630d55
2 өөрчлөгдсөн 262 нэмэгдсэн , 0 устгасан
  1. 2 0
      src/lib.rs
  2. 260 0
      src/versioned_parker/mod.rs

+ 2 - 0
src/lib.rs

@@ -1,4 +1,6 @@
 //! A collection of synchronization utils for concurrent programming.
 mod carrier;
+mod versioned_parker;
 
 pub use carrier::{Carrier, CarrierRef};
+pub use versioned_parker::{VersionedGuard, VersionedParker};

+ 260 - 0
src/versioned_parker/mod.rs

@@ -0,0 +1,260 @@
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::{Arc, Condvar, Mutex, MutexGuard, WaitTimeoutResult};
+use std::time::Duration;
+
+///
+/// ```
+/// use more_sync::VersionedParker;
+///
+/// let versioned_parker = VersionedParker::new(0);
+/// let mut guard = versioned_parker.lock();
+///
+/// let parker_clone = versioned_parker.clone();
+/// std::thread::spawn(move || {
+///     parker_clone.notify_one_mutate(|i| *i = 16);
+///     // Version is 1, try_notify_all() should fail.
+///     assert!(!parker_clone.try_notify_all(0));
+/// });
+///
+/// guard.wait();
+/// assert_eq!(guard.notify_count(), 1);
+/// assert_eq!(*guard, 16);
+/// ```
+#[derive(Default, Clone)]
+pub struct VersionedParker<T> {
+    inner: Arc<Inner<T>>,
+}
+
+#[derive(Default)]
+struct Inner<T> {
+    version: AtomicUsize,
+    data: Mutex<T>,
+    condvar: Condvar,
+}
+
+impl<T> Inner<T> {
+    fn version(&self) -> usize {
+        self.version.load(Ordering::Acquire)
+    }
+}
+
+impl<T> VersionedParker<T> {
+    pub fn new(data: T) -> Self {
+        Self {
+            inner: Arc::new(Inner {
+                version: AtomicUsize::new(0),
+                data: Mutex::new(data),
+                condvar: Condvar::new(),
+            }),
+        }
+    }
+
+    fn version(&self) -> usize {
+        self.inner.version()
+    }
+
+    fn do_notify(
+        &self,
+        expected_version: Option<usize>,
+        mutate: fn(&mut T),
+        notify: fn(&Condvar),
+    ) -> bool {
+        let mut guard = self.inner.data.lock().unwrap();
+        if expected_version
+            .map(|v| v == self.version())
+            .unwrap_or(true)
+        {
+            self.inner.version.fetch_add(1, Ordering::AcqRel);
+            mutate(guard.deref_mut());
+            notify(&self.inner.condvar);
+            return true;
+        }
+        return false;
+    }
+
+    pub fn notify_one(&self) {
+        self.do_notify(None, |_| {}, Condvar::notify_one);
+    }
+
+    pub fn notify_one_mutate(&self, mutate: fn(&mut T)) {
+        self.do_notify(None, mutate, Condvar::notify_one);
+    }
+
+    pub fn try_notify_one(&self, expected_version: usize) -> bool {
+        self.do_notify(Some(expected_version), |_| {}, Condvar::notify_one)
+    }
+
+    pub fn try_notify_one_mutate(
+        &self,
+        expected_version: usize,
+        mutate: fn(&mut T),
+    ) -> bool {
+        self.do_notify(Some(expected_version), mutate, Condvar::notify_one)
+    }
+
+    pub fn notify_all(&self) {
+        self.do_notify(None, |_| {}, Condvar::notify_all);
+    }
+
+    pub fn notify_all_mutate(&self, mutate: fn(&mut T)) {
+        self.do_notify(None, mutate, Condvar::notify_all);
+    }
+
+    pub fn try_notify_all(&self, expected_version: usize) -> bool {
+        self.do_notify(Some(expected_version), |_| {}, Condvar::notify_all)
+    }
+
+    pub fn try_notify_all_mutate(
+        &self,
+        expected_version: usize,
+        mutate: fn(&mut T),
+    ) -> bool {
+        self.do_notify(Some(expected_version), mutate, Condvar::notify_all)
+    }
+
+    pub fn lock(&self) -> VersionedGuard<T> {
+        let guard = self.inner.data.lock().unwrap();
+        VersionedGuard {
+            parker: self.inner.as_ref(),
+            guard: Some(guard),
+            notify_count: 0,
+        }
+    }
+}
+
+pub struct VersionedGuard<'a, T> {
+    parker: &'a Inner<T>,
+    guard: Option<MutexGuard<'a, T>>,
+    notify_count: usize,
+}
+
+impl<'a, T> VersionedGuard<'a, T> {
+    pub fn version(&self) -> usize {
+        self.parker.version()
+    }
+
+    pub fn notified(&self) -> bool {
+        self.notify_count != 0
+    }
+
+    pub fn notify_count(&self) -> usize {
+        self.notify_count
+    }
+
+    pub fn wait(&mut self) {
+        let guard = self.guard.take().unwrap();
+        let version = self.parker.version();
+
+        self.guard = Some(self.parker.condvar.wait(guard).unwrap());
+        self.notify_count = self.parker.version() - version;
+    }
+
+    pub fn wait_timeout(&mut self, timeout: Duration) -> WaitTimeoutResult {
+        let guard = self.guard.take().unwrap();
+        let version = self.parker.version();
+        let (guard_result, wait_result) =
+            self.parker.condvar.wait_timeout(guard, timeout).unwrap();
+
+        self.guard = Some(guard_result);
+        self.notify_count = self.parker.version() - version;
+
+        wait_result
+    }
+}
+
+impl<'a, T> Deref for VersionedGuard<'a, T> {
+    type Target = T;
+
+    fn deref(&self) -> &Self::Target {
+        self.guard.as_deref().unwrap()
+    }
+}
+
+impl<'a, T> DerefMut for VersionedGuard<'a, T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.guard.as_deref_mut().unwrap()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_basics() {
+        let versioned_parker = VersionedParker::new(0);
+        assert_eq!(versioned_parker.version(), 0);
+
+        versioned_parker.notify_one();
+        assert_eq!(versioned_parker.version(), 1);
+
+        versioned_parker.notify_one_mutate(|i| *i = 32);
+        let mut guard = versioned_parker.lock();
+        assert_eq!(versioned_parker.version(), 2);
+        assert_eq!(*guard, 32);
+
+        let parker_clone = versioned_parker.clone();
+        std::thread::spawn(move || parker_clone.notify_one_mutate(|i| *i = 64));
+        guard.wait();
+
+        assert_eq!(guard.notify_count(), 1);
+        assert_eq!(*guard, 64);
+    }
+
+    #[test]
+    fn test_multiple_notify() {
+        let versioned_parker = VersionedParker::new(0);
+        let mut guard = versioned_parker.lock();
+
+        let parker_clone = versioned_parker.clone();
+        std::thread::spawn(move || {
+            parker_clone.notify_all();
+            parker_clone.notify_all_mutate(|i| *i = 128);
+            parker_clone.notify_one_mutate(|i| *i = 256);
+            parker_clone.notify_one_mutate(|i| *i = 512);
+        });
+
+        guard.wait();
+        let expected_value = match guard.notify_count() {
+            1 => 0,
+            2 => 128,
+            3 => 256,
+            4 => 512,
+            _ => panic!("notify count should not be larger than 3"),
+        };
+        assert_eq!(*guard, expected_value);
+    }
+
+    #[test]
+    fn test_try_notify() {
+        let versioned_parker = VersionedParker::new(0);
+        let mut guard = versioned_parker.lock();
+
+        let parker_clone = versioned_parker.clone();
+        std::thread::spawn(move || {
+            assert!(parker_clone.try_notify_one(0));
+            assert!(!parker_clone.try_notify_all(0));
+        });
+
+        guard.wait();
+        assert_eq!(guard.notify_count(), 1);
+        assert_eq!(*guard, 0);
+    }
+
+    #[test]
+    fn test_try_notify_mutate() {
+        let versioned_parker = VersionedParker::new(0);
+        let mut guard = versioned_parker.lock();
+
+        let parker_clone = versioned_parker.clone();
+        std::thread::spawn(move || {
+            assert!(parker_clone.try_notify_one_mutate(0, |i| *i = 1024));
+            assert!(!parker_clone.try_notify_all_mutate(0, |i| *i = 2048));
+        });
+
+        guard.wait();
+        assert_eq!(guard.notify_count(), 1);
+        assert_eq!(*guard, 1024);
+    }
+}