Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "LSHProjection.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 #include "gmock/gmock-generated-matchers.h"
     21 #include "gmock/gmock-matchers.h"
     22 #include "gtest/gtest.h"
     23 
     24 using ::testing::FloatNear;
     25 using ::testing::Matcher;
     26 
     27 namespace android {
     28 namespace nn {
     29 namespace wrapper {
     30 
     31 using ::testing::ElementsAre;
     32 
     33 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
     34   ACTION(Hash, float)                            \
     35   ACTION(Input, int)                             \
     36   ACTION(Weight, float)
     37 
     38 // For all output and intermediate states
     39 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
     40   ACTION(Output, int)
     41 
     42 class LSHProjectionOpModel {
     43  public:
     44   LSHProjectionOpModel(LSHProjectionType type,
     45                        std::initializer_list<uint32_t> hash_shape,
     46                        std::initializer_list<uint32_t> input_shape,
     47                        std::initializer_list<uint32_t> weight_shape)
     48       : type_(type) {
     49     std::vector<uint32_t> inputs;
     50 
     51     OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
     52     inputs.push_back(model_.addOperand(&HashTy));
     53     OperandType InputTy(Type::TENSOR_INT32, input_shape);
     54     inputs.push_back(model_.addOperand(&InputTy));
     55     OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
     56     inputs.push_back(model_.addOperand(&WeightTy));
     57 
     58     OperandType TypeParamTy(Type::INT32, {});
     59     inputs.push_back(model_.addOperand(&TypeParamTy));
     60 
     61     std::vector<uint32_t> outputs;
     62 
     63     auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
     64       uint32_t sz = 1;
     65       for (uint32_t d : dims) {
     66         sz *= d;
     67       }
     68       return sz;
     69     };
     70 
     71     uint32_t outShapeDimension = 0;
     72     if (type == LSHProjectionType_SPARSE) {
     73       auto it = hash_shape.begin();
     74       Output_.insert(Output_.end(), *it, 0.f);
     75       outShapeDimension = *it;
     76     } else {
     77       Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
     78       outShapeDimension = multiAll(hash_shape);
     79     }
     80 
     81     OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
     82     outputs.push_back(model_.addOperand(&OutputTy));
     83 
     84     model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
     85     model_.identifyInputsAndOutputs(inputs, outputs);
     86 
     87     model_.finish();
     88   }
     89 
     90 #define DefineSetter(X, T)                       \
     91   void Set##X(const std::vector<T> &f) {         \
     92     X##_.insert(X##_.end(), f.begin(), f.end()); \
     93   }
     94 
     95   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
     96 
     97 #undef DefineSetter
     98 
     99   const std::vector<int> &GetOutput() const { return Output_; }
    100 
    101   void Invoke() {
    102     ASSERT_TRUE(model_.isValid());
    103 
    104     Compilation compilation(&model_);
    105     compilation.finish();
    106     Execution execution(&compilation);
    107 
    108 #define SetInputOrWeight(X, T)                                             \
    109     ASSERT_EQ(execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), \
    110                                  sizeof(T) * X##_.size()),                 \
    111               Result::NO_ERROR);
    112 
    113     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
    114 
    115 #undef SetInputOrWeight
    116 
    117 #define SetOutput(X, T)                                                   \
    118   ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
    119                                 sizeof(T) * X##_.size()),                 \
    120             Result::NO_ERROR);
    121 
    122     FOR_ALL_OUTPUT_TENSORS(SetOutput);
    123 
    124 #undef SetOutput
    125 
    126     ASSERT_EQ(
    127         execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
    128         Result::NO_ERROR);
    129 
    130     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    131   }
    132 
    133  private:
    134   Model model_;
    135   LSHProjectionType type_;
    136 
    137   std::vector<float> Hash_;
    138   std::vector<int> Input_;
    139   std::vector<float> Weight_;
    140   std::vector<int> Output_;
    141 };  // namespace wrapper
    142 
    143 TEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
    144   LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
    145 
    146   m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
    147   m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
    148   m.SetWeight({0.12, 0.34, 0.56});
    149 
    150   m.Invoke();
    151 
    152   EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
    153 }
    154 
    155 TEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
    156   LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {});
    157 
    158   m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
    159   m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
    160 
    161   m.Invoke();
    162 
    163   EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
    164 }
    165 
    166 }  // namespace wrapper
    167 }  // namespace nn
    168 }  // namespace android
    169