Home | History | Annotate | Download | only in base
      1 // Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "base/message_loop.h"
      6 #include "base/observer_list.h"
      7 #include "base/observer_list_threadsafe.h"
      8 #include "base/platform_thread.h"
      9 #include "base/ref_counted.h"
     10 #include "testing/gtest/include/gtest/gtest.h"
     11 
     12 using base::Time;
     13 
     14 namespace {
     15 
     16 class ObserverListTest : public testing::Test {
     17 };
     18 
     19 class Foo {
     20  public:
     21   virtual void Observe(int x) = 0;
     22   virtual ~Foo() {}
     23 };
     24 
     25 class Adder : public Foo {
     26  public:
     27   explicit Adder(int scaler) : total(0), scaler_(scaler) {}
     28   virtual void Observe(int x) {
     29     total += x * scaler_;
     30   }
     31   virtual ~Adder() { }
     32   int total;
     33  private:
     34   int scaler_;
     35 };
     36 
     37 class Disrupter : public Foo {
     38  public:
     39   Disrupter(ObserverList<Foo>* list, Foo* doomed)
     40       : list_(list), doomed_(doomed) { }
     41   virtual ~Disrupter() { }
     42   virtual void Observe(int x) {
     43     list_->RemoveObserver(doomed_);
     44   }
     45  private:
     46   ObserverList<Foo>* list_;
     47   Foo* doomed_;
     48 };
     49 
     50 class ThreadSafeDisrupter : public Foo {
     51  public:
     52   ThreadSafeDisrupter(ObserverListThreadSafe<Foo>* list, Foo* doomed)
     53       : list_(list), doomed_(doomed) { }
     54   virtual ~ThreadSafeDisrupter() { }
     55   virtual void Observe(int x) {
     56     list_->RemoveObserver(doomed_);
     57   }
     58  private:
     59   ObserverListThreadSafe<Foo>* list_;
     60   Foo* doomed_;
     61 };
     62 
     63 class AddInObserve : public Foo {
     64  public:
     65   explicit AddInObserve(ObserverList<Foo>* observer_list)
     66       : added(false),
     67         observer_list(observer_list),
     68         adder(1) {
     69   }
     70   virtual void Observe(int x) {
     71     if (!added) {
     72       added = true;
     73       observer_list->AddObserver(&adder);
     74     }
     75   }
     76 
     77   bool added;
     78   ObserverList<Foo>* observer_list;
     79   Adder adder;
     80 };
     81 
     82 
     83 class ObserverListThreadSafeTest : public testing::Test {
     84 };
     85 
     86 static const int kThreadRunTime = 2000;  // ms to run the multi-threaded test.
     87 
     88 // A thread for use in the ThreadSafeObserver test
     89 // which will add and remove itself from the notification
     90 // list repeatedly.
     91 class AddRemoveThread : public PlatformThread::Delegate,
     92                         public Foo {
     93  public:
     94   AddRemoveThread(ObserverListThreadSafe<Foo>* list, bool notify)
     95       : list_(list),
     96         in_list_(false),
     97         start_(Time::Now()),
     98         count_observes_(0),
     99         count_addtask_(0),
    100         do_notifies_(notify) {
    101     factory_ = new ScopedRunnableMethodFactory<AddRemoveThread>(this);
    102   }
    103 
    104   virtual ~AddRemoveThread() {
    105     delete factory_;
    106   }
    107 
    108   void ThreadMain() {
    109     loop_ = new MessageLoop();  // Fire up a message loop.
    110     loop_->PostTask(FROM_HERE,
    111       factory_->NewRunnableMethod(&AddRemoveThread::AddTask));
    112     loop_->Run();
    113     //LOG(ERROR) << "Loop 0x" << std::hex << loop_ << " done. " <<
    114     //    count_observes_ << ", " << count_addtask_;
    115     delete loop_;
    116     loop_ = reinterpret_cast<MessageLoop*>(0xdeadbeef);
    117     delete this;
    118   }
    119 
    120   // This task just keeps posting to itself in an attempt
    121   // to race with the notifier.
    122   void AddTask() {
    123     count_addtask_++;
    124 
    125     if ((Time::Now() - start_).InMilliseconds() > kThreadRunTime) {
    126       LOG(INFO) << "DONE!";
    127       return;
    128     }
    129 
    130     if (!in_list_) {
    131       list_->AddObserver(this);
    132       in_list_ = true;
    133     }
    134 
    135     if (do_notifies_) {
    136       list_->Notify(&Foo::Observe, 10);
    137     }
    138 
    139     loop_->PostDelayedTask(FROM_HERE,
    140       factory_->NewRunnableMethod(&AddRemoveThread::AddTask), 0);
    141   }
    142 
    143   void Quit() {
    144     loop_->PostTask(FROM_HERE, new MessageLoop::QuitTask());
    145   }
    146 
    147   virtual void Observe(int x) {
    148     count_observes_++;
    149 
    150     // If we're getting called after we removed ourselves from
    151     // the list, that is very bad!
    152     DCHECK(in_list_);
    153 
    154     // This callback should fire on the appropriate thread
    155     EXPECT_EQ(loop_, MessageLoop::current());
    156 
    157     list_->RemoveObserver(this);
    158     in_list_ = false;
    159   }
    160 
    161  private:
    162   ObserverListThreadSafe<Foo>* list_;
    163   MessageLoop* loop_;
    164   bool in_list_;        // Are we currently registered for notifications.
    165                         // in_list_ is only used on |this| thread.
    166   Time start_;          // The time we started the test.
    167 
    168   int count_observes_;  // Number of times we observed.
    169   int count_addtask_;   // Number of times thread AddTask was called
    170   bool do_notifies_;    // Whether these threads should do notifications.
    171 
    172   ScopedRunnableMethodFactory<AddRemoveThread>* factory_;
    173 };
    174 
    175 }  // namespace
    176 
    177 TEST(ObserverListTest, BasicTest) {
    178   ObserverList<Foo> observer_list;
    179   Adder a(1), b(-1), c(1), d(-1);
    180   Disrupter evil(&observer_list, &c);
    181 
    182   observer_list.AddObserver(&a);
    183   observer_list.AddObserver(&b);
    184 
    185   FOR_EACH_OBSERVER(Foo, observer_list, Observe(10));
    186 
    187   observer_list.AddObserver(&evil);
    188   observer_list.AddObserver(&c);
    189   observer_list.AddObserver(&d);
    190 
    191   FOR_EACH_OBSERVER(Foo, observer_list, Observe(10));
    192 
    193   EXPECT_EQ(a.total, 20);
    194   EXPECT_EQ(b.total, -20);
    195   EXPECT_EQ(c.total, 0);
    196   EXPECT_EQ(d.total, -10);
    197 }
    198 
    199 TEST(ObserverListThreadSafeTest, BasicTest) {
    200   MessageLoop loop;
    201 
    202   scoped_refptr<ObserverListThreadSafe<Foo> > observer_list(
    203       new ObserverListThreadSafe<Foo>);
    204   Adder a(1);
    205   Adder b(-1);
    206   Adder c(1);
    207   Adder d(-1);
    208   ThreadSafeDisrupter evil(observer_list.get(), &c);
    209 
    210   observer_list->AddObserver(&a);
    211   observer_list->AddObserver(&b);
    212 
    213   observer_list->Notify(&Foo::Observe, 10);
    214   loop.RunAllPending();
    215 
    216   observer_list->AddObserver(&evil);
    217   observer_list->AddObserver(&c);
    218   observer_list->AddObserver(&d);
    219 
    220   observer_list->Notify(&Foo::Observe, 10);
    221   loop.RunAllPending();
    222 
    223   EXPECT_EQ(a.total, 20);
    224   EXPECT_EQ(b.total, -20);
    225   EXPECT_EQ(c.total, 0);
    226   EXPECT_EQ(d.total, -10);
    227 }
    228 
    229 
    230 // A test driver for a multi-threaded notification loop.  Runs a number
    231 // of observer threads, each of which constantly adds/removes itself
    232 // from the observer list.  Optionally, if cross_thread_notifies is set
    233 // to true, the observer threads will also trigger notifications to
    234 // all observers.
    235 static void ThreadSafeObserverHarness(int num_threads,
    236                                       bool cross_thread_notifies) {
    237   MessageLoop loop;
    238 
    239   const int kMaxThreads = 15;
    240   num_threads = num_threads > kMaxThreads ? kMaxThreads : num_threads;
    241 
    242   scoped_refptr<ObserverListThreadSafe<Foo> > observer_list(
    243       new ObserverListThreadSafe<Foo>);
    244   Adder a(1);
    245   Adder b(-1);
    246   Adder c(1);
    247   Adder d(-1);
    248 
    249   observer_list->AddObserver(&a);
    250   observer_list->AddObserver(&b);
    251 
    252   AddRemoveThread* threaded_observer[kMaxThreads];
    253   PlatformThreadHandle threads[kMaxThreads];
    254   for (int index = 0; index < num_threads; index++) {
    255     threaded_observer[index] = new AddRemoveThread(observer_list.get(), false);
    256     EXPECT_TRUE(PlatformThread::Create(0,
    257                 threaded_observer[index], &threads[index]));
    258   }
    259 
    260   Time start = Time::Now();
    261   while (true) {
    262     if ((Time::Now() - start).InMilliseconds() > kThreadRunTime)
    263       break;
    264 
    265     observer_list->Notify(&Foo::Observe, 10);
    266 
    267     loop.RunAllPending();
    268   }
    269 
    270   for (int index = 0; index < num_threads; index++) {
    271     threaded_observer[index]->Quit();
    272     PlatformThread::Join(threads[index]);
    273   }
    274 }
    275 
    276 TEST(ObserverListThreadSafeTest, CrossThreadObserver) {
    277   // Use 7 observer threads.  Notifications only come from
    278   // the main thread.
    279   ThreadSafeObserverHarness(7, false);
    280 }
    281 
    282 TEST(ObserverListThreadSafeTest, CrossThreadNotifications) {
    283   // Use 3 observer threads.  Notifications will fire from
    284   // the main thread and all 3 observer threads.
    285   ThreadSafeObserverHarness(3, true);
    286 }
    287 
    288 TEST(ObserverListTest, Existing) {
    289   ObserverList<Foo> observer_list(ObserverList<Foo>::NOTIFY_EXISTING_ONLY);
    290   Adder a(1);
    291   AddInObserve b(&observer_list);
    292 
    293   observer_list.AddObserver(&a);
    294   observer_list.AddObserver(&b);
    295 
    296   FOR_EACH_OBSERVER(Foo, observer_list, Observe(1));
    297 
    298   EXPECT_TRUE(b.added);
    299   // B's adder should not have been notified because it was added during
    300   // notificaiton.
    301   EXPECT_EQ(0, b.adder.total);
    302 
    303   // Notify again to make sure b's adder is notified.
    304   FOR_EACH_OBSERVER(Foo, observer_list, Observe(1));
    305   EXPECT_EQ(1, b.adder.total);
    306 }
    307