Home | History | Annotate | Download | only in default
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/platform/test_benchmark.h"
     17 
     18 #include <cstdio>
     19 #include <cstdlib>
     20 
     21 #include <algorithm>
     22 #include <vector>
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 #include "tensorflow/core/platform/env.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 #include "tensorflow/core/util/reporter.h"
     27 
     28 namespace tensorflow {
     29 namespace testing {
     30 
     31 static std::vector<Benchmark*>* all_benchmarks = nullptr;
     32 static std::string label;
     33 static int64 bytes_processed;
     34 static int64 items_processed;
     35 static int64 accum_time = 0;
     36 static int64 start_time = 0;
     37 static Env* env;
     38 
     39 Benchmark::Benchmark(const char* name, void (*fn)(int))
     40     : name_(name), num_args_(0), fn0_(fn) {
     41   args_.push_back(std::make_pair(-1, -1));
     42   Register();
     43 }
     44 
     45 Benchmark::Benchmark(const char* name, void (*fn)(int, int))
     46     : name_(name), num_args_(1), fn1_(fn) {
     47   Register();
     48 }
     49 
     50 Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
     51     : name_(name), num_args_(2), fn2_(fn) {
     52   Register();
     53 }
     54 
     55 Benchmark* Benchmark::Arg(int x) {
     56   CHECK_EQ(num_args_, 1);
     57   args_.push_back(std::make_pair(x, -1));
     58   return this;
     59 }
     60 
     61 Benchmark* Benchmark::ArgPair(int x, int y) {
     62   CHECK_EQ(num_args_, 2);
     63   args_.push_back(std::make_pair(x, y));
     64   return this;
     65 }
     66 
     67 namespace {
     68 
     69 void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
     70   CHECK_GE(lo, 0);
     71   CHECK_GE(hi, lo);
     72 
     73   // Add "lo"
     74   dst->push_back(lo);
     75 
     76   // Now space out the benchmarks in multiples of "mult"
     77   for (int32 i = 1; i < kint32max / mult; i *= mult) {
     78     if (i >= hi) break;
     79     if (i > lo) {
     80       dst->push_back(i);
     81     }
     82   }
     83   // Add "hi" (if different from "lo")
     84   if (hi != lo) {
     85     dst->push_back(hi);
     86   }
     87 }
     88 
     89 }  // namespace
     90 
     91 Benchmark* Benchmark::Range(int lo, int hi) {
     92   std::vector<int> args;
     93   AddRange(&args, lo, hi, 8);
     94   for (int arg : args) {
     95     Arg(arg);
     96   }
     97   return this;
     98 }
     99 
    100 Benchmark* Benchmark::RangePair(int lo1, int hi1, int lo2, int hi2) {
    101   std::vector<int> args1;
    102   std::vector<int> args2;
    103   AddRange(&args1, lo1, hi1, 8);
    104   AddRange(&args2, lo2, hi2, 8);
    105   for (int arg1 : args1) {
    106     for (int arg2 : args2) {
    107       ArgPair(arg1, arg2);
    108     }
    109   }
    110   return this;
    111 }
    112 
    113 void Benchmark::Run(const char* pattern) {
    114   if (!all_benchmarks) return;
    115 
    116   // Converts "all" into the wildcard '.*'.  Currently pattern isn't
    117   // specified by clients, but we keep this here to match the internal
    118   // Google implementation, should we ever enable user-specified
    119   // pattern specification.
    120   if (StringPiece(pattern) == "all") {
    121     pattern = ".*";
    122   }
    123 
    124   // Compute name width.
    125   int width = 10;
    126   string name;
    127   for (auto b : *all_benchmarks) {
    128     name = b->name_;
    129     for (auto arg : b->args_) {
    130       name.resize(b->name_.size());
    131       if (arg.first >= 0) {
    132         strings::StrAppend(&name, "/", arg.first);
    133         if (arg.second >= 0) {
    134           strings::StrAppend(&name, "/", arg.second);
    135         }
    136       }
    137 
    138       // TODO(vrv): Check against 'pattern' using a regex before
    139       // computing the width, if we start allowing clients to pass in
    140       // a custom pattern.
    141       width = std::max<int>(width, name.size());
    142     }
    143   }
    144 
    145   printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations");
    146   printf("%s\n", string(width + 22, '-').c_str());
    147   for (auto b : *all_benchmarks) {
    148     name = b->name_;
    149     for (auto arg : b->args_) {
    150       name.resize(b->name_.size());
    151       if (arg.first >= 0) {
    152         strings::StrAppend(&name, "/", arg.first);
    153         if (arg.second >= 0) {
    154           strings::StrAppend(&name, "/", arg.second);
    155         }
    156       }
    157 
    158       // TODO(vrv): Match 'name' against 'pattern' using a regex
    159       // before continuing, if we start allowing clients to pass in a
    160       // custom pattern.
    161 
    162       int iters;
    163       double seconds;
    164       b->Run(arg.first, arg.second, &iters, &seconds);
    165 
    166       char buf[100];
    167       std::string full_label = label;
    168       if (bytes_processed > 0) {
    169         snprintf(buf, sizeof(buf), " %.1fMB/s",
    170                  (bytes_processed * 1e-6) / seconds);
    171         full_label += buf;
    172       }
    173       if (items_processed > 0) {
    174         snprintf(buf, sizeof(buf), " %.1fM items/s",
    175                  (items_processed * 1e-6) / seconds);
    176         full_label += buf;
    177       }
    178       printf("%-*s %10.0f %10d\t%s\n", width, name.c_str(),
    179              seconds * 1e9 / iters, iters, full_label.c_str());
    180 
    181       TestReporter reporter(name);
    182       Status s = reporter.Initialize();
    183       if (!s.ok()) {
    184         LOG(ERROR) << s.ToString();
    185         exit(EXIT_FAILURE);
    186       }
    187       s = reporter.Benchmark(iters, 0.0, seconds,
    188                              items_processed * 1e-6 / seconds);
    189       if (!s.ok()) {
    190         LOG(ERROR) << s.ToString();
    191         exit(EXIT_FAILURE);
    192       }
    193       s = reporter.Close();
    194       if (!s.ok()) {
    195         LOG(ERROR) << s.ToString();
    196         exit(EXIT_FAILURE);
    197       }
    198     }
    199   }
    200 }
    201 
    202 void Benchmark::Register() {
    203   if (!all_benchmarks) all_benchmarks = new std::vector<Benchmark*>;
    204   all_benchmarks->push_back(this);
    205 }
    206 
    207 void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
    208   env = Env::Default();
    209   static const int64 kMinIters = 100;
    210   static const int64 kMaxIters = 1000000000;
    211   static const double kMinTime = 0.5;
    212   int64 iters = kMinIters;
    213   while (true) {
    214     accum_time = 0;
    215     start_time = env->NowMicros();
    216     bytes_processed = -1;
    217     items_processed = -1;
    218     label.clear();
    219     if (fn0_) {
    220       (*fn0_)(iters);
    221     } else if (fn1_) {
    222       (*fn1_)(iters, arg1);
    223     } else {
    224       (*fn2_)(iters, arg1, arg2);
    225     }
    226     StopTiming();
    227     const double seconds = accum_time * 1e-6;
    228     if (seconds >= kMinTime || iters >= kMaxIters) {
    229       *run_count = iters;
    230       *run_seconds = seconds;
    231       return;
    232     }
    233 
    234     // Update number of iterations.  Overshoot by 40% in an attempt
    235     // to succeed the next time.
    236     double multiplier = 1.4 * kMinTime / std::max(seconds, 1e-9);
    237     multiplier = std::min(10.0, multiplier);
    238     if (multiplier <= 1.0) multiplier *= 2.0;
    239     iters = std::max<int64>(multiplier * iters, iters + 1);
    240     iters = std::min(iters, kMaxIters);
    241   }
    242 }
    243 
    244 // TODO(vrv): Add support for running a subset of benchmarks by having
    245 // RunBenchmarks take in a spec (and maybe other options such as
    246 // benchmark_min_time, etc).
    247 void RunBenchmarks() { Benchmark::Run("all"); }
    248 void SetLabel(const std::string& l) { label = l; }
    249 void BytesProcessed(int64 n) { bytes_processed = n; }
    250 void ItemsProcessed(int64 n) { items_processed = n; }
    251 void StartTiming() {
    252   if (start_time == 0) start_time = env->NowMicros();
    253 }
    254 void StopTiming() {
    255   if (start_time != 0) {
    256     accum_time += (env->NowMicros() - start_time);
    257     start_time = 0;
    258   }
    259 }
    260 void UseRealTime() {}
    261 
    262 }  // namespace testing
    263 }  // namespace tensorflow
    264