1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "test.h" 16 #include "../profiling/pthread_everywhere.h" 17 18 #include <vector> 19 20 #include "../internal/multi_thread_gemm.h" 21 22 namespace gemmlowp { 23 24 class Thread { 25 public: 26 Thread(BlockingCounter* blocking_counter, int number_of_times_to_decrement) 27 : blocking_counter_(blocking_counter), 28 number_of_times_to_decrement_(number_of_times_to_decrement), 29 finished_(false), 30 made_the_last_decrement_(false) { 31 pthread_create(&thread_, nullptr, ThreadFunc, this); 32 } 33 34 ~Thread() { Join(); } 35 36 bool Join() const { 37 if (!finished_) { 38 pthread_join(thread_, nullptr); 39 } 40 return made_the_last_decrement_; 41 } 42 43 private: 44 Thread(const Thread& other) = delete; 45 46 void ThreadFunc() { 47 for (int i = 0; i < number_of_times_to_decrement_; i++) { 48 Check(!made_the_last_decrement_); 49 made_the_last_decrement_ = blocking_counter_->DecrementCount(); 50 } 51 finished_ = true; 52 } 53 54 static void* ThreadFunc(void* ptr) { 55 static_cast<Thread*>(ptr)->ThreadFunc(); 56 return nullptr; 57 } 58 59 BlockingCounter* const blocking_counter_; 60 const int number_of_times_to_decrement_; 61 pthread_t thread_; 62 bool finished_; 63 bool made_the_last_decrement_; 64 }; 65 66 void test_blocking_counter(BlockingCounter* blocking_counter, int num_threads, 67 int num_decrements_per_thread, 68 int num_decrements_to_wait_for) { 69 std::vector<Thread*> threads; 70 blocking_counter->Reset(num_decrements_to_wait_for); 71 for (int i = 0; i < num_threads; i++) { 72 threads.push_back(new Thread(blocking_counter, num_decrements_per_thread)); 73 } 74 blocking_counter->Wait(); 75 76 int num_threads_that_made_the_last_decrement = 0; 77 for (int i = 0; i < num_threads; i++) { 78 if (threads[i]->Join()) { 79 num_threads_that_made_the_last_decrement++; 80 } 81 delete threads[i]; 82 } 83 Check(num_threads_that_made_the_last_decrement == 1); 84 } 85 86 void test_blocking_counter() { 87 BlockingCounter* blocking_counter = new BlockingCounter; 88 89 // repeating the entire test sequence ensures that we test 90 // non-monotonic changes. 91 for (int repeat = 1; repeat <= 2; repeat++) { 92 for (int num_threads = 1; num_threads <= 16; num_threads++) { 93 for (int num_decrements_per_thread = 1; 94 num_decrements_per_thread <= 64 * 1024; 95 num_decrements_per_thread *= 4) { 96 test_blocking_counter(blocking_counter, num_threads, 97 num_decrements_per_thread, 98 num_threads * num_decrements_per_thread); 99 } 100 } 101 } 102 delete blocking_counter; 103 } 104 105 } // end namespace gemmlowp 106 107 int main() { gemmlowp::test_blocking_counter(); } 108