Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "LSHProjection.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 #include "gmock/gmock-generated-matchers.h"
     21 #include "gmock/gmock-matchers.h"
     22 #include "gtest/gtest.h"
     23 
     24 using ::testing::FloatNear;
     25 using ::testing::Matcher;
     26 
     27 namespace android {
     28 namespace nn {
     29 namespace wrapper {
     30 
     31 using ::testing::ElementsAre;
     32 
     33 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
     34     ACTION(Hash, float)                          \
     35     ACTION(Input, int)                           \
     36     ACTION(Weight, float)
     37 
     38 // For all output and intermediate states
     39 #define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, int)
     40 
     41 class LSHProjectionOpModel {
     42    public:
     43     LSHProjectionOpModel(LSHProjectionType type, std::initializer_list<uint32_t> hash_shape,
     44                          std::initializer_list<uint32_t> input_shape,
     45                          std::initializer_list<uint32_t> weight_shape)
     46         : type_(type) {
     47         std::vector<uint32_t> inputs;
     48 
     49         OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
     50         inputs.push_back(model_.addOperand(&HashTy));
     51         OperandType InputTy(Type::TENSOR_INT32, input_shape);
     52         inputs.push_back(model_.addOperand(&InputTy));
     53         OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
     54         inputs.push_back(model_.addOperand(&WeightTy));
     55 
     56         OperandType TypeParamTy(Type::INT32, {});
     57         inputs.push_back(model_.addOperand(&TypeParamTy));
     58 
     59         std::vector<uint32_t> outputs;
     60 
     61         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
     62             uint32_t sz = 1;
     63             for (uint32_t d : dims) {
     64                 sz *= d;
     65             }
     66             return sz;
     67         };
     68 
     69         uint32_t outShapeDimension = 0;
     70         if (type == LSHProjectionType_SPARSE || type == LSHProjectionType_SPARSE_DEPRECATED) {
     71             auto it = hash_shape.begin();
     72             Output_.insert(Output_.end(), *it, 0.f);
     73             outShapeDimension = *it;
     74         } else {
     75             Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
     76             outShapeDimension = multiAll(hash_shape);
     77         }
     78 
     79         OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
     80         outputs.push_back(model_.addOperand(&OutputTy));
     81 
     82         model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
     83         model_.identifyInputsAndOutputs(inputs, outputs);
     84 
     85         model_.finish();
     86     }
     87 
     88 #define DefineSetter(X, T) \
     89     void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
     90 
     91     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
     92 
     93 #undef DefineSetter
     94 
     95     const std::vector<int>& GetOutput() const { return Output_; }
     96 
     97     void Invoke() {
     98         ASSERT_TRUE(model_.isValid());
     99 
    100         Compilation compilation(&model_);
    101         compilation.finish();
    102         Execution execution(&compilation);
    103 
    104 #define SetInputOrWeight(X, T)                                                                     \
    105     ASSERT_EQ(                                                                                     \
    106             execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), sizeof(T) * X##_.size()), \
    107             Result::NO_ERROR);
    108 
    109         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
    110 
    111 #undef SetInputOrWeight
    112 
    113 #define SetOutput(X, T)                                                     \
    114     ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
    115                                   sizeof(T) * X##_.size()),                 \
    116               Result::NO_ERROR);
    117 
    118         FOR_ALL_OUTPUT_TENSORS(SetOutput);
    119 
    120 #undef SetOutput
    121 
    122         ASSERT_EQ(execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
    123                   Result::NO_ERROR);
    124 
    125         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    126     }
    127 
    128    private:
    129     Model model_;
    130     LSHProjectionType type_;
    131 
    132     std::vector<float> Hash_;
    133     std::vector<int> Input_;
    134     std::vector<float> Weight_;
    135     std::vector<int> Output_;
    136 };  // namespace wrapper
    137 
    138 TEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
    139     LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
    140 
    141     m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
    142     m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
    143     m.SetWeight({0.12, 0.34, 0.56});
    144 
    145     m.Invoke();
    146 
    147     EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
    148 }
    149 
    150 TEST(LSHProjectionOpTest2, SparseDeprecatedWithTwoInputs) {
    151     LSHProjectionOpModel m(LSHProjectionType_SPARSE_DEPRECATED, {4, 2}, {3, 2}, {0});
    152 
    153     m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
    154     m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
    155 
    156     m.Invoke();
    157 
    158     EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
    159 }
    160 
    161 TEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
    162     LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {0});
    163 
    164     m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
    165     m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
    166 
    167     m.Invoke();
    168 
    169     EXPECT_THAT(m.GetOutput(), ElementsAre(1, 6, 10, 12));
    170 }
    171 
    172 }  // namespace wrapper
    173 }  // namespace nn
    174 }  // namespace android
    175