Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "EmbeddingLookup.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 #include "gmock/gmock-matchers.h"
     21 #include "gtest/gtest.h"
     22 
     23 using ::testing::FloatNear;
     24 using ::testing::Matcher;
     25 
     26 namespace android {
     27 namespace nn {
     28 namespace wrapper {
     29 
     30 namespace {
     31 
     32 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
     33                                            float max_abs_error=1.e-6) {
     34   std::vector<Matcher<float>> matchers;
     35   matchers.reserve(values.size());
     36   for (const float& v : values) {
     37     matchers.emplace_back(FloatNear(v, max_abs_error));
     38   }
     39   return matchers;
     40 }
     41 
     42 }  // namespace
     43 
     44 using ::testing::ElementsAreArray;
     45 
     46 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
     47   ACTION(Value, float)                           \
     48   ACTION(Lookup, int)
     49 
     50 // For all output and intermediate states
     51 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
     52   ACTION(Output, float)
     53 
     54 class EmbeddingLookupOpModel {
     55  public:
     56   EmbeddingLookupOpModel(std::initializer_list<uint32_t> index_shape,
     57                          std::initializer_list<uint32_t> weight_shape) {
     58     auto it = weight_shape.begin();
     59     rows_ = *it++;
     60     columns_ = *it++;
     61     features_ = *it;
     62 
     63     std::vector<uint32_t> inputs;
     64 
     65     OperandType LookupTy(Type::TENSOR_INT32, index_shape);
     66     inputs.push_back(model_.addOperand(&LookupTy));
     67 
     68     OperandType ValueTy(Type::TENSOR_FLOAT32, weight_shape);
     69     inputs.push_back(model_.addOperand(&ValueTy));
     70 
     71     std::vector<uint32_t> outputs;
     72 
     73     OperandType OutputOpndTy(Type::TENSOR_FLOAT32, weight_shape);
     74     outputs.push_back(model_.addOperand(&OutputOpndTy));
     75 
     76     auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
     77         uint32_t sz = 1;
     78         for (uint32_t d : dims) { sz *= d; }
     79         return sz;
     80     };
     81 
     82     Value_.insert(Value_.end(), multiAll(weight_shape), 0.f);
     83     Output_.insert(Output_.end(), multiAll(weight_shape), 0.f);
     84 
     85     model_.addOperation(ANEURALNETWORKS_EMBEDDING_LOOKUP, inputs, outputs);
     86     model_.identifyInputsAndOutputs(inputs, outputs);
     87 
     88     model_.finish();
     89   }
     90 
     91   void Invoke() {
     92     ASSERT_TRUE(model_.isValid());
     93 
     94     Compilation compilation(&model_);
     95     compilation.finish();
     96     Execution execution(&compilation);
     97 
     98 #define SetInputOrWeight(X, T)                                               \
     99   ASSERT_EQ(execution.setInput(EmbeddingLookup::k##X##Tensor, X##_.data(),   \
    100                                sizeof(T) * X##_.size()),                     \
    101             Result::NO_ERROR);
    102 
    103     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
    104 
    105 #undef SetInputOrWeight
    106 
    107 #define SetOutput(X, T)                                                       \
    108   ASSERT_EQ(execution.setOutput(EmbeddingLookup::k##X##Tensor, X##_.data(),   \
    109                                 sizeof(T) * X##_.size()),                     \
    110             Result::NO_ERROR);
    111 
    112     FOR_ALL_OUTPUT_TENSORS(SetOutput);
    113 
    114 #undef SetOutput
    115 
    116     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    117   }
    118 
    119 #define DefineSetter(X, T)                       \
    120   void Set##X(const std::vector<T>& f) {         \
    121     X##_.insert(X##_.end(), f.begin(), f.end()); \
    122   }
    123 
    124   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
    125 
    126 #undef DefineSetter
    127 
    128   void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
    129     for (uint32_t i = 0; i < rows_; i++) {
    130       for (uint32_t j = 0; j < columns_; j++) {
    131         for (uint32_t k = 0; k < features_; k++) {
    132           Value_[(i * columns_ + j) * features_ + k] = function(i, j, k);
    133         }
    134       }
    135     }
    136   }
    137 
    138   const std::vector<float> &GetOutput() const { return Output_; }
    139 
    140  private:
    141   Model model_;
    142   uint32_t rows_;
    143   uint32_t columns_;
    144   uint32_t features_;
    145 
    146 #define DefineTensor(X, T) std::vector<T> X##_;
    147 
    148   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    149   FOR_ALL_OUTPUT_TENSORS(DefineTensor);
    150 
    151 #undef DefineTensor
    152 };
    153 
    154 // TODO: write more tests that exercise the details of the op, such as
    155 // lookup errors and variable input shapes.
    156 TEST(EmbeddingLookupOpTest, SimpleTest) {
    157   EmbeddingLookupOpModel m({3}, {3, 2, 4});
    158   m.SetLookup({1, 0, 2});
    159   m.Set3DWeightMatrix(
    160       [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
    161 
    162   m.Invoke();
    163 
    164   EXPECT_THAT(m.GetOutput(),
    165               ElementsAreArray(ArrayFloatNear({
    166                   1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13,  // Row 1
    167                   0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13,  // Row 0
    168                   2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13,  // Row 2
    169               })));
    170 }
    171 
    172 }  // namespace wrapper
    173 }  // namespace nn
    174 }  // namespace android
    175