Home | History | Annotate | Download | only in meta
      1 // Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include <unistd.h>
     16 #ifdef __APPLE__
     17 #include <sys/time.h>
     18 #endif
     19 
     20 #include <cstdint>
     21 #include <cstdlib>
     22 #include <ctime>
     23 #include <iomanip>
     24 #include <iostream>
     25 #include <map>
     26 #include <memory>
     27 #include <vector>
     28 
     29 #include "multi_thread_gemm.h"
     30 #include "quantized_mul_kernels.h"
     31 #include "single_thread_gemm.h"
     32 #include "streams.h"
     33 
     34 #define LHS_OFFSET (-127)
     35 #define RHS_OFFSET (-127)
     36 #define SUM_OFFSET (127)
     37 #define MUL_OFFSET (1)
     38 #define SHIFT (7)
     39 #define FLOAT_SCALE (0.333f)
     40 
     41 using namespace gemmlowp::meta;
     42 
     43 // Input, output & kernel setups.
     44 
     45 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, ColumnMajorWithSum,
     46                    QuantizedStaticPreprocessed, RowMajor>
     47     ParamsColumnMajor;
     48 
     49 typedef GemmParams<std::uint8_t, std::uint8_t, RowMajorWithSum, RowMajorWithSum,
     50                    QuantizedStaticPreprocessed, RowMajor>
     51     ParamsRowMajor;
     52 
     53 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, ColumnMajorWithSum,
     54                    QuantizedStaticPreprocessedAsFloat, RowMajor>
     55     ParamsColumnMajorAsFloat;
     56 
     57 typedef GemmParams<std::uint8_t, float, RowMajorWithSum, RowMajorWithSum,
     58                    QuantizedStaticPreprocessedAsFloat, RowMajor>
     59     ParamsRowMajorAsFloat;
     60 
     61 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, ColumnMajorWithSum,
     62                    QuantizedStaticPreprocessedAsInt32, RowMajor>
     63     ParamsColumnMajorAsInt32;
     64 
     65 typedef GemmParams<std::uint8_t, std::int32_t, RowMajorWithSum, RowMajorWithSum,
     66                    QuantizedStaticPreprocessedAsInt32, RowMajor>
     67     ParamsRowMajorAsInt32;
     68 
     69 typedef gemmlowp::WorkersPool Pool;
     70 typedef SimpleContext<gemmlowp::WorkersPool> Context;
     71 
     72 #ifdef LHS_PACK
     73 typedef GemmExecutorPackLHSCacheFriendly<> Executor;
     74 #else
     75 typedef GemmExecutorPackRHSCacheFriendly<> Executor;
     76 #endif
     77 
     78 // Testing helper functions.
     79 
     80 void prepare_test_data(std::uint8_t* data, std::int32_t rows, std::int32_t cols,
     81                        std::int32_t seed, std::int32_t seed_2) {
     82   std::int32_t value = seed;
     83   for (int i = 0; i < rows * cols; ++i) {
     84     data[i] = static_cast<std::uint8_t>(value);
     85     value = ((value * seed_2) + seed) % 256;
     86   }
     87 }
     88 
     89 template <typename CLEAR_TYPE>
     90 void clear(int rows, int cols, CLEAR_TYPE* data) {
     91   for (int i = 0; i < rows * cols; ++i) {
     92     data[i] = 0;
     93   }
     94 }
     95 
     96 bool check_row_row(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows,
     97                    int cols, int depth) {
     98   int wrong = 0;
     99   int rounding = (1 << (SHIFT - 1));
    100   for (int i = 0; i < rows; ++i) {
    101     for (int j = 0; j < cols; ++j) {
    102       int expected = 0;
    103       for (int k = 0; k < depth; ++k) {
    104         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
    105                     (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
    106       }
    107       expected += SUM_OFFSET * depth;
    108       expected *= MUL_OFFSET;
    109       expected += rounding;
    110       expected = (expected >> SHIFT);
    111       if (expected < 0) {
    112         expected = 0;
    113       } else if (expected > 255) {
    114         expected = 255;
    115       }
    116       expected = static_cast<int>(static_cast<std::uint8_t>(expected));
    117       int actual = static_cast<int>(results[i * cols + j]);
    118       if (actual != expected) {
    119         std::cout << "Wrong @" << i << "x" << j << " : " << actual
    120                   << " != " << expected << std::endl;
    121         wrong++;
    122       }
    123     }
    124   }
    125   if (wrong != 0) {
    126     std::cout << wrong << "/" << (rows * cols) << std::endl;
    127   }
    128   return wrong == 0;
    129 }
    130 
    131 bool check_row_col(std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* results, int rows,
    132                    int cols, int depth) {
    133   int wrong = 0;
    134   int rounding = (1 << (SHIFT - 1));
    135   for (int i = 0; i < rows; ++i) {
    136     for (int j = 0; j < cols; ++j) {
    137       int expected = 0;
    138       for (int k = 0; k < depth; ++k) {
    139         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
    140                     (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
    141       }
    142       expected += SUM_OFFSET * depth;
    143       expected *= MUL_OFFSET;
    144       expected += rounding;
    145       expected = (expected >> SHIFT);
    146       if (expected < 0) {
    147         expected = 0;
    148       } else if (expected > 255) {
    149         expected = 255;
    150       }
    151       expected = static_cast<int>(static_cast<std::uint8_t>(expected));
    152       int actual = static_cast<int>(results[i * cols + j]);
    153       if (actual != expected) {
    154         wrong++;
    155       }
    156     }
    157   }
    158   return wrong == 0;
    159 }
    160 
    161 bool check_row_row_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows,
    162                      int cols, int depth) {
    163   int wrong = 0;
    164   for (int i = 0; i < rows; ++i) {
    165     for (int j = 0; j < cols; ++j) {
    166       int expected = 0;
    167       for (int k = 0; k < depth; ++k) {
    168         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
    169                     (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
    170       }
    171       float expected_float = static_cast<float>(expected) * FLOAT_SCALE;
    172       float actual = results[i * cols + j];
    173       if (actual != expected_float) {
    174         wrong++;
    175       }
    176     }
    177   }
    178   return wrong == 0;
    179 }
    180 
    181 bool check_row_col_f(std::uint8_t* lhs, std::uint8_t* rhs, float* results, int rows,
    182                      int cols, int depth) {
    183   int wrong = 0;
    184   for (int i = 0; i < rows; ++i) {
    185     for (int j = 0; j < cols; ++j) {
    186       int expected = 0;
    187       for (int k = 0; k < depth; ++k) {
    188         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
    189                     (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
    190       }
    191       float expected_float = static_cast<float>(expected) * FLOAT_SCALE;
    192       float actual = results[i * cols + j];
    193       if (actual != expected_float) {
    194         wrong++;
    195       }
    196     }
    197   }
    198   return wrong == 0;
    199 }
    200 
    201 bool check_row_row_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows,
    202                        int cols, int depth) {
    203   int wrong = 0;
    204   for (int i = 0; i < rows; ++i) {
    205     for (int j = 0; j < cols; ++j) {
    206       int expected = 0;
    207       for (int k = 0; k < depth; ++k) {
    208         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
    209                     (static_cast<int>(rhs[depth * j + k]) + RHS_OFFSET);
    210       }
    211       int actual = results[i * cols + j];
    212       if (actual != expected) {
    213         wrong++;
    214       }
    215     }
    216   }
    217   return wrong == 0;
    218 }
    219 
    220 bool check_row_col_i32(std::uint8_t* lhs, std::uint8_t* rhs, std::int32_t* results, int rows,
    221                        int cols, int depth) {
    222   int wrong = 0;
    223   for (int i = 0; i < rows; ++i) {
    224     for (int j = 0; j < cols; ++j) {
    225       int expected = 0;
    226       for (int k = 0; k < depth; ++k) {
    227         expected += (static_cast<int>(lhs[depth * i + k]) + LHS_OFFSET) *
    228                     (static_cast<int>(rhs[j + k * cols]) + RHS_OFFSET);
    229       }
    230       int actual = results[i * cols + j];
    231       if (actual != expected) {
    232         wrong++;
    233       }
    234     }
    235   }
    236   return wrong == 0;
    237 }
    238 
    239 template <typename PARAMS, typename RESULT_TYPE>
    240 void setup_params(std::uint8_t* lhs, std::uint8_t* rhs, RESULT_TYPE* result,
    241                   std::uint8_t* scratch, PARAMS* params) {
    242   params->lhs = lhs;
    243   params->rhs = rhs;
    244   params->result = result;
    245   params->scratch = scratch;
    246 
    247   params->left_stream.multiplicative_sum_offset = RHS_OFFSET;
    248   params->left_stream.additive_sum_offset = 0;
    249 
    250   params->right_stream.multiplicative_sum_offset = LHS_OFFSET;
    251   params->right_stream.additive_sum_offset = 0;
    252 }
    253 
    254 void setup_row_row(int m, int n, int k, ParamsRowMajor* params) {
    255   params->m = m;
    256   params->n = n;
    257   params->k = k;
    258   params->left_stream.count = k;
    259   params->left_stream.stride = k;
    260   params->left_stream.additive_sum_offset =
    261       SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET;
    262   params->right_stream.count = k;
    263   params->right_stream.stride = k;
    264   params->fused_kernel.kernel.count = k;
    265   params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET;
    266   params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1));
    267   params->fused_kernel.kernel.shift = -SHIFT;
    268   params->fused_kernel.output_stream.stride = n;
    269 }
    270 
    271 void setup_row_col(int m, int n, int k, ParamsColumnMajor* params) {
    272   params->m = m;
    273   params->n = n;
    274   params->k = k;
    275   params->left_stream.count = k;
    276   params->left_stream.stride = k;
    277   params->left_stream.additive_sum_offset =
    278       SUM_OFFSET * k + k * LHS_OFFSET * RHS_OFFSET;
    279   params->right_stream.count = k;
    280   params->right_stream.stride = n;
    281   params->fused_kernel.kernel.count = k;
    282   params->fused_kernel.kernel.multiplicative_offset = MUL_OFFSET;
    283   params->fused_kernel.kernel.rounding_offset = (1 << (SHIFT - 1));
    284   params->fused_kernel.kernel.shift = -SHIFT;
    285   params->fused_kernel.output_stream.stride = n;
    286 }
    287 
    288 void setup_row_row_f(int m, int n, int k, ParamsRowMajorAsFloat* params) {
    289   params->m = m;
    290   params->n = n;
    291   params->k = k;
    292   params->left_stream.count = k;
    293   params->left_stream.stride = k;
    294   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
    295   params->right_stream.count = k;
    296   params->right_stream.stride = k;
    297   params->fused_kernel.kernel.count = k;
    298   params->fused_kernel.kernel.scale = FLOAT_SCALE;
    299   params->fused_kernel.output_stream.stride = n * sizeof(float);
    300 }
    301 
    302 void setup_row_col_f(int m, int n, int k, ParamsColumnMajorAsFloat* params) {
    303   params->m = m;
    304   params->n = n;
    305   params->k = k;
    306   params->left_stream.count = k;
    307   params->left_stream.stride = k;
    308   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
    309   params->right_stream.count = k;
    310   params->right_stream.stride = n;
    311   params->fused_kernel.kernel.count = k;
    312   params->fused_kernel.kernel.scale = FLOAT_SCALE;
    313   params->fused_kernel.output_stream.stride = n * sizeof(float);
    314 }
    315 
    316 void setup_row_row_i32(int m, int n, int k, ParamsRowMajorAsInt32* params) {
    317   params->m = m;
    318   params->n = n;
    319   params->k = k;
    320   params->left_stream.count = k;
    321   params->left_stream.stride = k;
    322   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
    323   params->right_stream.count = k;
    324   params->right_stream.stride = k;
    325   params->fused_kernel.kernel.count = k;
    326   params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t);
    327 }
    328 
    329 void setup_row_col_i32(int m, int n, int k, ParamsColumnMajorAsInt32* params) {
    330   params->m = m;
    331   params->n = n;
    332   params->k = k;
    333   params->left_stream.count = k;
    334   params->left_stream.stride = k;
    335   params->left_stream.additive_sum_offset = k * LHS_OFFSET * RHS_OFFSET;
    336   params->right_stream.count = k;
    337   params->right_stream.stride = n;
    338   params->fused_kernel.kernel.count = k;
    339   params->fused_kernel.output_stream.stride = n * sizeof(std::int32_t);
    340 }
    341 
    342 int main() {
    343   ParamsRowMajor params_row;
    344   ParamsColumnMajor params_col;
    345   ParamsRowMajorAsFloat params_row_f;
    346   ParamsColumnMajorAsFloat params_col_f;
    347   ParamsRowMajorAsInt32 params_row_i32;
    348   ParamsColumnMajorAsInt32 params_col_i32;
    349 
    350   std::unique_ptr<std::uint8_t> lhs(new std::uint8_t[1024 * 1024]);
    351   std::unique_ptr<std::uint8_t> rhs(new std::uint8_t[1024 * 1024]);
    352   std::unique_ptr<std::uint8_t> result(new std::uint8_t[1024 * 1024]);
    353   std::unique_ptr<float> result_f(new float[1024 * 1024]);
    354   std::unique_ptr<std::int32_t> result_i32(new std::int32_t[1024 * 1024]);
    355   std::unique_ptr<std::uint8_t> scratch(new std::uint8_t[4048 * 1024]);
    356 
    357   setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), &params_row);
    358   setup_params(lhs.get(), rhs.get(), result.get(), scratch.get(), &params_col);
    359   setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(),
    360                &params_row_f);
    361   setup_params(lhs.get(), rhs.get(), result_f.get(), scratch.get(),
    362                &params_col_f);
    363   setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(),
    364                &params_row_i32);
    365   setup_params(lhs.get(), rhs.get(), result_i32.get(), scratch.get(),
    366                &params_col_i32);
    367 
    368   Pool pool;
    369   Context context(4, &pool);
    370 
    371   for (int i = 1; i < 16; ++i) {
    372     for (int j = 1; j < 16; ++j) {
    373       for (int k = 1; k < 24; ++k) {
    374         prepare_test_data(lhs.get(), i, k, 11, 13);
    375         prepare_test_data(rhs.get(), j, k, 13, 17);
    376 
    377         clear(i, j, result.get());
    378         setup_row_row(i, j, k, &params_row);
    379         Gemm<Executor, ParamsRowMajor, 2, 4, 8>(params_row);
    380         if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) {
    381           std::cout << "Row: " << i << "x" << j << "x" << k << " : ERROR"
    382                     << std::endl;
    383           std::cout << "Exiting." << std::endl;
    384           std::exit(1);
    385         }
    386 
    387         clear(i, j, result.get());
    388         setup_row_col(i, j, k, &params_col);
    389         Gemm<Executor, ParamsColumnMajor, 2, 4, 8>(params_col);
    390         if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) {
    391           std::cout << "Column: " << i << "x" << j << "x" << k << " : ERROR"
    392                     << std::endl;
    393           std::cout << "Exiting." << std::endl;
    394           std::exit(1);
    395         }
    396 
    397         clear(i, j, result_f.get());
    398         setup_row_row_f(i, j, k, &params_row_f);
    399         Gemm<Executor, ParamsRowMajorAsFloat, 2, 4, 8>(params_row_f);
    400         if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
    401           std::cout << "RowAsFloat: " << i << "x" << j << "x" << k << " : ERROR"
    402                     << std::endl;
    403           std::cout << "Exiting." << std::endl;
    404           std::exit(1);
    405         }
    406 
    407         clear(i, j, result_f.get());
    408         setup_row_col_f(i, j, k, &params_col_f);
    409         Gemm<Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(params_col_f);
    410         if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
    411           std::cout << "ColumnAsFloat: " << i << "x" << j << "x" << k
    412                     << " : ERROR" << std::endl;
    413           std::cout << "Exiting." << std::endl;
    414           std::exit(1);
    415         }
    416 
    417         clear(i, j, result_i32.get());
    418         setup_row_row_i32(i, j, k, &params_row_i32);
    419         Gemm<Executor, ParamsRowMajorAsInt32, 2, 4, 8>(params_row_i32);
    420         if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
    421                                k)) {
    422           std::cout << "RowAsInt32: " << i << "x" << j << "x" << k << " : ERROR"
    423                     << std::endl;
    424           std::cout << "Exiting." << std::endl;
    425           std::exit(1);
    426         }
    427 
    428         clear(i, j, result_i32.get());
    429         setup_row_col_i32(i, j, k, &params_col_i32);
    430         Gemm<Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(params_col_i32);
    431         if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
    432                                k)) {
    433           std::cout << "ColumnAsInt32: " << i << "x" << j << "x" << k
    434                     << " : ERROR" << std::endl;
    435           std::cout << "Exiting." << std::endl;
    436           std::exit(1);
    437         }
    438       }
    439     }
    440   }
    441 
    442   for (int i = 1; i < 1024; i += 211) {
    443     for (int j = 1; j < 1024; j += 211) {
    444       for (int k = 8; k < 1024; k += 111) {
    445         prepare_test_data(lhs.get(), i, k, 11, 13);
    446         prepare_test_data(rhs.get(), j, k, 13, 17);
    447 
    448         clear(i, j, result.get());
    449         setup_row_row(i, j, k, &params_row);
    450         MultiThreadGemm<Context, Executor, ParamsRowMajor, 2, 4, 8>(&context,
    451                                                                     params_row);
    452         if (!check_row_row(lhs.get(), rhs.get(), result.get(), i, j, k)) {
    453           std::cout << "Row(MT): " << i << "x" << j << "x" << k << " : ERROR"
    454                     << std::endl;
    455           std::cout << "Exiting." << std::endl;
    456           std::exit(1);
    457         }
    458 
    459         clear(i, j, result.get());
    460         setup_row_col(i, j, k, &params_col);
    461         MultiThreadGemm<Context, Executor, ParamsColumnMajor, 2, 4, 8>(
    462             &context, params_col);
    463         if (!check_row_col(lhs.get(), rhs.get(), result.get(), i, j, k)) {
    464           std::cout << "Column(MT): " << i << "x" << j << "x" << k << " : ERROR"
    465                     << std::endl;
    466           std::cout << "Exiting." << std::endl;
    467           std::exit(1);
    468         }
    469 
    470         clear(i, j, result_f.get());
    471         setup_row_row_f(i, j, k, &params_row_f);
    472         MultiThreadGemm<Context, Executor, ParamsRowMajorAsFloat, 2, 4, 8>(
    473             &context, params_row_f);
    474         if (!check_row_row_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
    475           std::cout << "RowAsFloat(MT): " << i << "x" << j << "x" << k
    476                     << " : ERROR" << std::endl;
    477           std::cout << "Exiting." << std::endl;
    478           std::exit(1);
    479         }
    480 
    481         clear(i, j, result_f.get());
    482         setup_row_col_f(i, j, k, &params_col_f);
    483         MultiThreadGemm<Context, Executor, ParamsColumnMajorAsFloat, 2, 4, 8>(
    484             &context, params_col_f);
    485         if (!check_row_col_f(lhs.get(), rhs.get(), result_f.get(), i, j, k)) {
    486           std::cout << "ColumnAsFloat(MT): " << i << "x" << j << "x" << k
    487                     << " : ERROR" << std::endl;
    488           std::cout << "Exiting." << std::endl;
    489           std::exit(1);
    490         }
    491 
    492         clear(i, j, result_i32.get());
    493         setup_row_row_i32(i, j, k, &params_row_i32);
    494         MultiThreadGemm<Context, Executor, ParamsRowMajorAsInt32, 2, 4, 8>(
    495             &context, params_row_i32);
    496         if (!check_row_row_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
    497                                k)) {
    498           std::cout << "RowAsInt32(MT): " << i << "x" << j << "x" << k
    499                     << " : ERROR" << std::endl;
    500           std::cout << "Exiting." << std::endl;
    501           std::exit(1);
    502         }
    503 
    504         clear(i, j, result_i32.get());
    505         setup_row_col_i32(i, j, k, &params_col_i32);
    506         MultiThreadGemm<Context, Executor, ParamsColumnMajorAsInt32, 2, 4, 8>(
    507             &context, params_col_i32);
    508         if (!check_row_col_i32(lhs.get(), rhs.get(), result_i32.get(), i, j,
    509                                k)) {
    510           std::cout << "ColumnAsInt32(MT): " << i << "x" << j << "x" << k
    511                     << " : ERROR" << std::endl;
    512           std::cout << "Exiting." << std::endl;
    513           std::exit(1);
    514         }
    515       }
    516     }
    517   }
    518 
    519   std::cout << "OK." << std::endl;
    520   return 0;
    521 }
    522