Home | History | Annotate | Download | only in include
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 /* Header-only library for various helpers of test harness
     18  * See frameworks/ml/nn/runtime/test/TestGenerated.cpp for how this is used.
     19  */
     20 #ifndef ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
     21 #define ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
     22 
     23 #include <gtest/gtest.h>
     24 
     25 #include <cmath>
     26 #include <functional>
     27 #include <map>
     28 #include <tuple>
     29 #include <vector>
     30 
     31 namespace test_helper {
     32 
     33 constexpr const size_t gMaximumNumberOfErrorMessages = 10;
     34 
     35 typedef std::map<int, std::vector<float>> Float32Operands;
     36 typedef std::map<int, std::vector<int32_t>> Int32Operands;
     37 typedef std::map<int, std::vector<uint8_t>> Quant8Operands;
     38 typedef std::tuple<Float32Operands,  // ANEURALNETWORKS_TENSOR_FLOAT32
     39                    Int32Operands,    // ANEURALNETWORKS_TENSOR_INT32
     40                    Quant8Operands    // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
     41                    >
     42         MixedTyped;
     43 typedef std::pair<MixedTyped, MixedTyped> MixedTypedExampleType;
     44 
     45 template <typename T>
     46 struct MixedTypedIndex {};
     47 
     48 template <>
     49 struct MixedTypedIndex<float> {
     50     static constexpr size_t index = 0;
     51 };
     52 template <>
     53 struct MixedTypedIndex<int32_t> {
     54     static constexpr size_t index = 1;
     55 };
     56 template <>
     57 struct MixedTypedIndex<uint8_t> {
     58     static constexpr size_t index = 2;
     59 };
     60 
     61 // Go through all index-value pairs of a given input type
     62 template <typename T>
     63 inline void for_each(const MixedTyped& idx_and_data,
     64                      std::function<void(int, const std::vector<T>&)> execute) {
     65     for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
     66         execute(i.first, i.second);
     67     }
     68 }
     69 
     70 // non-const variant of for_each
     71 template <typename T>
     72 inline void for_each(MixedTyped& idx_and_data,
     73                      std::function<void(int, std::vector<T>&)> execute) {
     74     for (auto& i : std::get<MixedTypedIndex<T>::index>(idx_and_data)) {
     75         execute(i.first, i.second);
     76     }
     77 }
     78 
     79 // internal helper for for_all
     80 template <typename T>
     81 inline void for_all_internal(
     82         MixedTyped& idx_and_data,
     83         std::function<void(int, void*, size_t)> execute_this) {
     84     for_each<T>(idx_and_data, [&execute_this](int idx, std::vector<T>& m) {
     85         execute_this(idx, static_cast<void*>(m.data()), m.size() * sizeof(T));
     86     });
     87 }
     88 
     89 // Go through all index-value pairs of all input types
     90 // expects a functor that takes (int index, void *raw data, size_t sz)
     91 inline void for_all(MixedTyped& idx_and_data,
     92                     std::function<void(int, void*, size_t)> execute_this) {
     93     for_all_internal<float>(idx_and_data, execute_this);
     94     for_all_internal<int32_t>(idx_and_data, execute_this);
     95     for_all_internal<uint8_t>(idx_and_data, execute_this);
     96 }
     97 
     98 // Const variant of internal helper for for_all
     99 template <typename T>
    100 inline void for_all_internal(
    101         const MixedTyped& idx_and_data,
    102         std::function<void(int, const void*, size_t)> execute_this) {
    103     for_each<T>(idx_and_data, [&execute_this](int idx, const std::vector<T>& m) {
    104         execute_this(idx, static_cast<const void*>(m.data()), m.size() * sizeof(T));
    105     });
    106 }
    107 
    108 // Go through all index-value pairs (const variant)
    109 // expects a functor that takes (int index, const void *raw data, size_t sz)
    110 inline void for_all(
    111         const MixedTyped& idx_and_data,
    112         std::function<void(int, const void*, size_t)> execute_this) {
    113     for_all_internal<float>(idx_and_data, execute_this);
    114     for_all_internal<int32_t>(idx_and_data, execute_this);
    115     for_all_internal<uint8_t>(idx_and_data, execute_this);
    116 }
    117 
    118 // Helper template - resize test output per golden
    119 template <typename ty, size_t tuple_index>
    120 void resize_accordingly_(const MixedTyped& golden, MixedTyped& test) {
    121     std::function<void(int, const std::vector<ty>&)> execute =
    122             [&test](int index, const std::vector<ty>& m) {
    123                 auto& t = std::get<tuple_index>(test);
    124                 t[index].resize(m.size());
    125             };
    126     for_each<ty>(golden, execute);
    127 }
    128 
    129 inline void resize_accordingly(const MixedTyped& golden, MixedTyped& test) {
    130     resize_accordingly_<float, 0>(golden, test);
    131     resize_accordingly_<int32_t, 1>(golden, test);
    132     resize_accordingly_<uint8_t, 2>(golden, test);
    133 }
    134 
    135 template <typename ty, size_t tuple_index>
    136 void filter_internal(const MixedTyped& golden, MixedTyped* filtered,
    137                      std::function<bool(int)> is_ignored) {
    138     for_each<ty>(golden,
    139                  [filtered, &is_ignored](int index, const std::vector<ty>& m) {
    140                      auto& g = std::get<tuple_index>(*filtered);
    141                      if (!is_ignored(index)) g[index] = m;
    142                  });
    143 }
    144 
    145 inline MixedTyped filter(const MixedTyped& golden,
    146                          std::function<bool(int)> is_ignored) {
    147     MixedTyped filtered;
    148     filter_internal<float, 0>(golden, &filtered, is_ignored);
    149     filter_internal<int32_t, 1>(golden, &filtered, is_ignored);
    150     filter_internal<uint8_t, 2>(golden, &filtered, is_ignored);
    151     return filtered;
    152 }
    153 
    154 // Compare results
    155 #define VECTOR_TYPE(x) \
    156     typename std::tuple_element<x, MixedTyped>::type::mapped_type
    157 #define VALUE_TYPE(x) VECTOR_TYPE(x)::value_type
    158 template <size_t tuple_index>
    159 void compare_(
    160         const MixedTyped& golden, const MixedTyped& test,
    161         std::function<void(VALUE_TYPE(tuple_index), VALUE_TYPE(tuple_index))>
    162                 cmp) {
    163     for_each<VALUE_TYPE(tuple_index)>(
    164             golden,
    165             [&test, &cmp](int index, const VECTOR_TYPE(tuple_index) & m) {
    166                 const auto& test_operands = std::get<tuple_index>(test);
    167                 const auto& test_ty = test_operands.find(index);
    168                 ASSERT_NE(test_ty, test_operands.end());
    169                 for (unsigned int i = 0; i < m.size(); i++) {
    170                     SCOPED_TRACE(testing::Message()
    171                                  << "When comparing element " << i);
    172                     cmp(m[i], test_ty->second[i]);
    173                 }
    174             });
    175 }
    176 #undef VALUE_TYPE
    177 #undef VECTOR_TYPE
    178 inline void compare(const MixedTyped& golden, const MixedTyped& test, float fpRange = 1e-5f) {
    179     size_t totalNumberOfErrors = 0;
    180     compare_<0>(golden, test, [&totalNumberOfErrors, fpRange](float g, float t) {
    181         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    182             EXPECT_NEAR(g, t, fpRange);
    183         }
    184         if (std::abs(g - t) > fpRange) {
    185             totalNumberOfErrors++;
    186         }
    187     });
    188     compare_<1>(golden, test, [&totalNumberOfErrors](int32_t g, int32_t t) {
    189         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    190             EXPECT_EQ(g, t);
    191         }
    192         if (g != t) {
    193             totalNumberOfErrors++;
    194         }
    195     });
    196     compare_<2>(golden, test, [&totalNumberOfErrors](uint8_t g, uint8_t t) {
    197         if (totalNumberOfErrors < gMaximumNumberOfErrorMessages) {
    198             EXPECT_NEAR(g, t, 1);
    199         }
    200         if (std::abs(g - t) > 1) {
    201             totalNumberOfErrors++;
    202         }
    203     });
    204     EXPECT_EQ(size_t{0}, totalNumberOfErrors);
    205 }
    206 
    207 };  // namespace test_helper
    208 
    209 #endif  // ANDROID_ML_NN_TOOLS_TEST_GENERATOR_TEST_HARNESS_H
    210