Home | History | Annotate | Download | only in test
      1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "test.h"
     16 #include "../profiling/pthread_everywhere.h"
     17 
     18 #include <vector>
     19 
     20 #include "../internal/multi_thread_gemm.h"
     21 
     22 namespace gemmlowp {
     23 
     24 class Thread {
     25  public:
     26   Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
     27       : blocking_counter_(blocking_counter),
     28         number_of_times_to_decrement_(number_of_times_to_decrement),
     29         finished_(false),
     30         made_the_last_decrement_(false) {
     31     pthread_create(&thread_, nullptr, ThreadFunc, this);
     32   }
     33 
     34   ~Thread() { Join(); }
     35 
     36   bool Join() const {
     37     if (!finished_) {
     38       pthread_join(thread_, nullptr);
     39     }
     40     return made_the_last_decrement_;
     41   }
     42 
     43  private:
     44   Thread(const Thread& other) = delete;
     45 
     46   void ThreadFunc() {
     47     for (int i = 0; i < number_of_times_to_decrement_; i++) {
     48       Check(!made_the_last_decrement_);
     49       made_the_last_decrement_ = blocking_counter_->DecrementCount();
     50     }
     51     finished_ = true;
     52   }
     53 
     54   static void* ThreadFunc(void* ptr) {
     55     static_cast<Thread*>(ptr)->ThreadFunc();
     56     return nullptr;
     57   }
     58 
     59   BlockingCounter* const blocking_counter_;
     60   const int number_of_times_to_decrement_;
     61   pthread_t thread_;
     62   bool finished_;
     63   bool made_the_last_decrement_;
     64 };
     65 
     66 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
     67                            int num_decrements_per_thread,
     68                            int num_decrements_to_wait_for) {
     69   std::vector<Thread*> threads;
     70   blocking_counter->Reset(num_decrements_to_wait_for);
     71   for (int i = 0; i < num_threads; i++) {
     72     threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
     73   }
     74   blocking_counter->Wait();
     75 
     76   int num_threads_that_made_the_last_decrement = 0;
     77   for (int i = 0; i < num_threads; i++) {
     78     if (threads[i]->Join()) {
     79       num_threads_that_made_the_last_decrement++;
     80     }
     81     delete threads[i];
     82   }
     83   Check(num_threads_that_made_the_last_decrement == 1);
     84 }
     85 
     86 void test_blocking_counter() {
     87   BlockingCounter* blocking_counter = new BlockingCounter;
     88 
     89   // repeating the entire test sequence ensures that we test
     90   // non-monotonic changes.
     91   for (int repeat = 1; repeat <= 2; repeat++) {
     92     for (int num_threads = 1; num_threads <= 16; num_threads++) {
     93       for (int num_decrements_per_thread = 1;
     94            num_decrements_per_thread <= 64 * 1024;
     95            num_decrements_per_thread *= 4) {
     96         test_blocking_counter(blocking_counter, num_threads,
     97                               num_decrements_per_thread,
     98                               num_threads * num_decrements_per_thread);
     99       }
    100     }
    101   }
    102   delete blocking_counter;
    103 }
    104 
    105 }  // end namespace gemmlowp
    106 
    107 int main() { gemmlowp::test_blocking_counter(); }
    108