Home | History | Annotate | Download | only in unittest
      1 /*
      2   This file is part of ThreadSanitizer, a dynamic data race detector.
      3 
      4   Copyright (C) 2008-2009 Google Inc
      5      opensource (at) google.com
      6 
      7   This program is free software; you can redistribute it and/or
      8   modify it under the terms of the GNU General Public License as
      9   published by the Free Software Foundation; either version 2 of the
     10   License, or (at your option) any later version.
     11 
     12   This program is distributed in the hope that it will be useful, but
     13   WITHOUT ANY WARRANTY; without even the implied warranty of
     14   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     15   General Public License for more details.
     16 
     17   You should have received a copy of the GNU General Public License
     18   along with this program; if not, write to the Free Software
     19   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
     20   02111-1307, USA.
     21 
     22   The GNU General Public License is contained in the file COPYING.
     23 */
     24 
     25 // Author: Konstantin Serebryany <opensource (at) google.com>
     26 //
     27 // Here we define a few simple classes that wrap threading primitives.
     28 //
     29 // We need this to create unit tests for ThreadSanitizer (or similar tools)
     30 // that will work with different threading frameworks.
     31 //
     32 // Note, that some of the methods defined here are annotated with
     33 // ANNOTATE_* macros defined in dynamic_annotations.h.
     34 //
     35 // DISCLAIMER: the classes defined in this header file
     36 // are NOT intended for general use -- only for unit tests.
     37 
     38 #ifndef THREAD_WRAPPERS_H
     39 #define THREAD_WRAPPERS_H
     40 
     41 #include <assert.h>
     42 #include <limits.h>   // INT_MAX
     43 #include <queue>
     44 #include <stdio.h>
     45 #include <string>
     46 #include <time.h>
     47 
     48 #include "dynamic_annotations.h"
     49 
     50 using namespace std;
     51 
     52 #ifdef NDEBUG
     53 # error "Pleeease, do not define NDEBUG"
     54 #endif
     55 
     56 #ifdef WIN32
     57 # define CHECK(x) do { if (!(x)) { \
     58    fprintf(stderr, "Assertion failed: %s (%s:%d) %s\n", \
     59           __FUNCTION__, __FILE__, __LINE__, #x); \
     60    exit(1); }} while (0)
     61 #else
     62 # define CHECK assert
     63 #endif
     64 
     65 /// Just a boolean condition. Used by Mutex::LockWhen and similar.
     66 class Condition {
     67  public:
     68   typedef bool (*func_t)(void*);
     69 
     70   template <typename T>
     71   Condition(bool (*func)(T*), T* arg)
     72   : func_(reinterpret_cast<func_t>(func)), arg_(arg) {}
     73 
     74   Condition(bool (*func)())
     75   : func_(reinterpret_cast<func_t>(func)), arg_(NULL) {}
     76 
     77   bool Eval() { return func_(arg_); }
     78  private:
     79   func_t func_;
     80   void *arg_;
     81 };
     82 
     83 // Define platform-specific types, constant and functions {{{1
     84 static int AtomicIncrement(volatile int *value, int increment);
     85 static int GetTimeInMs();
     86 
     87 class CondVar;
     88 class MyThread;
     89 class Mutex;
     90 //}}}
     91 
     92 // Include platform-specific header with declaraions.
     93 #ifndef WIN32
     94 // Include pthread primitives (Linux, Mac)
     95 #include "thread_wrappers_pthread.h"
     96 #else
     97 // Include Windows primitives
     98 #include "thread_wrappers_win.h"
     99 #endif
    100 
    101 // Define cross-platform types synchronization primitives {{{1
    102 /// Just a message queue.
    103 class ProducerConsumerQueue {
    104  public:
    105   ProducerConsumerQueue(int unused) {
    106     //ANNOTATE_PCQ_CREATE(this);
    107   }
    108   ~ProducerConsumerQueue() {
    109     CHECK(q_.empty());
    110     //ANNOTATE_PCQ_DESTROY(this);
    111   }
    112 
    113   // Put.
    114   void Put(void *item) {
    115     mu_.Lock();
    116       q_.push(item);
    117       ANNOTATE_CONDVAR_SIGNAL(&mu_); // LockWhen in Get()
    118       //ANNOTATE_PCQ_PUT(this);
    119     mu_.Unlock();
    120   }
    121 
    122   // Get.
    123   // Blocks if the queue is empty.
    124   void *Get() {
    125     mu_.LockWhen(Condition(IsQueueNotEmpty, &q_));
    126       void * item;
    127       bool ok = TryGetInternal(&item);
    128       CHECK(ok);
    129     mu_.Unlock();
    130     return item;
    131   }
    132 
    133   // If queue is not empty,
    134   // remove an element from queue, put it into *res and return true.
    135   // Otherwise return false.
    136   bool TryGet(void **res) {
    137     mu_.Lock();
    138       bool ok = TryGetInternal(res);
    139     mu_.Unlock();
    140     return ok;
    141   }
    142 
    143  private:
    144   Mutex mu_;
    145   std::queue<void*> q_; // protected by mu_
    146 
    147   // Requires mu_
    148   bool TryGetInternal(void ** item_ptr) {
    149     if (q_.empty())
    150       return false;
    151     *item_ptr = q_.front();
    152     q_.pop();
    153     //ANNOTATE_PCQ_GET(this);
    154     return true;
    155   }
    156 
    157   static bool IsQueueNotEmpty(std::queue<void*> * queue) {
    158      return !queue->empty();
    159   }
    160 };
    161 
    162 /// Function pointer with zero, one or two parameters.
    163 struct Closure {
    164   typedef void (*F0)();
    165   typedef void (*F1)(void *arg1);
    166   typedef void (*F2)(void *arg1, void *arg2);
    167   int  n_params;
    168   void *f;
    169   void *param1;
    170   void *param2;
    171 
    172   void Execute() {
    173     if (n_params == 0) {
    174       (F0(f))();
    175     } else if (n_params == 1) {
    176       (F1(f))(param1);
    177     } else {
    178       CHECK(n_params == 2);
    179       (F2(f))(param1, param2);
    180     }
    181     delete this;
    182   }
    183 };
    184 
    185 static Closure *NewCallback(void (*f)()) {
    186   Closure *res = new Closure;
    187   res->n_params = 0;
    188   res->f = (void*)(f);
    189   res->param1 = NULL;
    190   res->param2 = NULL;
    191   return res;
    192 }
    193 
    194 template <class P1>
    195 Closure *NewCallback(void (*f)(P1), P1 p1) {
    196   CHECK(sizeof(P1) <= sizeof(void*));
    197   Closure *res = new Closure;
    198   res->n_params = 1;
    199   res->f = (void*)(f);
    200   res->param1 = (void*)(intptr_t)p1;
    201   res->param2 = NULL;
    202   return res;
    203 }
    204 
    205 template <class P1, class P2>
    206 Closure *NewCallback(void (*f)(P1, P2), P1 p1, P2 p2) {
    207   CHECK(sizeof(P1) <= sizeof(void*));
    208   CHECK(sizeof(P2) <= sizeof(void*));
    209   Closure *res = new Closure;
    210   res->n_params = 2;
    211   res->f = (void*)(f);
    212   res->param1 = (void*)p1;
    213   res->param2 = (void*)p2;
    214   return res;
    215 }
    216 
    217 /*! A thread pool that uses ProducerConsumerQueue.
    218   Usage:
    219   {
    220     ThreadPool pool(n_workers);
    221     pool.StartWorkers();
    222     pool.Add(NewCallback(func_with_no_args));
    223     pool.Add(NewCallback(func_with_one_arg, arg));
    224     pool.Add(NewCallback(func_with_two_args, arg1, arg2));
    225     ... // more calls to pool.Add()
    226 
    227     // the ~ThreadPool() is called: we wait workers to finish
    228     // and then join all threads in the pool.
    229   }
    230 */
    231 class ThreadPool {
    232  public:
    233   //! Create n_threads threads, but do not start.
    234   explicit ThreadPool(int n_threads)
    235     : queue_(INT_MAX) {
    236     for (int i = 0; i < n_threads; i++) {
    237       MyThread *thread = new MyThread(&ThreadPool::Worker, this);
    238       workers_.push_back(thread);
    239     }
    240   }
    241 
    242   //! Start all threads.
    243   void StartWorkers() {
    244     for (size_t i = 0; i < workers_.size(); i++) {
    245       workers_[i]->Start();
    246     }
    247   }
    248 
    249   //! Add a closure.
    250   void Add(Closure *closure) {
    251     queue_.Put(closure);
    252   }
    253 
    254   int num_threads() { return workers_.size();}
    255 
    256   //! Wait workers to finish, then join all threads.
    257   ~ThreadPool() {
    258     for (size_t i = 0; i < workers_.size(); i++) {
    259       Add(NULL);
    260     }
    261     for (size_t i = 0; i < workers_.size(); i++) {
    262       workers_[i]->Join();
    263       delete workers_[i];
    264     }
    265   }
    266  private:
    267   std::vector<MyThread*>   workers_;
    268   ProducerConsumerQueue  queue_;
    269 
    270   static void *Worker(void *p) {
    271     ThreadPool *pool = reinterpret_cast<ThreadPool*>(p);
    272     while (true) {
    273       Closure *closure = reinterpret_cast<Closure*>(pool->queue_.Get());
    274       if(closure == NULL) {
    275         return NULL;
    276       }
    277       closure->Execute();
    278     }
    279   }
    280 };
    281 
    282 class MutexLock {  // Scoped Mutex Locker/Unlocker
    283  public:
    284   MutexLock(Mutex *mu)
    285     : mu_(mu) {
    286     mu_->Lock();
    287   }
    288   ~MutexLock() {
    289     mu_->Unlock();
    290   }
    291  private:
    292   Mutex *mu_;
    293 };
    294 
    295 class BlockingCounter {
    296  public:
    297   explicit BlockingCounter(int initial_count) :
    298     count_(initial_count) {}
    299   bool DecrementCount() {
    300     MutexLock lock(&mu_);
    301     count_--;
    302     return count_ == 0;
    303   }
    304   void Wait() {
    305     mu_.LockWhen(Condition(&IsZero, &count_));
    306     mu_.Unlock();
    307   }
    308  private:
    309   static bool IsZero(int *arg) { return *arg == 0; }
    310   Mutex mu_;
    311   int count_;
    312 };
    313 
    314 //}}}
    315 
    316 #endif // THREAD_WRAPPERS_H
    317 // vim:shiftwidth=2:softtabstop=2:expandtab:foldmethod=marker
    318