Home | History | Annotate | Download | only in base
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "base/observer_list.h"
      6 #include "base/observer_list_threadsafe.h"
      7 
      8 #include <vector>
      9 
     10 #include "base/memory/ref_counted.h"
     11 #include "base/message_loop.h"
     12 #include "base/threading/platform_thread.h"
     13 #include "testing/gtest/include/gtest/gtest.h"
     14 
     15 using base::PlatformThread;
     16 using base::Time;
     17 
     18 namespace {
     19 
     20 class Foo {
     21  public:
     22   virtual void Observe(int x) = 0;
     23   virtual ~Foo() {}
     24 };
     25 
     26 class Adder : public Foo {
     27  public:
     28   explicit Adder(int scaler) : total(0), scaler_(scaler) {}
     29   virtual void Observe(int x) {
     30     total += x * scaler_;
     31   }
     32   virtual ~Adder() { }
     33   int total;
     34  private:
     35   int scaler_;
     36 };
     37 
     38 class Disrupter : public Foo {
     39  public:
     40   Disrupter(ObserverList<Foo>* list, Foo* doomed)
     41       : list_(list), doomed_(doomed) { }
     42   virtual ~Disrupter() { }
     43   virtual void Observe(int x) {
     44     list_->RemoveObserver(doomed_);
     45   }
     46  private:
     47   ObserverList<Foo>* list_;
     48   Foo* doomed_;
     49 };
     50 
     51 class ThreadSafeDisrupter : public Foo {
     52  public:
     53   ThreadSafeDisrupter(ObserverListThreadSafe<Foo>* list, Foo* doomed)
     54       : list_(list), doomed_(doomed) { }
     55   virtual ~ThreadSafeDisrupter() { }
     56   virtual void Observe(int x) {
     57     list_->RemoveObserver(doomed_);
     58   }
     59  private:
     60   ObserverListThreadSafe<Foo>* list_;
     61   Foo* doomed_;
     62 };
     63 
     64 class AddInObserve : public Foo {
     65  public:
     66   explicit AddInObserve(ObserverList<Foo>* observer_list)
     67       : added(false),
     68         observer_list(observer_list),
     69         adder(1) {
     70   }
     71   virtual void Observe(int x) {
     72     if (!added) {
     73       added = true;
     74       observer_list->AddObserver(&adder);
     75     }
     76   }
     77 
     78   bool added;
     79   ObserverList<Foo>* observer_list;
     80   Adder adder;
     81 };
     82 
     83 
     84 class ObserverListThreadSafeTest : public testing::Test {
     85 };
     86 
     87 static const int kThreadRunTime = 2000;  // ms to run the multi-threaded test.
     88 
     89 // A thread for use in the ThreadSafeObserver test
     90 // which will add and remove itself from the notification
     91 // list repeatedly.
     92 class AddRemoveThread : public PlatformThread::Delegate,
     93                         public Foo {
     94  public:
     95   AddRemoveThread(ObserverListThreadSafe<Foo>* list, bool notify)
     96       : list_(list),
     97         in_list_(false),
     98         start_(Time::Now()),
     99         count_observes_(0),
    100         count_addtask_(0),
    101         do_notifies_(notify) {
    102     factory_ = new ScopedRunnableMethodFactory<AddRemoveThread>(this);
    103   }
    104 
    105   virtual ~AddRemoveThread() {
    106     delete factory_;
    107   }
    108 
    109   void ThreadMain() {
    110     loop_ = new MessageLoop();  // Fire up a message loop.
    111     loop_->PostTask(
    112         FROM_HERE, factory_->NewRunnableMethod(&AddRemoveThread::AddTask));
    113     loop_->Run();
    114     //LOG(ERROR) << "Loop 0x" << std::hex << loop_ << " done. " <<
    115     //    count_observes_ << ", " << count_addtask_;
    116     delete loop_;
    117     loop_ = reinterpret_cast<MessageLoop*>(0xdeadbeef);
    118     delete this;
    119   }
    120 
    121   // This task just keeps posting to itself in an attempt
    122   // to race with the notifier.
    123   void AddTask() {
    124     count_addtask_++;
    125 
    126     if ((Time::Now() - start_).InMilliseconds() > kThreadRunTime) {
    127       VLOG(1) << "DONE!";
    128       return;
    129     }
    130 
    131     if (!in_list_) {
    132       list_->AddObserver(this);
    133       in_list_ = true;
    134     }
    135 
    136     if (do_notifies_) {
    137       list_->Notify(&Foo::Observe, 10);
    138     }
    139 
    140     loop_->PostDelayedTask(FROM_HERE,
    141       factory_->NewRunnableMethod(&AddRemoveThread::AddTask), 0);
    142   }
    143 
    144   void Quit() {
    145     loop_->PostTask(FROM_HERE, new MessageLoop::QuitTask());
    146   }
    147 
    148   virtual void Observe(int x) {
    149     count_observes_++;
    150 
    151     // If we're getting called after we removed ourselves from
    152     // the list, that is very bad!
    153     DCHECK(in_list_);
    154 
    155     // This callback should fire on the appropriate thread
    156     EXPECT_EQ(loop_, MessageLoop::current());
    157 
    158     list_->RemoveObserver(this);
    159     in_list_ = false;
    160   }
    161 
    162  private:
    163   ObserverListThreadSafe<Foo>* list_;
    164   MessageLoop* loop_;
    165   bool in_list_;        // Are we currently registered for notifications.
    166                         // in_list_ is only used on |this| thread.
    167   Time start_;          // The time we started the test.
    168 
    169   int count_observes_;  // Number of times we observed.
    170   int count_addtask_;   // Number of times thread AddTask was called
    171   bool do_notifies_;    // Whether these threads should do notifications.
    172 
    173   ScopedRunnableMethodFactory<AddRemoveThread>* factory_;
    174 };
    175 
    176 TEST(ObserverListTest, BasicTest) {
    177   ObserverList<Foo> observer_list;
    178   Adder a(1), b(-1), c(1), d(-1);
    179   Disrupter evil(&observer_list, &c);
    180 
    181   observer_list.AddObserver(&a);
    182   observer_list.AddObserver(&b);
    183 
    184   FOR_EACH_OBSERVER(Foo, observer_list, Observe(10));
    185 
    186   observer_list.AddObserver(&evil);
    187   observer_list.AddObserver(&c);
    188   observer_list.AddObserver(&d);
    189 
    190   FOR_EACH_OBSERVER(Foo, observer_list, Observe(10));
    191 
    192   EXPECT_EQ(a.total, 20);
    193   EXPECT_EQ(b.total, -20);
    194   EXPECT_EQ(c.total, 0);
    195   EXPECT_EQ(d.total, -10);
    196 }
    197 
    198 TEST(ObserverListThreadSafeTest, BasicTest) {
    199   MessageLoop loop;
    200 
    201   scoped_refptr<ObserverListThreadSafe<Foo> > observer_list(
    202       new ObserverListThreadSafe<Foo>);
    203   Adder a(1);
    204   Adder b(-1);
    205   Adder c(1);
    206   Adder d(-1);
    207   ThreadSafeDisrupter evil(observer_list.get(), &c);
    208 
    209   observer_list->AddObserver(&a);
    210   observer_list->AddObserver(&b);
    211 
    212   observer_list->Notify(&Foo::Observe, 10);
    213   loop.RunAllPending();
    214 
    215   observer_list->AddObserver(&evil);
    216   observer_list->AddObserver(&c);
    217   observer_list->AddObserver(&d);
    218 
    219   observer_list->Notify(&Foo::Observe, 10);
    220   loop.RunAllPending();
    221 
    222   EXPECT_EQ(a.total, 20);
    223   EXPECT_EQ(b.total, -20);
    224   EXPECT_EQ(c.total, 0);
    225   EXPECT_EQ(d.total, -10);
    226 }
    227 
    228 class FooRemover : public Foo {
    229  public:
    230   explicit FooRemover(ObserverListThreadSafe<Foo>* list) : list_(list) {}
    231   virtual ~FooRemover() {}
    232 
    233   void AddFooToRemove(Foo* foo) {
    234     foos_.push_back(foo);
    235   }
    236 
    237   virtual void Observe(int x) {
    238     std::vector<Foo*> tmp;
    239     tmp.swap(foos_);
    240     for (std::vector<Foo*>::iterator it = tmp.begin();
    241          it != tmp.end(); ++it) {
    242       list_->RemoveObserver(*it);
    243     }
    244   }
    245 
    246  private:
    247   const scoped_refptr<ObserverListThreadSafe<Foo> > list_;
    248   std::vector<Foo*> foos_;
    249 };
    250 
    251 TEST(ObserverListThreadSafeTest, RemoveMultipleObservers) {
    252   MessageLoop loop;
    253   scoped_refptr<ObserverListThreadSafe<Foo> > observer_list(
    254       new ObserverListThreadSafe<Foo>);
    255 
    256   FooRemover a(observer_list);
    257   Adder b(1);
    258 
    259   observer_list->AddObserver(&a);
    260   observer_list->AddObserver(&b);
    261 
    262   a.AddFooToRemove(&a);
    263   a.AddFooToRemove(&b);
    264 
    265   observer_list->Notify(&Foo::Observe, 1);
    266   loop.RunAllPending();
    267 }
    268 
    269 // A test driver for a multi-threaded notification loop.  Runs a number
    270 // of observer threads, each of which constantly adds/removes itself
    271 // from the observer list.  Optionally, if cross_thread_notifies is set
    272 // to true, the observer threads will also trigger notifications to
    273 // all observers.
    274 static void ThreadSafeObserverHarness(int num_threads,
    275                                       bool cross_thread_notifies) {
    276   MessageLoop loop;
    277 
    278   const int kMaxThreads = 15;
    279   num_threads = num_threads > kMaxThreads ? kMaxThreads : num_threads;
    280 
    281   scoped_refptr<ObserverListThreadSafe<Foo> > observer_list(
    282       new ObserverListThreadSafe<Foo>);
    283   Adder a(1);
    284   Adder b(-1);
    285   Adder c(1);
    286   Adder d(-1);
    287 
    288   observer_list->AddObserver(&a);
    289   observer_list->AddObserver(&b);
    290 
    291   AddRemoveThread* threaded_observer[kMaxThreads];
    292   base::PlatformThreadHandle threads[kMaxThreads];
    293   for (int index = 0; index < num_threads; index++) {
    294     threaded_observer[index] = new AddRemoveThread(observer_list.get(), false);
    295     EXPECT_TRUE(PlatformThread::Create(0,
    296                 threaded_observer[index], &threads[index]));
    297   }
    298 
    299   Time start = Time::Now();
    300   while (true) {
    301     if ((Time::Now() - start).InMilliseconds() > kThreadRunTime)
    302       break;
    303 
    304     observer_list->Notify(&Foo::Observe, 10);
    305 
    306     loop.RunAllPending();
    307   }
    308 
    309   for (int index = 0; index < num_threads; index++) {
    310     threaded_observer[index]->Quit();
    311     PlatformThread::Join(threads[index]);
    312   }
    313 }
    314 
    315 TEST(ObserverListThreadSafeTest, CrossThreadObserver) {
    316   // Use 7 observer threads.  Notifications only come from
    317   // the main thread.
    318   ThreadSafeObserverHarness(7, false);
    319 }
    320 
    321 TEST(ObserverListThreadSafeTest, CrossThreadNotifications) {
    322   // Use 3 observer threads.  Notifications will fire from
    323   // the main thread and all 3 observer threads.
    324   ThreadSafeObserverHarness(3, true);
    325 }
    326 
    327 TEST(ObserverListTest, Existing) {
    328   ObserverList<Foo> observer_list(ObserverList<Foo>::NOTIFY_EXISTING_ONLY);
    329   Adder a(1);
    330   AddInObserve b(&observer_list);
    331 
    332   observer_list.AddObserver(&a);
    333   observer_list.AddObserver(&b);
    334 
    335   FOR_EACH_OBSERVER(Foo, observer_list, Observe(1));
    336 
    337   EXPECT_TRUE(b.added);
    338   // B's adder should not have been notified because it was added during
    339   // notificaiton.
    340   EXPECT_EQ(0, b.adder.total);
    341 
    342   // Notify again to make sure b's adder is notified.
    343   FOR_EACH_OBSERVER(Foo, observer_list, Observe(1));
    344   EXPECT_EQ(1, b.adder.total);
    345 }
    346 
    347 class AddInClearObserve : public Foo {
    348  public:
    349   explicit AddInClearObserve(ObserverList<Foo>* list)
    350       : list_(list), added_(false), adder_(1) {}
    351 
    352   virtual void Observe(int /* x */) {
    353     list_->Clear();
    354     list_->AddObserver(&adder_);
    355     added_ = true;
    356   }
    357 
    358   bool added() const { return added_; }
    359   const Adder& adder() const { return adder_; }
    360 
    361  private:
    362   ObserverList<Foo>* const list_;
    363 
    364   bool added_;
    365   Adder adder_;
    366 };
    367 
    368 TEST(ObserverListTest, ClearNotifyAll) {
    369   ObserverList<Foo> observer_list;
    370   AddInClearObserve a(&observer_list);
    371 
    372   observer_list.AddObserver(&a);
    373 
    374   FOR_EACH_OBSERVER(Foo, observer_list, Observe(1));
    375   EXPECT_TRUE(a.added());
    376   EXPECT_EQ(1, a.adder().total)
    377       << "Adder should observe once and have sum of 1.";
    378 }
    379 
    380 TEST(ObserverListTest, ClearNotifyExistingOnly) {
    381   ObserverList<Foo> observer_list(ObserverList<Foo>::NOTIFY_EXISTING_ONLY);
    382   AddInClearObserve a(&observer_list);
    383 
    384   observer_list.AddObserver(&a);
    385 
    386   FOR_EACH_OBSERVER(Foo, observer_list, Observe(1));
    387   EXPECT_TRUE(a.added());
    388   EXPECT_EQ(0, a.adder().total)
    389       << "Adder should not observe, so sum should still be 0.";
    390 }
    391 
    392 }  // namespace
    393