1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "LSHProjection.h" 18 19 #include "NeuralNetworksWrapper.h" 20 #include "gmock/gmock-generated-matchers.h" 21 #include "gmock/gmock-matchers.h" 22 #include "gtest/gtest.h" 23 24 using ::testing::FloatNear; 25 using ::testing::Matcher; 26 27 namespace android { 28 namespace nn { 29 namespace wrapper { 30 31 using ::testing::ElementsAre; 32 33 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \ 34 ACTION(Hash, float) \ 35 ACTION(Input, int) \ 36 ACTION(Weight, float) 37 38 // For all output and intermediate states 39 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \ 40 ACTION(Output, int) 41 42 class LSHProjectionOpModel { 43 public: 44 LSHProjectionOpModel(LSHProjectionType type, 45 std::initializer_list<uint32_t> hash_shape, 46 std::initializer_list<uint32_t> input_shape, 47 std::initializer_list<uint32_t> weight_shape) 48 : type_(type) { 49 std::vector<uint32_t> inputs; 50 51 OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape); 52 inputs.push_back(model_.addOperand(&HashTy)); 53 OperandType InputTy(Type::TENSOR_INT32, input_shape); 54 inputs.push_back(model_.addOperand(&InputTy)); 55 OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape); 56 inputs.push_back(model_.addOperand(&WeightTy)); 57 58 OperandType TypeParamTy(Type::INT32, {}); 59 inputs.push_back(model_.addOperand(&TypeParamTy)); 60 61 std::vector<uint32_t> outputs; 62 63 auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t { 64 uint32_t sz = 1; 65 for (uint32_t d : dims) { 66 sz *= d; 67 } 68 return sz; 69 }; 70 71 uint32_t outShapeDimension = 0; 72 if (type == LSHProjectionType_SPARSE) { 73 auto it = hash_shape.begin(); 74 Output_.insert(Output_.end(), *it, 0.f); 75 outShapeDimension = *it; 76 } else { 77 Output_.insert(Output_.end(), multiAll(hash_shape), 0.f); 78 outShapeDimension = multiAll(hash_shape); 79 } 80 81 OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension}); 82 outputs.push_back(model_.addOperand(&OutputTy)); 83 84 model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs); 85 model_.identifyInputsAndOutputs(inputs, outputs); 86 87 model_.finish(); 88 } 89 90 #define DefineSetter(X, T) \ 91 void Set##X(const std::vector<T> &f) { \ 92 X##_.insert(X##_.end(), f.begin(), f.end()); \ 93 } 94 95 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter); 96 97 #undef DefineSetter 98 99 const std::vector<int> &GetOutput() const { return Output_; } 100 101 void Invoke() { 102 ASSERT_TRUE(model_.isValid()); 103 104 Compilation compilation(&model_); 105 compilation.finish(); 106 Execution execution(&compilation); 107 108 #define SetInputOrWeight(X, T) \ 109 ASSERT_EQ(execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), \ 110 sizeof(T) * X##_.size()), \ 111 Result::NO_ERROR); 112 113 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight); 114 115 #undef SetInputOrWeight 116 117 #define SetOutput(X, T) \ 118 ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \ 119 sizeof(T) * X##_.size()), \ 120 Result::NO_ERROR); 121 122 FOR_ALL_OUTPUT_TENSORS(SetOutput); 123 124 #undef SetOutput 125 126 ASSERT_EQ( 127 execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)), 128 Result::NO_ERROR); 129 130 ASSERT_EQ(execution.compute(), Result::NO_ERROR); 131 } 132 133 private: 134 Model model_; 135 LSHProjectionType type_; 136 137 std::vector<float> Hash_; 138 std::vector<int> Input_; 139 std::vector<float> Weight_; 140 std::vector<int> Output_; 141 }; // namespace wrapper 142 143 TEST(LSHProjectionOpTest2, DenseWithThreeInputs) { 144 LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3}); 145 146 m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 147 m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 148 m.SetWeight({0.12, 0.34, 0.56}); 149 150 m.Invoke(); 151 152 EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0)); 153 } 154 155 TEST(LSHProjectionOpTest2, SparseWithTwoInputs) { 156 LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {}); 157 158 m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321}); 159 m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765}); 160 161 m.Invoke(); 162 163 EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0)); 164 } 165 166 } // namespace wrapper 167 } // namespace nn 168 } // namespace android 169