Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 /* Header-only library for various helpers of test harness
     18  * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used.
     19  */
     20 #ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
     21 #define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
     22 
     23 #include <gmock/gmock-matchers.h>
     24 #include <gtest/gtest.h>
     25 
     26 #include <cmath>
     27 #include <functional>
     28 #include <map>
     29 #include <tuple>
     30 #include <vector>
     31 
     32 namespace test_helper {
     33 
     34 constexpr const size_t gMaximumNumberOfErrorMessages = 10;
     35 
     36 // TODO: Figure out the build dependency to make including "CpuOperationUtils.h" work.
     37 inline void convertFloat16ToFloat32(const _Float16* input, std::vector<float>* output) {
     38     for (size_t i = 0; i < output->size(); ++i) {
     39         (*output)[i] = static_cast<float>(input[i]);
     40     }
     41 }
     42 
     43 // This class is a workaround for two issues our code relies on:
     44 // 1. sizeof(bool) is implementation defined.
     45 // 2. vector<bool> does not allow direct pointer access via the data() method.
     46 class bool8 {
     47    public:
     48     bool8() : mValue() {}
     49     /* implicit */ bool8(bool value) : mValue(value) {}
     50     inline operator bool() const { return mValue != 0; }
     51 
     52    private:
     53     uint8_t mValue;
     54 };
     55 
     56 static_assert(sizeof(bool8) == 1, "size of bool8 must be 8 bits");
     57 
     58 typedef std::map<int, std::vector<uint32_t>> OperandDimensions;
     59 typedef std::map<int, std::vector<float>> Float32Operands;
     60 typedef std::map<int, std::vector<int32_t>> Int32Operands;
     61 typedef std::map<int, std::vector<uint8_t>> Quant8AsymmOperands;
     62 typedef std::map<int, std::vector<int16_t>> Quant16SymmOperands;
     63 typedef std::map<int, std::vector<_Float16>> Float16Operands;
     64 typedef std::map<int, std::vector<bool8>> Bool8Operands;
     65 typedef std::map<int, std::vector<int8_t>> Quant8ChannelOperands;
     66 typedef std::map<int, std::vector<uint16_t>> Quant16AsymmOperands;
     67 typedef std::map<int, std::vector<int8_t>> Quant8SymmOperands;
     68 struct MixedTyped {
     69     static constexpr size_t kNumTypes = 9;
     70     OperandDimensions operandDimensions;
     71     Float32Operands float32Operands;
     72     Int32Operands int32Operands;
     73     Quant8AsymmOperands quant8AsymmOperands;
     74     Quant16SymmOperands quant16SymmOperands;
     75     Float16Operands float16Operands;
     76     Bool8Operands bool8Operands;
     77     Quant8ChannelOperands quant8ChannelOperands;
     78     Quant16AsymmOperands quant16AsymmOperands;
     79     Quant8SymmOperands quant8SymmOperands;
     80 };
     81 typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType;
     82 
     83 // Mixed-typed examples
     84 typedef struct {
     85     MixedTypedExampleType operands;
     86     // Specifies the RANDOM_MULTINOMIAL distribution tolerance.
     87     // If set to greater than zero, the input is compared as log-probabilities
     88     // to the output and must be within this tolerance to pass.
     89     float expectedMultinomialDistributionTolerance = 0.0;
     90 } MixedTypedExample;
     91 
     92 // Go through all index-value pairs of a given input type
     93 template <typename T>
     94 inline void for_each(const std::map<int, std::vector<T>>& idx_and_data,
     95                      std::function<void(int, const std::vector<T>&)> execute) {
     96     for (auto& i : idx_and_data) {
     97         execute(i.first, i.second);
     98     }
     99 }
    100 
    101 // non-const variant of for_each
    102 template <typename T>
    103 inline void for_each(std::map<int, std::vector<T>>& idx_and_data,
    104                      std::function<void(int, std::vector<T>&)> execute) {
    105     for (auto& i : idx_and_data) {
    106         execute(i.first, i.second);
    107     }
    108 }
    109 
    110 // Go through all index-value pairs of a given input type
    111 template <typename T>
    112 inline void for_each(const std::map<int, std::vector<T>>& golden,
    113                      std::map<int, std::vector<T>>& test,
    114                      std::function<void(int, const std::vector<T>&, std::vector<T>&)> execute) {
    115     for_each<T>(golden, [&test, &execute](int index, const std::vector<T>& g) {
    116         auto& t = test[index];
    117         execute(index, g, t);
    118     });
    119 }
    120 
    121 // Go through all index-value pairs of a given input type
    122 template <typename T>
    123 inline void for_each(
    124         const std::map<int, std::vector<T>>& golden, const std::map<int, std::vector<T>>& test,
    125         std::function<void(int, const std::vector<T>&, const std::vector<T>&)> execute) {
    126     for_each<T>(golden, [&test, &execute](int index, const std::vector<T>& g) {
    127         auto t = test.find(index);
    128         ASSERT_NE(t, test.end());
    129         execute(index, g, t->second);
    130     });
    131 }
    132 
    133 // internal helper for for_all
    134 template <typename T>
    135 inline void for_all_internal(std::map<int, std::vector<T>>& idx_and_data,
    136                              std::function<void(int, void*, size_t)> execute_this) {
    137     for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) {
    138         execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T));
    139     });
    140 }
    141 
    142 // Go through all index-value pairs of all input types
    143 // expects a functor that takes (int index, void *raw data, size_t sz)
    144 inline void for_all(MixedTyped& idx_and_data,
    145                     std::function<void(int, void*, size_t)> execute_this) {
    146     for_all_internal(idx_and_data.float32Operands, execute_this);
    147     for_all_internal(idx_and_data.int32Operands, execute_this);
    148     for_all_internal(idx_and_data.quant8AsymmOperands, execute_this);
    149     for_all_internal(idx_and_data.quant16SymmOperands, execute_this);
    150     for_all_internal(idx_and_data.float16Operands, execute_this);
    151     for_all_internal(idx_and_data.bool8Operands, execute_this);
    152     for_all_internal(idx_and_data.quant8ChannelOperands, execute_this);
    153     for_all_internal(idx_and_data.quant16AsymmOperands, execute_this);
    154     for_all_internal(idx_and_data.quant8SymmOperands, execute_this);
    155     static_assert(9 == MixedTyped::kNumTypes,
    156                   "Number of types in MixedTyped changed, but for_all function wasn't updated");
    157 }
    158 
    159 // Const variant of internal helper for for_all
    160 template <typename T>
    161 inline void for_all_internal(const std::map<int, std::vector<T>>& idx_and_data,
    162                              std::function<void(int, const void*, size_t)> execute_this) {
    163     for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) {
    164         execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T));
    165     });
    166 }
    167 
    168 // Go through all index-value pairs (const variant)
    169 // expects a functor that takes (int index, const void *raw data, size_t sz)
    170 inline void for_all(const MixedTyped& idx_and_data,
    171                     std::function<void(int, const void*, size_t)> execute_this) {
    172     for_all_internal(idx_and_data.float32Operands, execute_this);
    173     for_all_internal(idx_and_data.int32Operands, execute_this);
    174     for_all_internal(idx_and_data.quant8AsymmOperands, execute_this);
    175     for_all_internal(idx_and_data.quant16SymmOperands, execute_this);
    176     for_all_internal(idx_and_data.float16Operands, execute_this);
    177     for_all_internal(idx_and_data.bool8Operands, execute_this);
    178     for_all_internal(idx_and_data.quant8ChannelOperands, execute_this);
    179     for_all_internal(idx_and_data.quant16AsymmOperands, execute_this);
    180     for_all_internal(idx_and_data.quant8SymmOperands, execute_this);
    181     static_assert(
    182             9 == MixedTyped::kNumTypes,
    183             "Number of types in MixedTyped changed, but const for_all function wasn't updated");
    184 }
    185 
    186 // Helper template - resize test output per golden
    187 template <typename T>
    188 inline void resize_accordingly_(const std::map<int, std::vector<T>>& golden,
    189                                 std::map<int, std::vector<T>>& test) {
    190     for_each<T>(golden, test,
    191                 [](int, const std::vector<T>& g, std::vector<T>& t) { t.resize(g.size()); });
    192 }
    193 
    194 template <>
    195 inline void resize_accordingly_<uint32_t>(const OperandDimensions& golden,
    196                                           OperandDimensions& test) {
    197     for_each<uint32_t>(
    198             golden, test,
    199             [](int, const std::vector<uint32_t>& g, std::vector<uint32_t>& t) { t = g; });
    200 }
    201 
    202 inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) {
    203     resize_accordingly_(golden.operandDimensions, test.operandDimensions);
    204     resize_accordingly_(golden.float32Operands, test.float32Operands);
    205     resize_accordingly_(golden.int32Operands, test.int32Operands);
    206     resize_accordingly_(golden.quant8AsymmOperands, test.quant8AsymmOperands);
    207     resize_accordingly_(golden.quant16SymmOperands, test.quant16SymmOperands);
    208     resize_accordingly_(golden.float16Operands, test.float16Operands);
    209     resize_accordingly_(golden.bool8Operands, test.bool8Operands);
    210     resize_accordingly_(golden.quant8ChannelOperands, test.quant8ChannelOperands);
    211     resize_accordingly_(golden.quant16AsymmOperands, test.quant16AsymmOperands);
    212     resize_accordingly_(golden.quant8SymmOperands, test.quant8SymmOperands);
    213     static_assert(9 == MixedTyped::kNumTypes,
    214                   "Number of types in MixedTyped changed, but resize_accordingly function wasn't "
    215                   "updated");
    216 }
    217 
    218 template <typename T>
    219 void filter_internal(const std::map<int, std::vector<T>>& golden,
    220                      std::map<int, std::vector<T>>* filtered, std::function<bool(int)> is_ignored) {
    221     for_each<T>(golden, [filtered, &is_ignored](int index, const std::vector<T>& m) {
    222         auto& g = *filtered;
    223         if (!is_ignored(index)) g[index] = m;
    224     });
    225 }
    226 
    227 inline MixedTyped filter(const MixedTyped& golden,
    228                          std::function<bool(int)> is_ignored) {
    229     MixedTyped filtered;
    230     filter_internal(golden.operandDimensions, &filtered.operandDimensions, is_ignored);
    231     filter_internal(golden.float32Operands, &filtered.float32Operands, is_ignored);
    232     filter_internal(golden.int32Operands, &filtered.int32Operands, is_ignored);
    233     filter_internal(golden.quant8AsymmOperands, &filtered.quant8AsymmOperands, is_ignored);
    234     filter_internal(golden.quant16SymmOperands, &filtered.quant16SymmOperands, is_ignored);
    235     filter_internal(golden.float16Operands, &filtered.float16Operands, is_ignored);
    236     filter_internal(golden.bool8Operands, &filtered.bool8Operands, is_ignored);
    237     filter_internal(golden.quant8ChannelOperands, &filtered.quant8ChannelOperands, is_ignored);
    238     filter_internal(golden.quant16AsymmOperands, &filtered.quant16AsymmOperands, is_ignored);
    239     filter_internal(golden.quant8SymmOperands, &filtered.quant8SymmOperands, is_ignored);
    240     static_assert(9 == MixedTyped::kNumTypes,
    241                   "Number of types in MixedTyped changed, but compare function wasn't updated");
    242     return filtered;
    243 }
    244 
    245 // Compare results
    246 template <typename T>
    247 void compare_(const std::map<int, std::vector<T>>& golden,
    248               const std::map<int, std::vector<T>>& test, std::function<void(T, T)> cmp) {
    249     for_each<T>(golden, test, [&cmp](int index, const std::vector<T>& g, const std::vector<T>& t) {
    250         for (unsigned int i = 0; i < g.size(); i++) {
    251             SCOPED_TRACE(testing::Message()
    252                          << "When comparing output " << index << " element " << i);
    253             cmp(g[i], t[i]);
    254         }
    255     });
    256 }
    257 
    258 // TODO: Allow passing accuracy criteria from spec.
    259 // Currently we only need relaxed accuracy criteria on mobilenet tests, so we return the quant8
    260 // tolerance simply based on the current test name.
    261 inline int getQuant8AllowedError() {
    262     const ::testing::TestInfo* const testInfo =
    263             ::testing::UnitTest::GetInstance()->current_test_info();
    264     const std::string testCaseName = testInfo->test_case_name();
    265     const std::string testName = testInfo->name();
    266     // We relax the quant8 precision for all tests with mobilenet:
    267     // - CTS/VTS GeneratedTest and DynamicOutputShapeTest with mobilenet
    268     // - VTS CompilationCachingTest and CompilationCachingSecurityTest except for TOCTOU tests
    269     if (testName.find("mobilenet") != std::string::npos ||
    270         (testCaseName.find("CompilationCaching") != std::string::npos &&
    271          testName.find("TOCTOU") == std::string::npos)) {
    272         return 2;
    273     } else {
    274         return 1;
    275     }
    276 }
    277 
    278 inline void compare(const MixedTyped& golden, const MixedTyped& test,
    279                     float fpAtol = 1e-5f, float fpRtol = 1e-5f) {
    280     int quant8AllowedError = getQuant8AllowedError();
    281     for_each<uint32_t>(
    282             golden.operandDimensions, test.operandDimensions,
    283             [](int index, const std::vector<uint32_t>& g, const std::vector<uint32_t>& t) {
    284                 SCOPED_TRACE(testing::Message()
    285                              << "When comparing dimensions for output " << index);
    286                 EXPECT_EQ(g, t);
    287             });
    288     size_t totalNumberOfErrors = 0;
    289     compare_<float>(golden.float32Operands, test.float32Operands,
    290                     [&totalNumberOfErrors, fpAtol, fpRtol](float expected, float actual) {
    291                         // Compute the range based on both absolute tolerance and relative tolerance
    292                         float fpRange = fpAtol + fpRtol * std::abs(expected);
    293                         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    294                             EXPECT_NEAR(expected, actual, fpRange);
    295                         }
    296                         if (std::abs(expected - actual) > fpRange) {
    297                             totalNumberOfErrors++;
    298                         }
    299                     });
    300     compare_<int32_t>(golden.int32Operands, test.int32Operands,
    301                       [&totalNumberOfErrors](int32_t expected, int32_t actual) {
    302                           if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    303                               EXPECT_EQ(expected, actual);
    304                           }
    305                           if (expected != actual) {
    306                               totalNumberOfErrors++;
    307                           }
    308                       });
    309     compare_<uint8_t>(golden.quant8AsymmOperands, test.quant8AsymmOperands,
    310                       [&totalNumberOfErrors, quant8AllowedError](uint8_t expected, uint8_t actual) {
    311                           if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    312                               EXPECT_NEAR(expected, actual, quant8AllowedError);
    313                           }
    314                           if (std::abs(expected - actual) > quant8AllowedError) {
    315                               totalNumberOfErrors++;
    316                           }
    317                       });
    318     compare_<int16_t>(golden.quant16SymmOperands, test.quant16SymmOperands,
    319                       [&totalNumberOfErrors](int16_t expected, int16_t actual) {
    320                           if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    321                               EXPECT_NEAR(expected, actual, 1);
    322                           }
    323                           if (std::abs(expected - actual) > 1) {
    324                               totalNumberOfErrors++;
    325                           }
    326                       });
    327     compare_<_Float16>(golden.float16Operands, test.float16Operands,
    328                        [&totalNumberOfErrors, fpAtol, fpRtol](_Float16 expected, _Float16 actual) {
    329                            // Compute the range based on both absolute tolerance and relative
    330                            // tolerance
    331                            float fpRange = fpAtol + fpRtol * std::abs(static_cast<float>(expected));
    332                            if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    333                                EXPECT_NEAR(expected, actual, fpRange);
    334                            }
    335                            if (std::abs(static_cast<float>(expected - actual)) > fpRange) {
    336                                totalNumberOfErrors++;
    337                            }
    338                        });
    339     compare_<bool8>(golden.bool8Operands, test.bool8Operands,
    340                     [&totalNumberOfErrors](bool expected, bool actual) {
    341                         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    342                             EXPECT_EQ(expected, actual);
    343                         }
    344                         if (expected != actual) {
    345                             totalNumberOfErrors++;
    346                         }
    347                     });
    348     compare_<int8_t>(golden.quant8ChannelOperands, test.quant8ChannelOperands,
    349                      [&totalNumberOfErrors, &quant8AllowedError](int8_t expected, int8_t actual) {
    350                          if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    351                              EXPECT_NEAR(expected, actual, quant8AllowedError);
    352                          }
    353                          if (std::abs(static_cast<int>(expected) - static_cast<int>(actual)) >
    354                              quant8AllowedError) {
    355                              totalNumberOfErrors++;
    356                          }
    357                      });
    358     compare_<uint16_t>(golden.quant16AsymmOperands, test.quant16AsymmOperands,
    359                        [&totalNumberOfErrors](int16_t expected, int16_t actual) {
    360                            if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    361                                EXPECT_NEAR(expected, actual, 1);
    362                            }
    363                            if (std::abs(expected - actual) > 1) {
    364                                totalNumberOfErrors++;
    365                            }
    366                        });
    367     compare_<int8_t>(golden.quant8SymmOperands, test.quant8SymmOperands,
    368                      [&totalNumberOfErrors, quant8AllowedError](int8_t expected, int8_t actual) {
    369                          if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    370                              EXPECT_NEAR(expected, actual, quant8AllowedError);
    371                          }
    372                          if (std::abs(static_cast<int>(expected) - static_cast<int>(actual)) >
    373                              quant8AllowedError) {
    374                              totalNumberOfErrors++;
    375                          }
    376                      });
    377 
    378     static_assert(9 == MixedTyped::kNumTypes,
    379                   "Number of types in MixedTyped changed, but compare function wasn't updated");
    380     EXPECT_EQ(size_t{0}, totalNumberOfErrors);
    381 }
    382 
    383 // Calculates the expected probability from the unnormalized log-probability of
    384 // each class in the input and compares it to the actual ocurrence of that class
    385 // in the output.
    386 inline void expectMultinomialDistributionWithinTolerance(const MixedTyped& test,
    387                                                          const MixedTypedExample& example) {
    388     // TODO: These should be parameters but aren't currently preserved in the example.
    389     const int kBatchSize = 1;
    390     const int kNumClasses = 1024;
    391     const int kNumSamples = 128;
    392 
    393     std::vector<int32_t> output = test.int32Operands.at(0);
    394     std::vector<int> class_counts;
    395     class_counts.resize(kNumClasses);
    396     for (int index : output) {
    397         class_counts[index]++;
    398     }
    399     std::vector<float> input;
    400     Float32Operands float32Operands = example.operands.first.float32Operands;
    401     if (!float32Operands.empty()) {
    402         input = example.operands.first.float32Operands.at(0);
    403     } else {
    404         std::vector<_Float16> inputFloat16 = example.operands.first.float16Operands.at(0);
    405         input.resize(inputFloat16.size());
    406         convertFloat16ToFloat32(inputFloat16.data(), &input);
    407     }
    408     for (int b = 0; b < kBatchSize; ++b) {
    409         float probability_sum = 0;
    410         const int batch_index = kBatchSize * b;
    411         for (int i = 0; i < kNumClasses; ++i) {
    412             probability_sum += expf(input[batch_index + i]);
    413         }
    414         for (int i = 0; i < kNumClasses; ++i) {
    415             float probability =
    416                     static_cast<float>(class_counts[i]) / static_cast<float>(kNumSamples);
    417             float probability_expected = expf(input[batch_index + i]) / probability_sum;
    418             EXPECT_THAT(probability,
    419                         ::testing::FloatNear(probability_expected,
    420                                              example.expectedMultinomialDistributionTolerance));
    421         }
    422     }
    423 }
    424 
    425 };  // namespace test_helper
    426 
    427 #endif  // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
    428