Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Unit test for TFLite Sequential RNN op.
     16 
     17 #include <iomanip>
     18 #include <vector>
     19 
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     22 #include "tensorflow/contrib/lite/interpreter.h"
     23 #include "tensorflow/contrib/lite/kernels/register.h"
     24 #include "tensorflow/contrib/lite/kernels/test_util.h"
     25 #include "tensorflow/contrib/lite/model.h"
     26 
     27 namespace tflite {
     28 namespace {
     29 
     30 using ::testing::ElementsAreArray;
     31 
     32 static float rnn_input[] = {
     33     0.23689353,   0.285385,     0.037029743, -0.19858193,  -0.27569133,
     34     0.43773448,   0.60379338,   0.35562468,  -0.69424844,  -0.93421471,
     35     -0.87287879,  0.37144363,   -0.62476718, 0.23791671,   0.40060222,
     36     0.1356622,    -0.99774903,  -0.98858172, -0.38952237,  -0.47685933,
     37     0.31073618,   0.71511042,   -0.63767755, -0.31729108,  0.33468103,
     38     0.75801885,   0.30660987,   -0.37354088, 0.77002847,   -0.62747043,
     39     -0.68572164,  0.0069220066, 0.65791464,  0.35130811,   0.80834007,
     40     -0.61777675,  -0.21095741,  0.41213346,  0.73784804,   0.094794154,
     41     0.47791874,   0.86496925,   -0.53376222, 0.85315156,   0.10288584,
     42     0.86684,      -0.011186242, 0.10513687,  0.87825835,   0.59929144,
     43     0.62827742,   0.18899453,   0.31440187,  0.99059987,   0.87170351,
     44     -0.35091716,  0.74861872,   0.17831337,  0.2755419,    0.51864719,
     45     0.55084288,   0.58982027,   -0.47443086, 0.20875752,   -0.058871567,
     46     -0.66609079,  0.59098077,   0.73017097,  0.74604273,   0.32882881,
     47     -0.17503482,  0.22396147,   0.19379807,  0.29120302,   0.077113032,
     48     -0.70331609,  0.15804303,   -0.93407321, 0.40182066,   0.036301374,
     49     0.66521823,   0.0300982,    -0.7747041,  -0.02038002,  0.020698071,
     50     -0.90300065,  0.62870288,   -0.23068321, 0.27531278,   -0.095755219,
     51     -0.712036,    -0.17384434,  -0.50593495, -0.18646687,  -0.96508682,
     52     0.43519354,   0.14744234,   0.62589407,  0.1653645,    -0.10651493,
     53     -0.045277178, 0.99032974,   -0.88255352, -0.85147917,  0.28153265,
     54     0.19455957,   -0.55479527,  -0.56042433, 0.26048636,   0.84702539,
     55     0.47587705,   -0.074295521, -0.12287641, 0.70117295,   0.90532446,
     56     0.89782166,   0.79817224,   0.53402734,  -0.33286154,  0.073485017,
     57     -0.56172788,  -0.044897556, 0.89964068,  -0.067662835, 0.76863563,
     58     0.93455386,   -0.6324693,   -0.083922029};
     59 
     60 static float rnn_golden_output[] = {
     61     0.496726,   0,          0.965996,  0,         0.0584254, 0,
     62     0,          0.12315,    0,         0,         0.612266,  0.456601,
     63     0,          0.52286,    1.16099,   0.0291232,
     64 
     65     0,          0,          0.524901,  0,         0,         0,
     66     0,          1.02116,    0,         1.35762,   0,         0.356909,
     67     0.436415,   0.0355727,  0,         0,
     68 
     69     0,          0,          0,         0.262335,  0,         0,
     70     0,          1.33992,    0,         2.9739,    0,         0,
     71     1.31914,    2.66147,    0,         0,
     72 
     73     0.942568,   0,          0,         0,         0.025507,  0,
     74     0,          0,          0.321429,  0.569141,  1.25274,   1.57719,
     75     0.8158,     1.21805,    0.586239,  0.25427,
     76 
     77     1.04436,    0,          0.630725,  0,         0.133801,  0.210693,
     78     0.363026,   0,          0.533426,  0,         1.25926,   0.722707,
     79     0,          1.22031,    1.30117,   0.495867,
     80 
     81     0.222187,   0,          0.72725,   0,         0.767003,  0,
     82     0,          0.147835,   0,         0,         0,         0.608758,
     83     0.469394,   0.00720298, 0.927537,  0,
     84 
     85     0.856974,   0.424257,   0,         0,         0.937329,  0,
     86     0,          0,          0.476425,  0,         0.566017,  0.418462,
     87     0.141911,   0.996214,   1.13063,   0,
     88 
     89     0.967899,   0,          0,         0,         0.0831304, 0,
     90     0,          1.00378,    0,         0,         0,         1.44818,
     91     1.01768,    0.943891,   0.502745,  0,
     92 
     93     0.940135,   0,          0,         0,         0,         0,
     94     0,          2.13243,    0,         0.71208,   0.123918,  1.53907,
     95     1.30225,    1.59644,    0.70222,   0,
     96 
     97     0.804329,   0,          0.430576,  0,         0.505872,  0.509603,
     98     0.343448,   0,          0.107756,  0.614544,  1.44549,   1.52311,
     99     0.0454298,  0.300267,   0.562784,  0.395095,
    100 
    101     0.228154,   0,          0.675323,  0,         1.70536,   0.766217,
    102     0,          0,          0,         0.735363,  0.0759267, 1.91017,
    103     0.941888,   0,          0,         0,
    104 
    105     0,          0,          1.5909,    0,         0,         0,
    106     0,          0.5755,     0,         0.184687,  0,         1.56296,
    107     0.625285,   0,          0,         0,
    108 
    109     0,          0,          0.0857888, 0,         0,         0,
    110     0,          0.488383,   0.252786,  0,         0,         0,
    111     1.02817,    1.85665,    0,         0,
    112 
    113     0.00981836, 0,          1.06371,   0,         0,         0,
    114     0,          0,          0,         0.290445,  0.316406,  0,
    115     0.304161,   1.25079,    0.0707152, 0,
    116 
    117     0.986264,   0.309201,   0,         0,         0,         0,
    118     0,          1.64896,    0.346248,  0,         0.918175,  0.78884,
    119     0.524981,   1.92076,    2.07013,   0.333244,
    120 
    121     0.415153,   0.210318,   0,         0,         0,         0,
    122     0,          2.02616,    0,         0.728256,  0.84183,   0.0907453,
    123     0.628881,   3.58099,    1.49974,   0};
    124 
    125 class UnidirectionalRNNOpModel : public SingleOpModel {
    126  public:
    127   UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size,
    128                            bool time_major)
    129       : batches_(batches),
    130         sequence_len_(sequence_len),
    131         units_(units),
    132         input_size_(size) {
    133     input_ = AddInput(TensorType_FLOAT32);
    134     weights_ = AddInput(TensorType_FLOAT32);
    135     recurrent_weights_ = AddInput(TensorType_FLOAT32);
    136     bias_ = AddInput(TensorType_FLOAT32);
    137     hidden_state_ = AddOutput(TensorType_FLOAT32);
    138     output_ = AddOutput(TensorType_FLOAT32);
    139     SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
    140                  BuiltinOptions_SequenceRNNOptions,
    141                  CreateSequenceRNNOptions(builder_, time_major,
    142                                           ActivationFunctionType_RELU)
    143                      .Union());
    144     if (time_major) {
    145       BuildInterpreter({{sequence_len_, batches_, input_size_},
    146                         {units_, input_size_},
    147                         {units_, units_},
    148                         {units_}});
    149     } else {
    150       BuildInterpreter({{batches_, sequence_len_, input_size_},
    151                         {units_, input_size_},
    152                         {units_, units_},
    153                         {units_}});
    154     }
    155   }
    156 
    157   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
    158 
    159   void SetWeights(std::initializer_list<float> f) {
    160     PopulateTensor(weights_, f);
    161   }
    162 
    163   void SetRecurrentWeights(std::initializer_list<float> f) {
    164     PopulateTensor(recurrent_weights_, f);
    165   }
    166 
    167   void SetInput(std::initializer_list<float> data) {
    168     PopulateTensor(input_, data);
    169   }
    170 
    171   void SetInput(int offset, float* begin, float* end) {
    172     PopulateTensor(input_, offset, begin, end);
    173   }
    174 
    175   void ResetHiddenState() {
    176     const int zero_buffer_size = units_ * batches_;
    177     std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
    178     memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
    179     PopulateTensor(hidden_state_, 0, zero_buffer.get(),
    180                    zero_buffer.get() + zero_buffer_size);
    181   }
    182 
    183   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
    184 
    185   int input_size() { return input_size_; }
    186   int num_units() { return units_; }
    187   int num_batches() { return batches_; }
    188   int sequence_len() { return sequence_len_; }
    189 
    190  private:
    191   int input_;
    192   int weights_;
    193   int recurrent_weights_;
    194   int bias_;
    195   int hidden_state_;
    196   int output_;
    197 
    198   int batches_;
    199   int sequence_len_;
    200   int units_;
    201   int input_size_;
    202 };
    203 
    204 // TODO(mirkov): add another test which directly compares to TF once TOCO
    205 // supports the conversion from dynamic_rnn with BasicRNNCell.
    206 TEST(FullyConnectedOpTest, BlackBoxTest) {
    207   UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
    208                                /*units=*/16, /*size=*/8, /*time_major=*/false);
    209   rnn.SetWeights(
    210       {0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
    211        0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
    212        0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
    213        -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
    214        -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
    215        -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
    216        -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
    217        0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
    218        0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
    219        0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
    220        -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
    221        0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
    222        -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
    223        -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
    224        0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
    225        0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
    226        0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
    227        -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
    228        0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
    229        0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
    230        -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
    231        0.277308,    0.415818});
    232 
    233   rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
    234                -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
    235                0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
    236                -0.37609905});
    237 
    238   rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    239                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    240                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    241                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    242                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    243                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    244                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    245                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    246                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    247                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    248                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    249                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    250                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    251                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    252                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    253                            0.1});
    254 
    255   rnn.ResetHiddenState();
    256   const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
    257   float* batch_start = rnn_input;
    258   float* batch_end = batch_start + input_sequence_size;
    259   rnn.SetInput(0, batch_start, batch_end);
    260   rnn.SetInput(input_sequence_size, batch_start, batch_end);
    261 
    262   rnn.Invoke();
    263 
    264   float* golden_start = rnn_golden_output;
    265   float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len();
    266   std::vector<float> expected;
    267   expected.insert(expected.end(), golden_start, golden_end);
    268   expected.insert(expected.end(), golden_start, golden_end);
    269 
    270   EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    271 }
    272 
    273 TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) {
    274   UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
    275                                /*units=*/16, /*size=*/8, /*time_major=*/true);
    276   rnn.SetWeights(
    277       {0.461459,    0.153381,   0.529743,    -0.00371218, 0.676267,   -0.211346,
    278        0.317493,    0.969689,   -0.343251,   0.186423,    0.398151,   0.152399,
    279        0.448504,    0.317662,   0.523556,    -0.323514,   0.480877,   0.333113,
    280        -0.757714,   -0.674487,  -0.643585,   0.217766,    -0.0251462, 0.79512,
    281        -0.595574,   -0.422444,  0.371572,    -0.452178,   -0.556069,  -0.482188,
    282        -0.685456,   -0.727851,  0.841829,    0.551535,    -0.232336,  0.729158,
    283        -0.00294906, -0.69754,   0.766073,    -0.178424,   0.369513,   -0.423241,
    284        0.548547,    -0.0152023, -0.757482,   -0.85491,    0.251331,   -0.989183,
    285        0.306261,    -0.340716,  0.886103,    -0.0726757,  -0.723523,  -0.784303,
    286        0.0354295,   0.566564,   -0.485469,   -0.620498,   0.832546,   0.697884,
    287        -0.279115,   0.294415,   -0.584313,   0.548772,    0.0648819,  0.968726,
    288        0.723834,    -0.0080452, -0.350386,   -0.272803,   0.115121,   -0.412644,
    289        -0.824713,   -0.992843,  -0.592904,   -0.417893,   0.863791,   -0.423461,
    290        -0.147601,   -0.770664,  -0.479006,   0.654782,    0.587314,   -0.639158,
    291        0.816969,    -0.337228,  0.659878,    0.73107,     0.754768,   -0.337042,
    292        0.0960841,   0.368357,   0.244191,    -0.817703,   -0.211223,  0.442012,
    293        0.37225,     -0.623598,  -0.405423,   0.455101,    0.673656,   -0.145345,
    294        -0.511346,   -0.901675,  -0.81252,    -0.127006,   0.809865,   -0.721884,
    295        0.636255,    0.868989,   -0.347973,   -0.10179,    -0.777449,  0.917274,
    296        0.819286,    0.206218,   -0.00785118, 0.167141,    0.45872,    0.972934,
    297        -0.276798,   0.837861,   0.747958,    -0.0151566,  -0.330057,  -0.469077,
    298        0.277308,    0.415818});
    299 
    300   rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
    301                -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
    302                0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
    303                -0.37609905});
    304 
    305   rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    306                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    307                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    308                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    309                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    310                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    311                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    312                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    313                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    314                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    315                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    316                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    317                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    318                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    319                            0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    320                            0.1});
    321 
    322   rnn.ResetHiddenState();
    323   for (int i = 0; i < rnn.sequence_len(); i++) {
    324     float* batch_start = rnn_input + i * rnn.input_size();
    325     float* batch_end = batch_start + rnn.input_size();
    326     // The two batches are identical.
    327     rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
    328     rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
    329   }
    330 
    331   rnn.Invoke();
    332 
    333   std::vector<float> expected;
    334   for (int i = 0; i < rnn.sequence_len(); i++) {
    335     float* golden_batch_start = rnn_golden_output + i * rnn.num_units();
    336     float* golden_batch_end = golden_batch_start + rnn.num_units();
    337     expected.insert(expected.end(), golden_batch_start, golden_batch_end);
    338     expected.insert(expected.end(), golden_batch_start, golden_batch_end);
    339   }
    340 
    341   EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    342 }
    343 
    344 }  // namespace
    345 }  // namespace tflite
    346 
    347 int main(int argc, char** argv) {
    348   // On Linux, add: tflite::LogToStderr();
    349   ::testing::InitGoogleTest(&argc, argv);
    350   return RUN_ALL_TESTS();
    351 }
    352