Home | History | Annotate | Download | only in operations
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "RNN.h"
     18 
     19 #include "NeuralNetworksWrapper.h"
     20 #include "gmock/gmock-matchers.h"
     21 #include "gtest/gtest.h"
     22 
     23 namespace android {
     24 namespace nn {
     25 namespace wrapper {
     26 
     27 using ::testing::Each;
     28 using ::testing::FloatNear;
     29 using ::testing::Matcher;
     30 
     31 namespace {
     32 
     33 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
     34                                            float max_abs_error = 1.e-5) {
     35   std::vector<Matcher<float>> matchers;
     36   matchers.reserve(values.size());
     37   for (const float& v : values) {
     38     matchers.emplace_back(FloatNear(v, max_abs_error));
     39   }
     40   return matchers;
     41 }
     42 
     43 static float rnn_input[] = {
     44     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
     45     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
     46     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
     47     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
     48     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
     49     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
     50     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
     51     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
     52     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
     53     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
     54     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
     55     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
     56     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
     57     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
     58     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
     59     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
     60     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
     61     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
     62     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
     63     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
     64     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
     65     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
     66     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
     67     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
     68     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
     69     0.93455386,   -0.6324693,   -0.083922029};
     70 
     71 static float rnn_golden_output[] = {
     72     0.496726,   0,          0.965996,  0,         0.0584254, 0,
     73     0,          0.12315,    0,         0,         0.612266,  0.456601,
     74     0,          0.52286,    1.16099,   0.0291232,
     75 
     76     0,          0,          0.524901,  0,         0,         0,
     77     0,          1.02116,    0,         1.35762,   0,         0.356909,
     78     0.436415,   0.0355727,  0,         0,
     79 
     80     0,          0,          0,         0.262335,  0,         0,
     81     0,          1.33992,    0,         2.9739,    0,         0,
     82     1.31914,    2.66147,    0,         0,
     83 
     84     0.942568,   0,          0,         0,         0.025507,  0,
     85     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
     86     0.8158,     1.21805,    0.586239,  0.25427,
     87 
     88     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
     89     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
     90     0,          1.22031,    1.30117,   0.495867,
     91 
     92     0.222187,   0,          0.72725,   0,         0.767003,  0,
     93     0,          0.147835,   0,         0,         0,         0.608758,
     94     0.469394,   0.00720298, 0.927537,  0,
     95 
     96     0.856974,   0.424257,   0,         0,         0.937329,  0,
     97     0,          0,          0.476425,  0,         0.566017,  0.418462,
     98     0.141911,   0.996214,   1.13063,   0,
     99 
    100     0.967899,   0,          0,         0,         0.0831304, 0,
    101     0,          1.00378,    0,         0,         0,         1.44818,
    102     1.01768,    0.943891,   0.502745,  0,
    103 
    104     0.940135,   0,          0,         0,         0,         0,
    105     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
    106     1.30225,    1.59644,    0.70222,   0,
    107 
    108     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
    109     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
    110     0.0454298,  0.300267,   0.562784,  0.395095,
    111 
    112     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
    113     0,          0,          0,         0.735363,  0.0759267, 1.91017,
    114     0.941888,   0,          0,         0,
    115 
    116     0,          0,          1.5909,    0,         0,         0,
    117     0,          0.5755,     0,         0.184687,  0,         1.56296,
    118     0.625285,   0,          0,         0,
    119 
    120     0,          0,          0.0857888, 0,         0,         0,
    121     0,          0.488383,   0.252786,  0,         0,         0,
    122     1.02817,    1.85665,    0,         0,
    123 
    124     0.00981836, 0,          1.06371,   0,         0,         0,
    125     0,          0,          0,         0.290445,  0.316406,  0,
    126     0.304161,   1.25079,    0.0707152, 0,
    127 
    128     0.986264,   0.309201,   0,         0,         0,         0,
    129     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
    130     0.524981,   1.92076,    2.07013,   0.333244,
    131 
    132     0.415153,   0.210318,   0,         0,         0,         0,
    133     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
    134     0.628881,   3.58099,    1.49974,   0};
    135 
    136 }  // anonymous namespace
    137 
    138 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
    139   ACTION(Input)                                  \
    140   ACTION(Weights)                                \
    141   ACTION(RecurrentWeights)                       \
    142   ACTION(Bias)                                   \
    143   ACTION(HiddenStateIn)
    144 
    145 // For all output and intermediate states
    146 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
    147   ACTION(HiddenStateOut)               \
    148   ACTION(Output)
    149 
    150 class BasicRNNOpModel {
    151  public:
    152   BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
    153       : batches_(batches),
    154         units_(units),
    155         input_size_(size),
    156         activation_(kActivationRelu) {
    157     std::vector<uint32_t> inputs;
    158 
    159     OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
    160     inputs.push_back(model_.addOperand(&InputTy));
    161     OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
    162     inputs.push_back(model_.addOperand(&WeightTy));
    163     OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
    164     inputs.push_back(model_.addOperand(&RecurrentWeightTy));
    165     OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
    166     inputs.push_back(model_.addOperand(&BiasTy));
    167     OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
    168     inputs.push_back(model_.addOperand(&HiddenStateTy));
    169     OperandType ActionParamTy(Type::INT32, {1});
    170     inputs.push_back(model_.addOperand(&ActionParamTy));
    171 
    172     std::vector<uint32_t> outputs;
    173 
    174     outputs.push_back(model_.addOperand(&HiddenStateTy));
    175     OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
    176     outputs.push_back(model_.addOperand(&OutputTy));
    177 
    178     Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
    179     HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
    180     HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
    181     Output_.insert(Output_.end(), batches_ * units_, 0.f);
    182 
    183     model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
    184     model_.identifyInputsAndOutputs(inputs, outputs);
    185 
    186     model_.finish();
    187   }
    188 
    189 #define DefineSetter(X)                          \
    190   void Set##X(const std::vector<float>& f) {     \
    191     X##_.insert(X##_.end(), f.begin(), f.end()); \
    192   }
    193 
    194   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
    195 
    196 #undef DefineSetter
    197 
    198   void SetInput(int offset, float* begin, float* end) {
    199     for (; begin != end; begin++, offset++) {
    200       Input_[offset] = *begin;
    201     }
    202   }
    203 
    204   void ResetHiddenState() {
    205     std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
    206     std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
    207   }
    208 
    209   const std::vector<float>& GetOutput() const { return Output_; }
    210 
    211   uint32_t input_size() const { return input_size_; }
    212   uint32_t num_units() const { return units_; }
    213   uint32_t num_batches() const { return batches_; }
    214 
    215   void Invoke() {
    216     ASSERT_TRUE(model_.isValid());
    217 
    218     HiddenStateIn_.swap(HiddenStateOut_);
    219 
    220     Compilation compilation(&model_);
    221     compilation.finish();
    222     Execution execution(&compilation);
    223 #define SetInputOrWeight(X)                                                   \
    224   ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(),                \
    225                                sizeof(float) * X##_.size()),                  \
    226             Result::NO_ERROR);
    227 
    228     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
    229 
    230 #undef SetInputOrWeight
    231 
    232 #define SetOutput(X)                                                           \
    233   ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(),                \
    234                                 sizeof(float) * X##_.size()),                  \
    235             Result::NO_ERROR);
    236 
    237     FOR_ALL_OUTPUT_TENSORS(SetOutput);
    238 
    239 #undef SetOutput
    240 
    241     ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_,
    242                                  sizeof(activation_)),
    243               Result::NO_ERROR);
    244 
    245     ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    246   }
    247 
    248  private:
    249   Model model_;
    250 
    251   const uint32_t batches_;
    252   const uint32_t units_;
    253   const uint32_t input_size_;
    254 
    255   const int activation_;
    256 
    257 #define DefineTensor(X) std::vector<float> X##_;
    258 
    259   FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    260   FOR_ALL_OUTPUT_TENSORS(DefineTensor);
    261 
    262 #undef DefineTensor
    263 };
    264 
    265 TEST(RNNOpTest, BlackBoxTest) {
    266   BasicRNNOpModel rnn(2, 16, 8);
    267   rnn.SetWeights(
    268       {0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
    269        0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
    270        0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
    271        -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
    272        -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
    273        -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
    274        -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
    275        0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
    276        0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
    277        0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
    278        -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
    279        0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
    280        -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
    281        -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
    282        0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
    283        0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
    284        0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
    285        -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
    286        0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
    287        0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
    288        -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
    289        0.277308,    0.415818});
    290 
    291   rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
    292                -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
    293                0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
    294                -0.37609905});
    295 
    296   rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    297                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    298                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    299                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    300                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    301                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    302                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    303                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    304                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    305                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    306                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    307                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    308                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    309                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    310                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    311                            0.1});
    312 
    313   rnn.ResetHiddenState();
    314   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
    315                                   (rnn.input_size() * rnn.num_batches());
    316 
    317   for (int i = 0; i < input_sequence_size; i++) {
    318     float* batch_start = rnn_input + i * rnn.input_size();
    319     float* batch_end = batch_start + rnn.input_size();
    320     rnn.SetInput(0, batch_start, batch_end);
    321     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
    322 
    323     rnn.Invoke();
    324 
    325     float* golden_start = rnn_golden_output + i * rnn.num_units();
    326     float* golden_end = golden_start + rnn.num_units();
    327     std::vector<float> expected;
    328     expected.insert(expected.end(), golden_start, golden_end);
    329     expected.insert(expected.end(), golden_start, golden_end);
    330 
    331     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    332   }
    333 }
    334 
    335 }  // namespace wrapper
    336 }  // namespace nn
    337 }  // namespace android
    338