1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "HashtableLookup.h" 18 19 #include "NeuralNetworksWrapper.h" 20 #include "gmock/gmock-matchers.h" 21 #include "gtest/gtest.h" 22 23 using ::testing::FloatNear; 24 using ::testing::Matcher; 25 26 namespace android { 27 namespace nn { 28 namespace wrapper { 29 30 namespace { 31 32 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values, 33 float max_abs_error=1.e-6) { 34 std::vector<Matcher<float>> matchers; 35 matchers.reserve(values.size()); 36 for (const float& v : values) { 37 matchers.emplace_back(FloatNear(v, max_abs_error)); 38 } 39 return matchers; 40 } 41 42 } // namespace 43 44 using ::testing::ElementsAreArray; 45 46 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ 47 ACTION(Lookup, int) \ 48 ACTION(Key, int) \ 49 ACTION(Value, float) 50 51 // For all output and intermediate states 52 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ 53 ACTION(Output, float) \ 54 ACTION(Hits, uint8_t) 55 56 class HashtableLookupOpModel { 57 public: 58 HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape, 59 std::initializer_list<uint32_t> key_shape, 60 std::initializer_list<uint32_t> value_shape) { 61 auto it_vs = value_shape.begin(); 62 rows_ = *it_vs++; 63 features_ = *it_vs; 64 65 std::vector<uint32_t> inputs; 66 67 // Input and weights 68 OperandType LookupTy(Type::TENSOR_INT32, lookup_shape); 69 inputs.push_back(model_.addOperand(&LookupTy)); 70 71 OperandType KeyTy(Type::TENSOR_INT32, key_shape); 72 inputs.push_back(model_.addOperand(&KeyTy)); 73 74 OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape); 75 inputs.push_back(model_.addOperand(&ValueTy)); 76 77 // Output and other intermediate state 78 std::vector<uint32_t> outputs; 79 80 std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end()); 81 out_dim.push_back(features_); 82 83 OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim); 84 outputs.push_back(model_.addOperand(&OutputOpndTy)); 85 86 OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0); 87 outputs.push_back(model_.addOperand(&HitsOpndTy)); 88 89 auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { 90 uint32_t sz = 1; 91 for (uint32_t d : dims) { sz *= d; } 92 return sz; 93 }; 94 95 Value_.insert(Value_.end(), multiAll(value_shape), 0.f); 96 Output_.insert(Output_.end(), multiAll(out_dim), 0.f); 97 Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0); 98 99 model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs); 100 model_.identifyInputsAndOutputs(inputs, outputs); 101 102 model_.finish(); 103 } 104 105 void Invoke() { 106 ASSERT_TRUE(model_.isValid()); 107 108 Compilation compilation(&model_); 109 compilation.finish(); 110 Execution execution(&compilation); 111 112 #define SetInputOrWeight(X, T) \ 113 ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \ 114 sizeof(T) * X##_.size()), \ 115 Result::NO_ERROR); 116 117 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); 118 119 #undef SetInputOrWeight 120 121 #define SetOutput(X, T) \ 122 ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \ 123 sizeof(T) * X##_.size()), \ 124 Result::NO_ERROR); 125 126 FOR_ALL_OUTPUT_TENSORS(SetOutput); 127 128 #undef SetOutput 129 130 ASSERT_EQ(execution.compute(), Result::NO_ERROR); 131 } 132 133 #define DefineSetter(X, T) \ 134 void Set##X(const std::vector<T>& f) { \ 135 X##_.insert(X##_.end(), f.begin(), f.end()); \ 136 } 137 138 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); 139 140 #undef DefineSetter 141 142 void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) { 143 for (uint32_t i = 0; i < rows_; i++) { 144 for (uint32_t j = 0; j < features_; j++) { 145 Value_[i * features_ + j] = function(i, j); 146 } 147 } 148 } 149 150 const std::vector<float>& GetOutput() const { return Output_; } 151 const std::vector<uint8_t>& GetHits() const { return Hits_; } 152 153 private: 154 Model model_; 155 uint32_t rows_; 156 uint32_t features_; 157 158 #define DefineTensor(X, T) std::vector<T> X##_; 159 160 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor); 161 FOR_ALL_OUTPUT_TENSORS(DefineTensor); 162 163 #undef DefineTensor 164 }; 165 166 TEST(HashtableLookupOpTest, BlackBoxTest) { 167 HashtableLookupOpModel m({4}, {3}, {3, 2}); 168 169 m.SetLookup({1234, -292, -11, 0}); 170 m.SetKey({-11, 0, 1234}); 171 m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); 172 173 m.Invoke(); 174 175 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 176 2.0, 2.1, // 2-rd item 177 0, 0, // Not found 178 0.0, 0.1, // 0-th item 179 1.0, 1.1, // 1-st item 180 }))); 181 EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({ 182 1, 0, 1, 1, 183 })); 184 185 } 186 187 } // namespace wrapper 188 } // namespace nn 189 } // namespace android 190