Home | History | Annotate | Download | only in test
      1 // Copyright 2015 Google Inc. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include <unistd.h>
     16 #ifdef __APPLE__
     17 #include <sys/time.h>
     18 #endif
     19 
     20 #include <cstdint>
     21 #include <cstdlib>
     22 #include <ctime>
     23 #include <iomanip>
     24 #include <iostream>
     25 #include <map>
     26 #include <vector>
     27 
     28 #include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
     29 #include "test.h"
     30 
     31 #if defined(__arm__) && !defined(GEMMLOWP_NEON)
     32 #warning "Building without NEON support on ARM, check your compiler setup!"
     33 #endif
     34 
     35 double time() {
     36 #ifdef __APPLE__
     37   timeval t;
     38   gettimeofday(&t, nullptr);
     39   return t.tv_sec + 1e-6 * t.tv_usec;
     40 #else
     41   timespec t;
     42   clock_gettime(CLOCK_REALTIME, &t);
     43   return t.tv_sec + 1e-9 * t.tv_nsec;
     44 #endif
     45 }
     46 
     47 const std::int32_t MIN_WORKING_SET_SIZE = 2 * 1024 * 1024;
     48 const double MIN_OPS = 1000.0 * 1000000.0;
     49 
     50 struct WorkingSet {
     51   WorkingSet() : lhs(nullptr), rhs(nullptr), result(nullptr) {}
     52 
     53   void init(std::int32_t n, std::int32_t m, std::int32_t k) {
     54     lhs = new std::uint8_t[n * k];
     55     rhs = new std::uint8_t[k * m];
     56     result = new std::uint8_t[m * n];
     57   }
     58 
     59   std::uint8_t* lhs;
     60   std::uint8_t* rhs;
     61   std::uint8_t* result;
     62 };
     63 
     64 struct Shape {
     65   std::int32_t n;
     66   std::int32_t m;
     67   std::int32_t k;
     68 
     69   std::int32_t repetitions;
     70   std::int32_t current_set;
     71   std::vector<WorkingSet> working_sets;
     72 
     73   Shape(std::int32_t n, std::int32_t m, std::int32_t k)
     74       : n(n), m(m), k(k), repetitions(1), current_set(0), working_sets() {}
     75 
     76   void init() {
     77     const std::int32_t size = n * k + k * m + n * m;
     78     const std::int32_t count = MIN_WORKING_SET_SIZE / size + 1;
     79     const double ops = static_cast<double>(n) * static_cast<double>(m) *
     80                        static_cast<double>(k);
     81     for (int i = 0; i < count; ++i) {
     82       working_sets.push_back(WorkingSet());
     83       working_sets.back().init(n, m, k);
     84     }
     85     current_set = 0;
     86     repetitions = MIN_OPS / ops + 20;
     87   }
     88 
     89   WorkingSet& working_set() { return working_sets[current_set]; }
     90 
     91   void next_working_set() {
     92     current_set = (current_set + 1) % working_sets.size();
     93   }
     94 };
     95 
     96 double run_gemm(std::int32_t n, std::int32_t m, std::int32_t k,
     97                 std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* result) {
     98   gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
     99       true, false, false, m, n, k, rhs, -100, k, lhs, -100, k, result, 10000,
    100       10, 3, m, gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8);
    101   return static_cast<double>(n * m * k * 2);
    102 }
    103 
    104 double run_gemms(std::vector<Shape>* shapes) {
    105   double ops = 0.0;
    106   for (auto& shape : *shapes) {
    107     ops += run_gemm(shape.n, shape.m, shape.k, shape.working_set().lhs,
    108                     shape.working_set().rhs, shape.working_set().result);
    109   }
    110   return ops;
    111 }
    112 
    113 void print_summary(std::vector<double>* times, bool full) {
    114   std::sort(times->begin(), times->end());
    115 
    116   double sum_times = 0;
    117   double sum_times_trimmed = 0;
    118   int count_times_trimmed = 0;
    119   const float trim_ratio = 0.25;
    120   const size_t count_trimmed = times->size() * trim_ratio;
    121   double sum_times_best = 0;
    122   int count_times_best = 0;
    123   const float best_ratio = 0.1;
    124   const size_t count_best = times->size() * best_ratio;
    125 
    126   for (size_t i = 0; i < times->size(); i++) {
    127     sum_times += (*times)[i];
    128     if (i >= count_trimmed && i < times->size() - count_trimmed) {
    129       sum_times_trimmed += (*times)[i];
    130       count_times_trimmed++;
    131     }
    132     if (i < count_best) {
    133       sum_times_best += (*times)[i];
    134       count_times_best++;
    135     }
    136   }
    137 
    138   const double min_latency = times->front();
    139   const double max_latency = times->back();
    140   const double mean_latency = sum_times / times->size();
    141   const double trimmed_mean_latency = sum_times_trimmed / count_times_trimmed;
    142   const double best_mean_latency = sum_times_best / count_times_best;
    143 
    144   if (full) {
    145     std::cout << "Graph latency (over " << times->size()
    146               << " iterations):" << std::endl;
    147     std::cout << "  Best:             " << min_latency << "s" << std::endl;
    148     std::cout << "  Worst:            " << max_latency << "s" << std::endl;
    149     std::cout << "  Mean:             " << mean_latency << "s" << std::endl;
    150     std::cout << "  " << 100 * trim_ratio
    151               << "% trimmed mean: " << trimmed_mean_latency << "s" << std::endl;
    152     std::cout << "  Mean of " << 100 * best_ratio
    153               << "% best: " << best_mean_latency << "s" << std::endl;
    154   } else {
    155     std::cout << (mean_latency * 1000.0) << std::endl;
    156   }
    157 }
    158 
    159 void time_all(std::vector<Shape>* shapes, std::int32_t repetitions,
    160               double max_time) {
    161   std::vector<double> times;
    162   double ops = 0.0;
    163   double sum_time = 0.0;
    164 
    165   while (sum_time < max_time) {
    166     double start = time();
    167 
    168     for (int i = 0; i < repetitions; ++i) {
    169       ops += run_gemms(shapes);
    170     }
    171     double delta_time = (time() - start);
    172     times.push_back(delta_time / repetitions);
    173     sum_time += delta_time;
    174   }
    175 
    176   print_summary(&times, true);
    177 }
    178 
    179 void time_one(Shape* shape, double max_time) {
    180   std::vector<double> times;
    181   double ops = 0.0;
    182   double sum_time = 0.0;
    183 
    184   std::cout << std::setprecision(6) << std::fixed << shape->n << ", "
    185             << shape->m << ", " << shape->k << ", " << std::flush;
    186 
    187   while (sum_time < max_time) {
    188     double start = time();
    189 
    190     for (int i = 0; i < shape->repetitions; ++i) {
    191       ops += run_gemm(shape->n, shape->m, shape->k, shape->working_set().lhs,
    192                       shape->working_set().rhs, shape->working_set().result);
    193       shape->next_working_set();
    194     }
    195     double delta_time = (time() - start);
    196     times.push_back(delta_time / shape->repetitions);
    197     sum_time += delta_time;
    198   }
    199 
    200   print_summary(&times, false);
    201 }
    202 
    203 int main() {
    204   std::vector<Shape> googlenet_gemms;
    205   googlenet_gemms.push_back(Shape(12544, 64, 147));
    206   googlenet_gemms.push_back(Shape(3136, 64, 64));
    207   googlenet_gemms.push_back(Shape(3136, 192, 576));
    208   googlenet_gemms.push_back(Shape(784, 64, 192));
    209   googlenet_gemms.push_back(Shape(784, 96, 192));
    210   googlenet_gemms.push_back(Shape(784, 128, 864));
    211   googlenet_gemms.push_back(Shape(784, 16, 192));
    212   googlenet_gemms.push_back(Shape(784, 32, 400));
    213   googlenet_gemms.push_back(Shape(784, 32, 192));
    214   googlenet_gemms.push_back(Shape(784, 128, 256));
    215   googlenet_gemms.push_back(Shape(784, 128, 256));
    216   googlenet_gemms.push_back(Shape(784, 192, 1152));
    217   googlenet_gemms.push_back(Shape(784, 32, 256));
    218   googlenet_gemms.push_back(Shape(784, 96, 800));
    219   googlenet_gemms.push_back(Shape(784, 64, 256));
    220   googlenet_gemms.push_back(Shape(196, 192, 480));
    221   googlenet_gemms.push_back(Shape(196, 96, 480));
    222   googlenet_gemms.push_back(Shape(196, 204, 864));
    223   googlenet_gemms.push_back(Shape(196, 16, 480));
    224   googlenet_gemms.push_back(Shape(196, 48, 400));
    225   googlenet_gemms.push_back(Shape(196, 64, 480));
    226   googlenet_gemms.push_back(Shape(196, 160, 508));
    227   googlenet_gemms.push_back(Shape(196, 112, 508));
    228   googlenet_gemms.push_back(Shape(196, 224, 1008));
    229   googlenet_gemms.push_back(Shape(196, 24, 508));
    230   googlenet_gemms.push_back(Shape(196, 64, 600));
    231   googlenet_gemms.push_back(Shape(196, 64, 508));
    232   googlenet_gemms.push_back(Shape(196, 128, 512));
    233   googlenet_gemms.push_back(Shape(196, 128, 512));
    234   googlenet_gemms.push_back(Shape(196, 256, 1152));
    235   googlenet_gemms.push_back(Shape(196, 24, 512));
    236   googlenet_gemms.push_back(Shape(196, 64, 600));
    237   googlenet_gemms.push_back(Shape(196, 64, 512));
    238   googlenet_gemms.push_back(Shape(196, 112, 512));
    239   googlenet_gemms.push_back(Shape(196, 144, 512));
    240   googlenet_gemms.push_back(Shape(196, 288, 1296));
    241   googlenet_gemms.push_back(Shape(196, 32, 512));
    242   googlenet_gemms.push_back(Shape(196, 64, 800));
    243   googlenet_gemms.push_back(Shape(196, 64, 512));
    244   googlenet_gemms.push_back(Shape(196, 256, 528));
    245   googlenet_gemms.push_back(Shape(196, 160, 528));
    246   googlenet_gemms.push_back(Shape(196, 320, 1440));
    247   googlenet_gemms.push_back(Shape(196, 32, 528));
    248   googlenet_gemms.push_back(Shape(196, 128, 800));
    249   googlenet_gemms.push_back(Shape(196, 128, 528));
    250   googlenet_gemms.push_back(Shape(49, 256, 832));
    251   googlenet_gemms.push_back(Shape(49, 160, 832));
    252   googlenet_gemms.push_back(Shape(49, 320, 1440));
    253   googlenet_gemms.push_back(Shape(49, 48, 832));
    254   googlenet_gemms.push_back(Shape(49, 128, 1200));
    255   googlenet_gemms.push_back(Shape(49, 128, 832));
    256   googlenet_gemms.push_back(Shape(49, 384, 832));
    257   googlenet_gemms.push_back(Shape(49, 192, 832));
    258   googlenet_gemms.push_back(Shape(49, 384, 1728));
    259   googlenet_gemms.push_back(Shape(49, 48, 832));
    260   googlenet_gemms.push_back(Shape(49, 128, 1200));
    261   googlenet_gemms.push_back(Shape(49, 128, 832));
    262   googlenet_gemms.push_back(Shape(16, 128, 508));
    263   googlenet_gemms.push_back(Shape(1, 1024, 2048));
    264   googlenet_gemms.push_back(Shape(1, 1008, 1024));
    265   googlenet_gemms.push_back(Shape(16, 128, 528));
    266   googlenet_gemms.push_back(Shape(1, 1024, 2048));
    267   googlenet_gemms.push_back(Shape(1, 1008, 1024));
    268   googlenet_gemms.push_back(Shape(1, 1008, 1024));
    269 
    270   for (auto& shape : googlenet_gemms) {
    271     shape.init();
    272   }
    273 
    274   std::vector<Shape> small_gemms;
    275   small_gemms.push_back(Shape(29232, 16, 25));
    276   small_gemms.push_back(Shape(7308, 6, 400));
    277   small_gemms.push_back(Shape(203, 3002, 216));
    278 
    279   for (auto& shape : small_gemms) {
    280     shape.init();
    281   }
    282 
    283   std::vector<Shape> others;
    284   others.push_back(Shape(100, 100, 100));
    285   others.push_back(Shape(1000, 1000, 1000));
    286   others.push_back(Shape(2000, 1000, 1000));
    287 
    288   for (auto& shape : others) {
    289     shape.init();
    290   }
    291 
    292   gemmlowp::eight_bit_int_gemm::SetMaxNumThreads(4);
    293 
    294   std::cout << "Warmup run." << std::endl;
    295   time_all(&googlenet_gemms, 10, 1.0);
    296   time_all(&small_gemms, 50, 1.0);
    297 
    298   std::cout << "Timing all." << std::endl;
    299   time_all(&googlenet_gemms, 10, 20.0);
    300   time_all(&small_gemms, 50, 10.0);
    301 
    302   std::cout << "Timing separate." << std::endl;
    303 
    304   for (auto& shape : googlenet_gemms) {
    305     time_one(&shape, 0.10);
    306   }
    307 
    308   for (auto& shape : small_gemms) {
    309     time_one(&shape, 0.10);
    310   }
    311 
    312   for (auto& shape : others) {
    313     time_one(&shape, 0.10);
    314   }
    315 
    316   return 0;
    317 }
    318