1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "LSHProjection.h" 18 19 #include "NeuralNetworksWrapper.h" 20 #include "gmock/gmock-generated-matchers.h" 21 #include "gmock/gmock-matchers.h" 22 #include "gtest/gtest.h" 23 24 using ::testing::FloatNear; 25 using ::testing::Matcher; 26 27 namespace android { 28 namespace nn { 29 namespace wrapper { 30 31 using ::testing::ElementsAre; 32 33 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ 34 ACTION(Hash, float) \ 35 ACTION(Input, int) \ 36 ACTION(Weight, float) 37 38 // For all output and intermediate states 39 #define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, int) 40 41 class LSHProjectionOpModel { 42 public: 43 LSHProjectionOpModel(LSHProjectionType type, std::initializer_list<uint32_t> hash_shape, 44 std::initializer_list<uint32_t> input_shape, 45 std::initializer_list<uint32_t> weight_shape) 46 : type_(type) { 47 std::vector<uint32_t> inputs; 48 49 OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape); 50 inputs.push_back(model_.addOperand(&HashTy)); 51 OperandType InputTy(Type::TENSOR_INT32, input_shape); 52 inputs.push_back(model_.addOperand(&InputTy)); 53 OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape); 54 inputs.push_back(model_.addOperand(&WeightTy)); 55 56 OperandType TypeParamTy(Type::INT32, {}); 57 inputs.push_back(model_.addOperand(&TypeParamTy)); 58 59 std::vector<uint32_t> outputs; 60 61 auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t { 62 uint32_t sz = 1; 63 for (uint32_t d : dims) { 64 sz *= d; 65 } 66 return sz; 67 }; 68 69 uint32_t outShapeDimension = 0; 70 if (type == LSHProjectionType_SPARSE || type == LSHProjectionType_SPARSE_DEPRECATED) { 71 auto it = hash_shape.begin(); 72 Output_.insert(Output_.end(), *it, 0.f); 73 outShapeDimension = *it; 74 } else { 75 Output_.insert(Output_.end(), multiAll(hash_shape), 0.f); 76 outShapeDimension = multiAll(hash_shape); 77 } 78 79 OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension}); 80 outputs.push_back(model_.addOperand(&OutputTy)); 81 82 model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs); 83 model_.identifyInputsAndOutputs(inputs, outputs); 84 85 model_.finish(); 86 } 87 88 #define DefineSetter(X, T) \ 89 void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); } 90 91 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); 92 93 #undef DefineSetter 94 95 const std::vector<int>& GetOutput() const { return Output_; } 96 97 void Invoke() { 98 ASSERT_TRUE(model_.isValid()); 99 100 Compilation compilation(&model_); 101 compilation.finish(); 102 Execution execution(&compilation); 103 104 #define SetInputOrWeight(X, T) \ 105 ASSERT_EQ( \ 106 execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), sizeof(T) * X##_.size()), \ 107 Result::NO_ERROR); 108 109 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); 110 111 #undef SetInputOrWeight 112 113 #define SetOutput(X, T) \ 114 ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \ 115 sizeof(T) * X##_.size()), \ 116 Result::NO_ERROR); 117 118 FOR_ALL_OUTPUT_TENSORS(SetOutput); 119 120 #undef SetOutput 121 122 ASSERT_EQ(execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)), 123 Result::NO_ERROR); 124 125 ASSERT_EQ(execution.compute(), Result::NO_ERROR); 126 } 127 128 private: 129 Model model_; 130 LSHProjectionType type_; 131 132 std::vector<float> Hash_; 133 std::vector<int> Input_; 134 std::vector<float> Weight_; 135 std::vector<int> Output_; 136 }; // namespace wrapper 137 138 TEST(LSHProjectionOpTest2, DenseWithThreeInputs) { 139 LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3}); 140 141 m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 142 m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 143 m.SetWeight({0.12, 0.34, 0.56}); 144 145 m.Invoke(); 146 147 EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0)); 148 } 149 150 TEST(LSHProjectionOpTest2, SparseDeprecatedWithTwoInputs) { 151 LSHProjectionOpModel m(LSHProjectionType_SPARSE_DEPRECATED, {4, 2}, {3, 2}, {0}); 152 153 m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 154 m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 155 156 m.Invoke(); 157 158 EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0)); 159 } 160 161 TEST(LSHProjectionOpTest2, SparseWithTwoInputs) { 162 LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {0}); 163 164 m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 165 m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 166 167 m.Invoke(); 168 169 EXPECT_THAT(m.GetOutput(), ElementsAre(1, 6, 10, 12)); 170 } 171 172 } // namespace wrapper 173 } // namespace nn 174 } // namespace android 175