1 // Copyright 2015 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "test.h" 16 17 #include <pthread.h> 18 #include <vector> 19 20 #include "../internal/multi_thread_gemm.h" 21 22 namespace gemmlowp { 23 24 class Thread { 25 public: 26 Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) 27 : blocking_counter_(blocking_counter), 28 number_of_times_to_decrement_(number_of_times_to_decrement), 29 made_the_last_decrement_(false) { 30 pthread_create(&thread_, nullptr, ThreadFunc, this); 31 } 32 33 ~Thread() { Join(); } 34 35 bool Join() const { 36 pthread_join(thread_, nullptr); 37 return made_the_last_decrement_; 38 } 39 40 private: 41 Thread(const Thread& other) = delete; 42 43 void ThreadFunc() { 44 for (int i = 0; i < number_of_times_to_decrement_; i++) { 45 Check(!made_the_last_decrement_); 46 made_the_last_decrement_ = blocking_counter_->DecrementCount(); 47 } 48 } 49 50 static void* ThreadFunc(void* ptr) { 51 static_cast<Thread*>(ptr)->ThreadFunc(); 52 return nullptr; 53 } 54 55 BlockingCounter* const blocking_counter_; 56 const int number_of_times_to_decrement_; 57 pthread_t thread_; 58 bool made_the_last_decrement_; 59 }; 60 61 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, 62 int num_decrements_per_thread, 63 int num_decrements_to_wait_for) { 64 std::vector<Thread*> threads; 65 blocking_counter->Reset(num_decrements_to_wait_for); 66 for (int i = 0; i < num_threads; i++) { 67 threads.push_back(new Thread(blocking_counter, num_decrements_per_thread)); 68 } 69 blocking_counter->Wait(); 70 71 int num_threads_that_made_the_last_decrement = 0; 72 for (int i = 0; i < num_threads; i++) { 73 if (threads[i]->Join()) { 74 num_threads_that_made_the_last_decrement++; 75 } 76 delete threads[i]; 77 } 78 Check(num_threads_that_made_the_last_decrement == 1); 79 } 80 81 void test_blocking_counter() { 82 BlockingCounter* blocking_counter = new BlockingCounter; 83 84 // repeating the entire test sequence ensures that we test 85 // non-monotonic changes. 86 for (int repeat = 1; repeat <= 2; repeat++) { 87 for (int num_threads = 1; num_threads <= 16; num_threads++) { 88 for (int num_decrements_per_thread = 1; 89 num_decrements_per_thread <= 64 * 1024; 90 num_decrements_per_thread *= 4) { 91 test_blocking_counter(blocking_counter, num_threads, 92 num_decrements_per_thread, 93 num_threads * num_decrements_per_thread); 94 } 95 } 96 } 97 delete blocking_counter; 98 } 99 100 } // end namespace gemmlowp 101 102 int main() { gemmlowp::test_blocking_counter(); } 103