Home | History | Annotate | Download | only in ThreadPool
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov (at) google.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
     11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
     12 
     13 namespace Eigen {
     14 
     15 // EventCount allows to wait for arbitrary predicates in non-blocking
     16 // algorithms. Think of condition variable, but wait predicate does not need to
     17 // be protected by a mutex. Usage:
     18 // Waiting thread does:
     19 //
     20 //   if (predicate)
     21 //     return act();
     22 //   EventCount::Waiter& w = waiters[my_index];
     23 //   ec.Prewait(&w);
     24 //   if (predicate) {
     25 //     ec.CancelWait(&w);
     26 //     return act();
     27 //   }
     28 //   ec.CommitWait(&w);
     29 //
     30 // Notifying thread does:
     31 //
     32 //   predicate = true;
     33 //   ec.Notify(true);
     34 //
     35 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
     36 // cheap, but they are executed only if the preceeding predicate check has
     37 // failed.
     38 //
     39 // Algorihtm outline:
     40 // There are two main variables: predicate (managed by user) and state_.
     41 // Operation closely resembles Dekker mutual algorithm:
     42 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
     43 // Waiting thread sets state_ then checks predicate, Notifying thread sets
     44 // predicate then checks state_. Due to seq_cst fences in between these
     45 // operations it is guaranteed than either waiter will see predicate change
     46 // and won't block, or notifying thread will see state_ change and will unblock
     47 // the waiter, or both. But it can't happen that both threads don't see each
     48 // other changes, which would lead to deadlock.
     49 class EventCount {
     50  public:
     51   class Waiter;
     52 
     53   EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) {
     54     eigen_assert(waiters.size() < (1 << kWaiterBits) - 1);
     55     // Initialize epoch to something close to overflow to test overflow.
     56     state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
     57   }
     58 
     59   ~EventCount() {
     60     // Ensure there are no waiters.
     61     eigen_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask);
     62   }
     63 
     64   // Prewait prepares for waiting.
     65   // After calling this function the thread must re-check the wait predicate
     66   // and call either CancelWait or CommitWait passing the same Waiter object.
     67   void Prewait(Waiter* w) {
     68     w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed);
     69     std::atomic_thread_fence(std::memory_order_seq_cst);
     70   }
     71 
     72   // CommitWait commits waiting.
     73   void CommitWait(Waiter* w) {
     74     w->state = Waiter::kNotSignaled;
     75     // Modification epoch of this waiter.
     76     uint64_t epoch =
     77         (w->epoch & kEpochMask) +
     78         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
     79     uint64_t state = state_.load(std::memory_order_seq_cst);
     80     for (;;) {
     81       if (int64_t((state & kEpochMask) - epoch) < 0) {
     82         // The preceeding waiter has not decided on its fate. Wait until it
     83         // calls either CancelWait or CommitWait, or is notified.
     84         EIGEN_THREAD_YIELD();
     85         state = state_.load(std::memory_order_seq_cst);
     86         continue;
     87       }
     88       // We've already been notified.
     89       if (int64_t((state & kEpochMask) - epoch) > 0) return;
     90       // Remove this thread from prewait counter and add it to the waiter list.
     91       eigen_assert((state & kWaiterMask) != 0);
     92       uint64_t newstate = state - kWaiterInc + kEpochInc;
     93       newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
     94       if ((state & kStackMask) == kStackMask)
     95         w->next.store(nullptr, std::memory_order_relaxed);
     96       else
     97         w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
     98       if (state_.compare_exchange_weak(state, newstate,
     99                                        std::memory_order_release))
    100         break;
    101     }
    102     Park(w);
    103   }
    104 
    105   // CancelWait cancels effects of the previous Prewait call.
    106   void CancelWait(Waiter* w) {
    107     uint64_t epoch =
    108         (w->epoch & kEpochMask) +
    109         (((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
    110     uint64_t state = state_.load(std::memory_order_relaxed);
    111     for (;;) {
    112       if (int64_t((state & kEpochMask) - epoch) < 0) {
    113         // The preceeding waiter has not decided on its fate. Wait until it
    114         // calls either CancelWait or CommitWait, or is notified.
    115         EIGEN_THREAD_YIELD();
    116         state = state_.load(std::memory_order_relaxed);
    117         continue;
    118       }
    119       // We've already been notified.
    120       if (int64_t((state & kEpochMask) - epoch) > 0) return;
    121       // Remove this thread from prewait counter.
    122       eigen_assert((state & kWaiterMask) != 0);
    123       if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
    124                                        std::memory_order_relaxed))
    125         return;
    126     }
    127   }
    128 
    129   // Notify wakes one or all waiting threads.
    130   // Must be called after changing the associated wait predicate.
    131   void Notify(bool all) {
    132     std::atomic_thread_fence(std::memory_order_seq_cst);
    133     uint64_t state = state_.load(std::memory_order_acquire);
    134     for (;;) {
    135       // Easy case: no waiters.
    136       if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
    137         return;
    138       uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
    139       uint64_t newstate;
    140       if (all) {
    141         // Reset prewait counter and empty wait list.
    142         newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
    143       } else if (waiters) {
    144         // There is a thread in pre-wait state, unblock it.
    145         newstate = state + kEpochInc - kWaiterInc;
    146       } else {
    147         // Pop a waiter from list and unpark it.
    148         Waiter* w = &waiters_[state & kStackMask];
    149         Waiter* wnext = w->next.load(std::memory_order_relaxed);
    150         uint64_t next = kStackMask;
    151         if (wnext != nullptr) next = wnext - &waiters_[0];
    152         // Note: we don't add kEpochInc here. ABA problem on the lock-free stack
    153         // can't happen because a waiter is re-pushed onto the stack only after
    154         // it was in the pre-wait state which inevitably leads to epoch
    155         // increment.
    156         newstate = (state & kEpochMask) + next;
    157       }
    158       if (state_.compare_exchange_weak(state, newstate,
    159                                        std::memory_order_acquire)) {
    160         if (!all && waiters) return;  // unblocked pre-wait thread
    161         if ((state & kStackMask) == kStackMask) return;
    162         Waiter* w = &waiters_[state & kStackMask];
    163         if (!all) w->next.store(nullptr, std::memory_order_relaxed);
    164         Unpark(w);
    165         return;
    166       }
    167     }
    168   }
    169 
    170   class Waiter {
    171     friend class EventCount;
    172     // Align to 128 byte boundary to prevent false sharing with other Waiter objects in the same vector.
    173     EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next;
    174     std::mutex mu;
    175     std::condition_variable cv;
    176     uint64_t epoch;
    177     unsigned state;
    178     enum {
    179       kNotSignaled,
    180       kWaiting,
    181       kSignaled,
    182     };
    183   };
    184 
    185  private:
    186   // State_ layout:
    187   // - low kStackBits is a stack of waiters committed wait.
    188   // - next kWaiterBits is count of waiters in prewait state.
    189   // - next kEpochBits is modification counter.
    190   static const uint64_t kStackBits = 16;
    191   static const uint64_t kStackMask = (1ull << kStackBits) - 1;
    192   static const uint64_t kWaiterBits = 16;
    193   static const uint64_t kWaiterShift = 16;
    194   static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
    195                                       << kWaiterShift;
    196   static const uint64_t kWaiterInc = 1ull << kWaiterBits;
    197   static const uint64_t kEpochBits = 32;
    198   static const uint64_t kEpochShift = 32;
    199   static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
    200   static const uint64_t kEpochInc = 1ull << kEpochShift;
    201   std::atomic<uint64_t> state_;
    202   MaxSizeVector<Waiter>& waiters_;
    203 
    204   void Park(Waiter* w) {
    205     std::unique_lock<std::mutex> lock(w->mu);
    206     while (w->state != Waiter::kSignaled) {
    207       w->state = Waiter::kWaiting;
    208       w->cv.wait(lock);
    209     }
    210   }
    211 
    212   void Unpark(Waiter* waiters) {
    213     Waiter* next = nullptr;
    214     for (Waiter* w = waiters; w; w = next) {
    215       next = w->next.load(std::memory_order_relaxed);
    216       unsigned state;
    217       {
    218         std::unique_lock<std::mutex> lock(w->mu);
    219         state = w->state;
    220         w->state = Waiter::kSignaled;
    221       }
    222       // Avoid notifying if it wasn't waiting.
    223       if (state == Waiter::kWaiting) w->cv.notify_one();
    224     }
    225   }
    226 
    227   EventCount(const EventCount&) = delete;
    228   void operator=(const EventCount&) = delete;
    229 };
    230 
    231 }  // namespace Eigen
    232 
    233 #endif  // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H_
    234