Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Unit test for TFLite RNN op.
     16 
     17 #include <iomanip>
     18 #include <vector>
     19 
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     22 #include "tensorflow/contrib/lite/interpreter.h"
     23 #include "tensorflow/contrib/lite/kernels/register.h"
     24 #include "tensorflow/contrib/lite/kernels/test_util.h"
     25 #include "tensorflow/contrib/lite/model.h"
     26 
     27 namespace tflite {
     28 namespace {
     29 
     30 using ::testing::ElementsAreArray;
     31 
     32 static float rnn_input[] = {
     33     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
     34     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
     35     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
     36     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
     37     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
     38     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
     39     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
     40     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
     41     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
     42     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
     43     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
     44     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
     45     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
     46     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
     47     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
     48     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
     49     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
     50     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
     51     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
     52     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
     53     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
     54     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
     55     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
     56     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
     57     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
     58     0.93455386,   -0.6324693,   -0.083922029};
     59 
     60 static float rnn_golden_output[] = {
     61     0.496726,   0,          0.965996,  0,         0.0584254, 0,
     62     0,          0.12315,    0,         0,         0.612266,  0.456601,
     63     0,          0.52286,    1.16099,   0.0291232,
     64 
     65     0,          0,          0.524901,  0,         0,         0,
     66     0,          1.02116,    0,         1.35762,   0,         0.356909,
     67     0.436415,   0.0355727,  0,         0,
     68 
     69     0,          0,          0,         0.262335,  0,         0,
     70     0,          1.33992,    0,         2.9739,    0,         0,
     71     1.31914,    2.66147,    0,         0,
     72 
     73     0.942568,   0,          0,         0,         0.025507,  0,
     74     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
     75     0.8158,     1.21805,    0.586239,  0.25427,
     76 
     77     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
     78     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
     79     0,          1.22031,    1.30117,   0.495867,
     80 
     81     0.222187,   0,          0.72725,   0,         0.767003,  0,
     82     0,          0.147835,   0,         0,         0,         0.608758,
     83     0.469394,   0.00720298, 0.927537,  0,
     84 
     85     0.856974,   0.424257,   0,         0,         0.937329,  0,
     86     0,          0,          0.476425,  0,         0.566017,  0.418462,
     87     0.141911,   0.996214,   1.13063,   0,
     88 
     89     0.967899,   0,          0,         0,         0.0831304, 0,
     90     0,          1.00378,    0,         0,         0,         1.44818,
     91     1.01768,    0.943891,   0.502745,  0,
     92 
     93     0.940135,   0,          0,         0,         0,         0,
     94     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
     95     1.30225,    1.59644,    0.70222,   0,
     96 
     97     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
     98     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
     99     0.0454298,  0.300267,   0.562784,  0.395095,
    100 
    101     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
    102     0,          0,          0,         0.735363,  0.0759267, 1.91017,
    103     0.941888,   0,          0,         0,
    104 
    105     0,          0,          1.5909,    0,         0,         0,
    106     0,          0.5755,     0,         0.184687,  0,         1.56296,
    107     0.625285,   0,          0,         0,
    108 
    109     0,          0,          0.0857888, 0,         0,         0,
    110     0,          0.488383,   0.252786,  0,         0,         0,
    111     1.02817,    1.85665,    0,         0,
    112 
    113     0.00981836, 0,          1.06371,   0,         0,         0,
    114     0,          0,          0,         0.290445,  0.316406,  0,
    115     0.304161,   1.25079,    0.0707152, 0,
    116 
    117     0.986264,   0.309201,   0,         0,         0,         0,
    118     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
    119     0.524981,   1.92076,    2.07013,   0.333244,
    120 
    121     0.415153,   0.210318,   0,         0,         0,         0,
    122     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
    123     0.628881,   3.58099,    1.49974,   0};
    124 
    125 class RNNOpModel : public SingleOpModel {
    126  public:
    127   RNNOpModel(int batches, int units, int size)
    128       : batches_(batches), units_(units), input_size_(size) {
    129     input_ = AddInput(TensorType_FLOAT32);
    130     weights_ = AddInput(TensorType_FLOAT32);
    131     recurrent_weights_ = AddInput(TensorType_FLOAT32);
    132     bias_ = AddInput(TensorType_FLOAT32);
    133     hidden_state_ = AddOutput(TensorType_FLOAT32);
    134     output_ = AddOutput(TensorType_FLOAT32);
    135     SetBuiltinOp(
    136         BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
    137         CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
    138     BuildInterpreter({{batches_, input_size_},
    139                       {units_, input_size_},
    140                       {units_, units_},
    141                       {units_}});
    142   }
    143 
    144   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
    145 
    146   void SetWeights(std::initializer_list<float> f) {
    147     PopulateTensor(weights_, f);
    148   }
    149 
    150   void SetRecurrentWeights(std::initializer_list<float> f) {
    151     PopulateTensor(recurrent_weights_, f);
    152   }
    153 
    154   void SetInput(std::initializer_list<float> data) {
    155     PopulateTensor(input_, data);
    156   }
    157 
    158   void SetInput(int offset, float* begin, float* end) {
    159     PopulateTensor(input_, offset, begin, end);
    160   }
    161 
    162   void ResetHiddenState() {
    163     const int zero_buffer_size = units_ * batches_;
    164     std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
    165     memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
    166     PopulateTensor(hidden_state_, 0, zero_buffer.get(),
    167                    zero_buffer.get() + zero_buffer_size);
    168   }
    169 
    170   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
    171 
    172   int input_size() { return input_size_; }
    173   int num_units() { return units_; }
    174   int num_batches() { return batches_; }
    175 
    176  private:
    177   int input_;
    178   int weights_;
    179   int recurrent_weights_;
    180   int bias_;
    181   int hidden_state_;
    182   int output_;
    183 
    184   int batches_;
    185   int units_;
    186   int input_size_;
    187 };
    188 
    189 TEST(FullyConnectedOpTest, BlackBoxTest) {
    190   RNNOpModel rnn(2, 16, 8);
    191   rnn.SetWeights(
    192       {0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
    193        0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
    194        0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
    195        -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
    196        -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
    197        -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
    198        -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
    199        0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
    200        0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
    201        0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
    202        -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
    203        0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
    204        -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
    205        -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
    206        0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
    207        0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
    208        0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
    209        -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
    210        0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
    211        0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
    212        -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
    213        0.277308,    0.415818});
    214 
    215   rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
    216                -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
    217                0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
    218                -0.37609905});
    219 
    220   rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    221                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    222                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    223                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    224                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    225                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    226                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    227                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    228                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    229                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    230                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    231                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    232                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    233                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    234                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    235                            0.1});
    236 
    237   rnn.ResetHiddenState();
    238   const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
    239                                   (rnn.input_size() * rnn.num_batches());
    240 
    241   for (int i = 0; i < input_sequence_size; i++) {
    242     float* batch_start = rnn_input + i * rnn.input_size();
    243     float* batch_end = batch_start + rnn.input_size();
    244     rnn.SetInput(0, batch_start, batch_end);
    245     rnn.SetInput(rnn.input_size(), batch_start, batch_end);
    246 
    247     rnn.Invoke();
    248 
    249     float* golden_start = rnn_golden_output + i * rnn.num_units();
    250     float* golden_end = golden_start + rnn.num_units();
    251     std::vector<float> expected;
    252     expected.insert(expected.end(), golden_start, golden_end);
    253     expected.insert(expected.end(), golden_start, golden_end);
    254 
    255     EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    256   }
    257 }
    258 
    259 }  // namespace
    260 }  // namespace tflite
    261 
    262 int main(int argc, char** argv) {
    263   ::tflite::LogToStderr();
    264   ::testing::InitGoogleTest(&argc, argv);
    265   return RUN_ALL_TESTS();
    266 }
    267