Home | History | Annotate | Download | only in test
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov (at) google.com>
      5 // Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      6 //
      7 // This Source Code Form is subject to the terms of the Mozilla
      8 // Public License v. 2.0. If a copy of the MPL was not distributed
      9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     10 
     11 #define EIGEN_USE_THREADS
     12 #include "main.h"
     13 #include <Eigen/CXX11/ThreadPool>
     14 
     15 // Visual studio doesn't implement a rand_r() function since its
     16 // implementation of rand() is already thread safe
     17 int rand_reentrant(unsigned int* s) {
     18 #ifdef EIGEN_COMP_MSVC_STRICT
     19   EIGEN_UNUSED_VARIABLE(s);
     20   return rand();
     21 #else
     22   return rand_r(s);
     23 #endif
     24 }
     25 
     26 static void test_basic_eventcount()
     27 {
     28   MaxSizeVector<EventCount::Waiter> waiters(1);
     29   waiters.resize(1);
     30   EventCount ec(waiters);
     31   EventCount::Waiter& w = waiters[0];
     32   ec.Notify(false);
     33   ec.Prewait(&w);
     34   ec.Notify(true);
     35   ec.CommitWait(&w);
     36   ec.Prewait(&w);
     37   ec.CancelWait(&w);
     38 }
     39 
     40 // Fake bounded counter-based queue.
     41 struct TestQueue {
     42   std::atomic<int> val_;
     43   static const int kQueueSize = 10;
     44 
     45   TestQueue() : val_() {}
     46 
     47   ~TestQueue() { VERIFY_IS_EQUAL(val_.load(), 0); }
     48 
     49   bool Push() {
     50     int val = val_.load(std::memory_order_relaxed);
     51     for (;;) {
     52       VERIFY_GE(val, 0);
     53       VERIFY_LE(val, kQueueSize);
     54       if (val == kQueueSize) return false;
     55       if (val_.compare_exchange_weak(val, val + 1, std::memory_order_relaxed))
     56         return true;
     57     }
     58   }
     59 
     60   bool Pop() {
     61     int val = val_.load(std::memory_order_relaxed);
     62     for (;;) {
     63       VERIFY_GE(val, 0);
     64       VERIFY_LE(val, kQueueSize);
     65       if (val == 0) return false;
     66       if (val_.compare_exchange_weak(val, val - 1, std::memory_order_relaxed))
     67         return true;
     68     }
     69   }
     70 
     71   bool Empty() { return val_.load(std::memory_order_relaxed) == 0; }
     72 };
     73 
     74 const int TestQueue::kQueueSize;
     75 
     76 // A number of producers send messages to a set of consumers using a set of
     77 // fake queues. Ensure that it does not crash, consumers don't deadlock and
     78 // number of blocked and unblocked threads match.
     79 static void test_stress_eventcount()
     80 {
     81   const int kThreads = std::thread::hardware_concurrency();
     82   static const int kEvents = 1 << 16;
     83   static const int kQueues = 10;
     84 
     85   MaxSizeVector<EventCount::Waiter> waiters(kThreads);
     86   waiters.resize(kThreads);
     87   EventCount ec(waiters);
     88   TestQueue queues[kQueues];
     89 
     90   std::vector<std::unique_ptr<std::thread>> producers;
     91   for (int i = 0; i < kThreads; i++) {
     92     producers.emplace_back(new std::thread([&ec, &queues]() {
     93       unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
     94       for (int j = 0; j < kEvents; j++) {
     95         unsigned idx = rand_reentrant(&rnd) % kQueues;
     96         if (queues[idx].Push()) {
     97           ec.Notify(false);
     98           continue;
     99         }
    100         EIGEN_THREAD_YIELD();
    101         j--;
    102       }
    103     }));
    104   }
    105 
    106   std::vector<std::unique_ptr<std::thread>> consumers;
    107   for (int i = 0; i < kThreads; i++) {
    108     consumers.emplace_back(new std::thread([&ec, &queues, &waiters, i]() {
    109       EventCount::Waiter& w = waiters[i];
    110       unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
    111       for (int j = 0; j < kEvents; j++) {
    112         unsigned idx = rand_reentrant(&rnd) % kQueues;
    113         if (queues[idx].Pop()) continue;
    114         j--;
    115         ec.Prewait(&w);
    116         bool empty = true;
    117         for (int q = 0; q < kQueues; q++) {
    118           if (!queues[q].Empty()) {
    119             empty = false;
    120             break;
    121           }
    122         }
    123         if (!empty) {
    124           ec.CancelWait(&w);
    125           continue;
    126         }
    127         ec.CommitWait(&w);
    128       }
    129     }));
    130   }
    131 
    132   for (int i = 0; i < kThreads; i++) {
    133     producers[i]->join();
    134     consumers[i]->join();
    135   }
    136 }
    137 
    138 void test_cxx11_eventcount()
    139 {
    140   CALL_SUBTEST(test_basic_eventcount());
    141   CALL_SUBTEST(test_stress_eventcount());
    142 }
    143