Home | History | Annotate | Download | only in test
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "test.h"
     16 
     17 #include <pthread.h>
     18 #include <vector>
     19 
     20 #include "../internal/multi_thread_gemm.h"
     21 
     22 namespace gemmlowp {
     23 
     24 class Thread {
     25  public:
     26   Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement)
     27       : blocking_counter_(blocking_counter),
     28         number_of_times_to_decrement_(number_of_times_to_decrement),
     29         made_the_last_decrement_(false) {
     30     pthread_create(&thread_, nullptr, ThreadFunc, this);
     31   }
     32 
     33   ~Thread() { Join(); }
     34 
     35   bool Join() const {
     36     pthread_join(thread_, nullptr);
     37     return made_the_last_decrement_;
     38   }
     39 
     40  private:
     41   Thread(const Thread& other) = delete;
     42 
     43   void ThreadFunc() {
     44     for (int i = 0; i < number_of_times_to_decrement_; i++) {
     45       Check(!made_the_last_decrement_);
     46       made_the_last_decrement_ = blocking_counter_->DecrementCount();
     47     }
     48   }
     49 
     50   static void* ThreadFunc(void* ptr) {
     51     static_cast<Thread*>(ptr)->ThreadFunc();
     52     return nullptr;
     53   }
     54 
     55   BlockingCounter* const blocking_counter_;
     56   const int number_of_times_to_decrement_;
     57   pthread_t thread_;
     58   bool made_the_last_decrement_;
     59 };
     60 
     61 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads,
     62                            int num_decrements_per_thread,
     63                            int num_decrements_to_wait_for) {
     64   std::vector<Thread*> threads;
     65   blocking_counter->Reset(num_decrements_to_wait_for);
     66   for (int i = 0; i < num_threads; i++) {
     67     threads.push_back(new Thread(blocking_counter, num_decrements_per_thread));
     68   }
     69   blocking_counter->Wait();
     70 
     71   int num_threads_that_made_the_last_decrement = 0;
     72   for (int i = 0; i < num_threads; i++) {
     73     if (threads[i]->Join()) {
     74       num_threads_that_made_the_last_decrement++;
     75     }
     76     delete threads[i];
     77   }
     78   Check(num_threads_that_made_the_last_decrement == 1);
     79 }
     80 
     81 void test_blocking_counter() {
     82   BlockingCounter* blocking_counter = new BlockingCounter;
     83 
     84   // repeating the entire test sequence ensures that we test
     85   // non-monotonic changes.
     86   for (int repeat = 1; repeat <= 2; repeat++) {
     87     for (int num_threads = 1; num_threads <= 16; num_threads++) {
     88       for (int num_decrements_per_thread = 1;
     89            num_decrements_per_thread <= 64 * 1024;
     90            num_decrements_per_thread *= 4) {
     91         test_blocking_counter(blocking_counter, num_threads,
     92                               num_decrements_per_thread,
     93                               num_threads * num_decrements_per_thread);
     94       }
     95     }
     96   }
     97   delete blocking_counter;
     98 }
     99 
    100 }  // end namespace gemmlowp
    101 
    102 int main() { gemmlowp::test_blocking_counter(); }
    103