1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/platform/test_benchmark.h" 17 18 #include <cstdio> 19 #include <cstdlib> 20 21 #include <algorithm> 22 #include <vector> 23 #include "tensorflow/core/lib/strings/str_util.h" 24 #include "tensorflow/core/platform/env.h" 25 #include "tensorflow/core/platform/logging.h" 26 #include "tensorflow/core/util/reporter.h" 27 28 namespace tensorflow { 29 namespace testing { 30 31 static std::vector<Benchmark*>* all_benchmarks = nullptr; 32 static std::string label; 33 static int64 bytes_processed; 34 static int64 items_processed; 35 static int64 accum_time = 0; 36 static int64 start_time = 0; 37 static Env* env; 38 39 Benchmark::Benchmark(const char* name, void (*fn)(int)) 40 : name_(name), num_args_(0), fn0_(fn) { 41 args_.push_back(std::make_pair(-1, -1)); 42 Register(); 43 } 44 45 Benchmark::Benchmark(const char* name, void (*fn)(int, int)) 46 : name_(name), num_args_(1), fn1_(fn) { 47 Register(); 48 } 49 50 Benchmark::Benchmark(const char* name, void (*fn)(int, int, int)) 51 : name_(name), num_args_(2), fn2_(fn) { 52 Register(); 53 } 54 55 Benchmark* Benchmark::Arg(int x) { 56 CHECK_EQ(num_args_, 1); 57 args_.push_back(std::make_pair(x, -1)); 58 return this; 59 } 60 61 Benchmark* Benchmark::ArgPair(int x, int y) { 62 CHECK_EQ(num_args_, 2); 63 args_.push_back(std::make_pair(x, y)); 64 return this; 65 } 66 67 namespace { 68 69 void AddRange(std::vector<int>* dst, int lo, int hi, int mult) { 70 CHECK_GE(lo, 0); 71 CHECK_GE(hi, lo); 72 73 // Add "lo" 74 dst->push_back(lo); 75 76 // Now space out the benchmarks in multiples of "mult" 77 for (int32 i = 1; i < kint32max / mult; i *= mult) { 78 if (i >= hi) break; 79 if (i > lo) { 80 dst->push_back(i); 81 } 82 } 83 // Add "hi" (if different from "lo") 84 if (hi != lo) { 85 dst->push_back(hi); 86 } 87 } 88 89 } // namespace 90 91 Benchmark* Benchmark::Range(int lo, int hi) { 92 std::vector<int> args; 93 AddRange(&args, lo, hi, 8); 94 for (int arg : args) { 95 Arg(arg); 96 } 97 return this; 98 } 99 100 Benchmark* Benchmark::RangePair(int lo1, int hi1, int lo2, int hi2) { 101 std::vector<int> args1; 102 std::vector<int> args2; 103 AddRange(&args1, lo1, hi1, 8); 104 AddRange(&args2, lo2, hi2, 8); 105 for (int arg1 : args1) { 106 for (int arg2 : args2) { 107 ArgPair(arg1, arg2); 108 } 109 } 110 return this; 111 } 112 113 void Benchmark::Run(const char* pattern) { 114 if (!all_benchmarks) return; 115 116 // Converts "all" into the wildcard '.*'. Currently pattern isn't 117 // specified by clients, but we keep this here to match the internal 118 // Google implementation, should we ever enable user-specified 119 // pattern specification. 120 if (StringPiece(pattern) == "all") { 121 pattern = ".*"; 122 } 123 124 // Compute name width. 125 int width = 10; 126 string name; 127 for (auto b : *all_benchmarks) { 128 name = b->name_; 129 for (auto arg : b->args_) { 130 name.resize(b->name_.size()); 131 if (arg.first >= 0) { 132 strings::StrAppend(&name, "/", arg.first); 133 if (arg.second >= 0) { 134 strings::StrAppend(&name, "/", arg.second); 135 } 136 } 137 138 // TODO(vrv): Check against 'pattern' using a regex before 139 // computing the width, if we start allowing clients to pass in 140 // a custom pattern. 141 width = std::max<int>(width, name.size()); 142 } 143 } 144 145 printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations"); 146 printf("%s\n", string(width + 22, '-').c_str()); 147 for (auto b : *all_benchmarks) { 148 name = b->name_; 149 for (auto arg : b->args_) { 150 name.resize(b->name_.size()); 151 if (arg.first >= 0) { 152 strings::StrAppend(&name, "/", arg.first); 153 if (arg.second >= 0) { 154 strings::StrAppend(&name, "/", arg.second); 155 } 156 } 157 158 // TODO(vrv): Match 'name' against 'pattern' using a regex 159 // before continuing, if we start allowing clients to pass in a 160 // custom pattern. 161 162 int iters; 163 double seconds; 164 b->Run(arg.first, arg.second, &iters, &seconds); 165 166 char buf[100]; 167 std::string full_label = label; 168 if (bytes_processed > 0) { 169 snprintf(buf, sizeof(buf), " %.1fMB/s", 170 (bytes_processed * 1e-6) / seconds); 171 full_label += buf; 172 } 173 if (items_processed > 0) { 174 snprintf(buf, sizeof(buf), " %.1fM items/s", 175 (items_processed * 1e-6) / seconds); 176 full_label += buf; 177 } 178 printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(), 179 seconds * 1e9 / iters, iters, full_label.c_str()); 180 181 TestReporter reporter(name); 182 Status s = reporter.Initialize(); 183 if (!s.ok()) { 184 LOG(ERROR) << s.ToString(); 185 exit(EXIT_FAILURE); 186 } 187 s = reporter.Benchmark(iters, 0.0, seconds, 188 items_processed * 1e-6 / seconds); 189 if (!s.ok()) { 190 LOG(ERROR) << s.ToString(); 191 exit(EXIT_FAILURE); 192 } 193 s = reporter.Close(); 194 if (!s.ok()) { 195 LOG(ERROR) << s.ToString(); 196 exit(EXIT_FAILURE); 197 } 198 } 199 } 200 } 201 202 void Benchmark::Register() { 203 if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>; 204 all_benchmarks->push_back(this); 205 } 206 207 void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { 208 env = Env::Default(); 209 static const int64 kMinIters = 100; 210 static const int64 kMaxIters = 1000000000; 211 static const double kMinTime = 0.5; 212 int64 iters = kMinIters; 213 while (true) { 214 accum_time = 0; 215 start_time = env->NowMicros(); 216 bytes_processed = -1; 217 items_processed = -1; 218 label.clear(); 219 if (fn0_) { 220 (*fn0_)(iters); 221 } else if (fn1_) { 222 (*fn1_)(iters, arg1); 223 } else { 224 (*fn2_)(iters, arg1, arg2); 225 } 226 StopTiming(); 227 const double seconds = accum_time * 1e-6; 228 if (seconds >= kMinTime || iters >= kMaxIters) { 229 *run_count = iters; 230 *run_seconds = seconds; 231 return; 232 } 233 234 // Update number of iterations. Overshoot by 40% in an attempt 235 // to succeed the next time. 236 double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9); 237 multiplier = std::min(10.0, multiplier); 238 if (multiplier <= 1.0) multiplier *= 2.0; 239 iters = std::max<int64>(multiplier * iters, iters + 1); 240 iters = std::min(iters, kMaxIters); 241 } 242 } 243 244 // TODO(vrv): Add support for running a subset of benchmarks by having 245 // RunBenchmarks take in a spec (and maybe other options such as 246 // benchmark_min_time, etc). 247 void RunBenchmarks() { Benchmark::Run("all"); } 248 void SetLabel(const std::string& l) { label = l; } 249 void BytesProcessed(int64 n) { bytes_processed = n; } 250 void ItemsProcessed(int64 n) { items_processed = n; } 251 void StartTiming() { 252 if (start_time == 0) start_time = env->NowMicros(); 253 } 254 void StopTiming() { 255 if (start_time != 0) { 256 accum_time += (env->NowMicros() - start_time); 257 start_time = 0; 258 } 259 } 260 void UseRealTime() {} 261 262 } // namespace testing 263 } // namespace tensorflow 264