Home | History | Annotate | Download | only in tests
      1 /*
      2  * Test program that illustrates how to annotate a smart pointer
      3  * implementation.  In a multithreaded program the following is relevant when
      4  * working with smart pointers:
      5  * - whether or not the objects pointed at are shared over threads.
      6  * - whether or not the methods of the objects pointed at are thread-safe.
      7  * - whether or not the smart pointer objects are shared over threads.
      8  * - whether or not the smart pointer object itself is thread-safe.
      9  *
     10  * Most smart pointer implementations are not thread-safe
     11  * (e.g. boost::shared_ptr<>, tr1::shared_ptr<> and the smart_ptr<>
     12  * implementation below). This means that it is not safe to modify a shared
     13  * pointer object that is shared over threads without proper synchronization.
     14  *
     15  * Even for non-thread-safe smart pointers it is possible to have different
     16  * threads access the same object via smart pointers without triggering data
     17  * races on the smart pointer objects.
     18  *
     19  * A smart pointer implementation guarantees that the destructor of the object
     20  * pointed at is invoked after the last smart pointer that points to that
     21  * object has been destroyed or reset. Data race detection tools cannot detect
     22  * this ordering without explicit annotation for smart pointers that track
     23  * references without invoking synchronization operations recognized by data
     24  * race detection tools.
     25  */
     26 
     27 
     28 #include <cassert>     // assert()
     29 #include <climits>     // PTHREAD_STACK_MIN
     30 #include <iostream>    // std::cerr
     31 #include <stdlib.h>    // atoi()
     32 #include <vector>
     33 #ifdef _WIN32
     34 #include <process.h>   // _beginthreadex()
     35 #include <windows.h>   // CRITICAL_SECTION
     36 #else
     37 #include <pthread.h>   // pthread_mutex_t
     38 #endif
     39 #include "unified_annotations.h"
     40 
     41 
     42 static bool s_enable_annotations;
     43 
     44 
     45 #ifdef _WIN32
     46 
     47 class AtomicInt32
     48 {
     49 public:
     50   AtomicInt32(const int value = 0) : m_value(value) { }
     51   ~AtomicInt32() { }
     52   LONG operator++() { return InterlockedIncrement(&m_value); }
     53   LONG operator--() { return InterlockedDecrement(&m_value); }
     54 
     55 private:
     56   volatile LONG m_value;
     57 };
     58 
     59 class Mutex
     60 {
     61 public:
     62   Mutex() : m_mutex()
     63   { InitializeCriticalSection(&m_mutex); }
     64   ~Mutex()
     65   { DeleteCriticalSection(&m_mutex); }
     66   void Lock()
     67   { EnterCriticalSection(&m_mutex); }
     68   void Unlock()
     69   { LeaveCriticalSection(&m_mutex); }
     70 
     71 private:
     72   CRITICAL_SECTION m_mutex;
     73 };
     74 
     75 class Thread
     76 {
     77 public:
     78   Thread() : m_thread(INVALID_HANDLE_VALUE) { }
     79   ~Thread() { }
     80   void Create(void* (*pf)(void*), void* arg)
     81   {
     82     WrapperArgs* wrapper_arg_p = new WrapperArgs(pf, arg);
     83     m_thread = reinterpret_cast<HANDLE>(_beginthreadex(NULL, 0, wrapper,
     84 						       wrapper_arg_p, 0, NULL));
     85   }
     86   void Join()
     87   { WaitForSingleObject(m_thread, INFINITE); }
     88 
     89 private:
     90   struct WrapperArgs
     91   {
     92     WrapperArgs(void* (*pf)(void*), void* arg) : m_pf(pf), m_arg(arg) { }
     93 
     94     void* (*m_pf)(void*);
     95     void* m_arg;
     96   };
     97   static unsigned int __stdcall wrapper(void* arg)
     98   {
     99     WrapperArgs* wrapper_arg_p = reinterpret_cast<WrapperArgs*>(arg);
    100     WrapperArgs wa = *wrapper_arg_p;
    101     delete wrapper_arg_p;
    102     return reinterpret_cast<unsigned>((wa.m_pf)(wa.m_arg));
    103   }
    104   HANDLE m_thread;
    105 };
    106 
    107 #else // _WIN32
    108 
    109 class AtomicInt32
    110 {
    111 public:
    112   AtomicInt32(const int value = 0) : m_value(value) { }
    113   ~AtomicInt32() { }
    114   int operator++() { return __sync_add_and_fetch(&m_value, 1); }
    115   int operator--() { return __sync_sub_and_fetch(&m_value, 1); }
    116 private:
    117   volatile int m_value;
    118 };
    119 
    120 class Mutex
    121 {
    122 public:
    123   Mutex() : m_mutex()
    124   { pthread_mutex_init(&m_mutex, NULL); }
    125   ~Mutex()
    126   { pthread_mutex_destroy(&m_mutex); }
    127   void Lock()
    128   { pthread_mutex_lock(&m_mutex); }
    129   void Unlock()
    130   { pthread_mutex_unlock(&m_mutex); }
    131 
    132 private:
    133   pthread_mutex_t m_mutex;
    134 };
    135 
    136 class Thread
    137 {
    138 public:
    139   Thread() : m_tid() { }
    140   ~Thread() { }
    141   void Create(void* (*pf)(void*), void* arg)
    142   {
    143     pthread_attr_t attr;
    144     pthread_attr_init(&attr);
    145     pthread_attr_setstacksize(&attr, PTHREAD_STACK_MIN + 4096);
    146     pthread_create(&m_tid, &attr, pf, arg);
    147     pthread_attr_destroy(&attr);
    148   }
    149   void Join()
    150   { pthread_join(m_tid, NULL); }
    151 private:
    152   pthread_t m_tid;
    153 };
    154 
    155 #endif // !defined(_WIN32)
    156 
    157 
    158 template<class T>
    159 class smart_ptr
    160 {
    161 public:
    162   typedef AtomicInt32 counter_t;
    163 
    164   template <typename Q> friend class smart_ptr;
    165 
    166   explicit smart_ptr()
    167     : m_ptr(NULL), m_count_ptr(NULL)
    168   { }
    169 
    170   explicit smart_ptr(T* const pT)
    171     : m_ptr(NULL), m_count_ptr(NULL)
    172   {
    173     set(pT, pT ? new counter_t(0) : NULL);
    174   }
    175 
    176   template <typename Q>
    177   explicit smart_ptr(Q* const q)
    178     : m_ptr(NULL), m_count_ptr(NULL)
    179   {
    180     set(q, q ? new counter_t(0) : NULL);
    181   }
    182 
    183   ~smart_ptr()
    184   {
    185     set(NULL, NULL);
    186   }
    187 
    188   smart_ptr(const smart_ptr<T>& sp)
    189     : m_ptr(NULL), m_count_ptr(NULL)
    190   {
    191     set(sp.m_ptr, sp.m_count_ptr);
    192   }
    193 
    194   template <typename Q>
    195   smart_ptr(const smart_ptr<Q>& sp)
    196     : m_ptr(NULL), m_count_ptr(NULL)
    197   {
    198     set(sp.m_ptr, sp.m_count_ptr);
    199   }
    200 
    201   smart_ptr& operator=(const smart_ptr<T>& sp)
    202   {
    203     set(sp.m_ptr, sp.m_count_ptr);
    204     return *this;
    205   }
    206 
    207   smart_ptr& operator=(T* const p)
    208   {
    209     set(p, p ? new counter_t(0) : NULL);
    210     return *this;
    211   }
    212 
    213   template <typename Q>
    214   smart_ptr& operator=(Q* const q)
    215   {
    216     set(q, q ? new counter_t(0) : NULL);
    217     return *this;
    218   }
    219 
    220   T* operator->() const
    221   {
    222     assert(m_ptr);
    223     return m_ptr;
    224   }
    225 
    226   T& operator*() const
    227   {
    228     assert(m_ptr);
    229     return *m_ptr;
    230   }
    231 
    232 private:
    233   void set(T* const pT, counter_t* const count_ptr)
    234   {
    235     if (m_ptr != pT)
    236     {
    237       if (m_count_ptr)
    238       {
    239 	if (s_enable_annotations)
    240 	  U_ANNOTATE_HAPPENS_BEFORE(m_count_ptr);
    241 	if (--(*m_count_ptr) == 0)
    242 	{
    243 	  if (s_enable_annotations)
    244 	    U_ANNOTATE_HAPPENS_AFTER(m_count_ptr);
    245 	  delete m_ptr;
    246 	  m_ptr = NULL;
    247 	  delete m_count_ptr;
    248 	  m_count_ptr = NULL;
    249 	}
    250       }
    251       m_ptr = pT;
    252       m_count_ptr = count_ptr;
    253       if (count_ptr)
    254 	++(*m_count_ptr);
    255     }
    256   }
    257 
    258   T*         m_ptr;
    259   counter_t* m_count_ptr;
    260 };
    261 
    262 class counter
    263 {
    264 public:
    265   counter()
    266     : m_mutex(), m_count()
    267   { }
    268   ~counter()
    269   {
    270     // Data race detection tools that do not recognize the
    271     // ANNOTATE_HAPPENS_BEFORE() / ANNOTATE_HAPPENS_AFTER() annotations in the
    272     // smart_ptr<> implementation will report that the assignment below
    273     // triggers a data race.
    274     m_count = -1;
    275   }
    276   int get() const
    277   {
    278     int result;
    279     m_mutex.Lock();
    280     result = m_count;
    281     m_mutex.Unlock();
    282     return result;
    283   }
    284   int post_increment()
    285   {
    286     int result;
    287     m_mutex.Lock();
    288     result = m_count++;
    289     m_mutex.Unlock();
    290     return result;
    291   }
    292 
    293 private:
    294   mutable Mutex m_mutex;
    295   int           m_count;
    296 };
    297 
    298 static void* thread_func(void* arg)
    299 {
    300   smart_ptr<counter>* pp = reinterpret_cast<smart_ptr<counter>*>(arg);
    301   (*pp)->post_increment();
    302   *pp = NULL;
    303   delete pp;
    304   return NULL;
    305 }
    306 
    307 int main(int argc, char** argv)
    308 {
    309   const int nthreads = std::max(argc > 1 ? atoi(argv[1]) : 1, 1);
    310   const int iterations = std::max(argc > 2 ? atoi(argv[2]) : 1, 1);
    311   s_enable_annotations = argc > 3 ? !!atoi(argv[3]) : true;
    312 
    313   for (int j = 0; j < iterations; ++j)
    314   {
    315     std::vector<Thread> T(nthreads);
    316     smart_ptr<counter> p(new counter);
    317     p->post_increment();
    318     for (std::vector<Thread>::iterator q = T.begin(); q != T.end(); q++)
    319       q->Create(thread_func, new smart_ptr<counter>(p));
    320     {
    321       // Avoid that counter.m_mutex introduces a false ordering on the
    322       // counter.m_count accesses.
    323       const timespec delay = { 0, 100 * 1000 * 1000 };
    324       nanosleep(&delay, 0);
    325     }
    326     p = NULL;
    327     for (std::vector<Thread>::iterator q = T.begin(); q != T.end(); q++)
    328       q->Join();
    329   }
    330   std::cerr << "Done.\n";
    331   return 0;
    332 }
    333