Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "HashtableLookup.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 #include "gmock/gmock-matchers.h"
     21 #include "gtest/gtest.h"
     22 
     23 using ::testing::FloatNear;
     24 using ::testing::Matcher;
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace wrapper {
     29 
     30 namespace {
     31 
     32 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
     33                                            float max_abs_error=1.e-6) {
     34   std::vector<Matcher<float>> matchers;
     35   matchers.reserve(values.size());
     36   for (const float& v : values) {
     37     matchers.emplace_back(FloatNear(v, max_abs_error));
     38   }
     39   return matchers;
     40 }
     41 
     42 }  // namespace
     43 
     44 using ::testing::ElementsAreArray;
     45 
     46 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION)     \
     47   ACTION(Lookup, int)                                \
     48   ACTION(Key, int)                                   \
     49   ACTION(Value, float)
     50 
     51 // For all output and intermediate states
     52 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
     53   ACTION(Output, float)                \
     54   ACTION(Hits, uint8_t)
     55 
     56 class HashtableLookupOpModel {
     57  public:
     58     HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
     59                            std::initializer_list<uint32_t> key_shape,
     60                            std::initializer_list<uint32_t> value_shape) {
     61     auto it_vs = value_shape.begin();
     62     rows_ = *it_vs++;
     63     features_ = *it_vs;
     64 
     65     std::vector<uint32_t> inputs;
     66 
     67     // Input and weights
     68     OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
     69     inputs.push_back(model_.addOperand(&LookupTy));
     70 
     71     OperandType KeyTy(Type::TENSOR_INT32, key_shape);
     72     inputs.push_back(model_.addOperand(&KeyTy));
     73 
     74     OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
     75     inputs.push_back(model_.addOperand(&ValueTy));
     76 
     77     // Output and other intermediate state
     78     std::vector<uint32_t> outputs;
     79 
     80     std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
     81     out_dim.push_back(features_);
     82 
     83     OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
     84     outputs.push_back(model_.addOperand(&OutputOpndTy));
     85 
     86     OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
     87     outputs.push_back(model_.addOperand(&HitsOpndTy));
     88 
     89     auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
     90         uint32_t sz = 1;
     91         for (uint32_t d : dims) { sz *= d; }
     92         return sz;
     93     };
     94 
     95     Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
     96     Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
     97     Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
     98 
     99     model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
    100     model_.identifyInputsAndOutputs(inputs, outputs);
    101 
    102     model_.finish();
    103   }
    104 
    105   void Invoke() {
    106     ASSERT_TRUE(model_.isValid());
    107 
    108     Compilation compilation(&model_);
    109     compilation.finish();
    110     Execution execution(&compilation);
    111 
    112 #define SetInputOrWeight(X, T)                                             \
    113   ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
    114                                sizeof(T) * X##_.size()),                   \
    115             Result::NO_ERROR);
    116 
    117     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
    118 
    119 #undef SetInputOrWeight
    120 
    121 #define SetOutput(X, T)                                                     \
    122   ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
    123                                sizeof(T) * X##_.size()),                    \
    124             Result::NO_ERROR);
    125 
    126     FOR_ALL_OUTPUT_TENSORS(SetOutput);
    127 
    128 #undef SetOutput
    129 
    130     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    131   }
    132 
    133 #define DefineSetter(X, T)                       \
    134   void Set##X(const std::vector<T>& f) {         \
    135     X##_.insert(X##_.end(), f.begin(), f.end()); \
    136   }
    137 
    138   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
    139 
    140 #undef DefineSetter
    141 
    142   void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
    143     for (uint32_t i = 0; i < rows_; i++) {
    144       for (uint32_t j = 0; j < features_; j++) {
    145           Value_[i * features_ + j] = function(i, j);
    146       }
    147     }
    148   }
    149 
    150   const std::vector<float>& GetOutput() const { return Output_; }
    151   const std::vector<uint8_t>& GetHits() const { return Hits_; }
    152 
    153  private:
    154   Model model_;
    155   uint32_t rows_;
    156   uint32_t features_;
    157 
    158 #define DefineTensor(X, T) std::vector<T> X##_;
    159 
    160   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    161   FOR_ALL_OUTPUT_TENSORS(DefineTensor);
    162 
    163 #undef DefineTensor
    164 };
    165 
    166 TEST(HashtableLookupOpTest, BlackBoxTest) {
    167   HashtableLookupOpModel m({4}, {3}, {3, 2});
    168 
    169   m.SetLookup({1234, -292, -11, 0});
    170   m.SetKey({-11, 0, 1234});
    171   m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
    172 
    173   m.Invoke();
    174 
    175   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
    176                                  2.0, 2.1,  // 2-rd item
    177                                  0, 0,      // Not found
    178                                  0.0, 0.1,  // 0-th item
    179                                  1.0, 1.1,  // 1-st item
    180                              })));
    181   EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
    182                                1, 0, 1, 1,
    183                            }));
    184 
    185 }
    186 
    187 }  // namespace wrapper
    188 }  // namespace nn
    189 }  // namespace android
    190