Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Unit test for TFLite FULLY_CONNECTED op.
     16 
     17 #include <iomanip>
     18 #include <vector>
     19 
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     22 #include "tensorflow/contrib/lite/interpreter.h"
     23 #include "tensorflow/contrib/lite/kernels/register.h"
     24 #include "tensorflow/contrib/lite/kernels/test_util.h"
     25 #include "tensorflow/contrib/lite/model.h"
     26 
     27 namespace tflite {
     28 namespace {
     29 
     30 using ::testing::ElementsAre;
     31 using ::testing::ElementsAreArray;
     32 
     33 static float fully_connected_input[] = {
     34     0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653,
     35     0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390,
     36     0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314,
     37     0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550,
     38     0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112,
     39     0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999,
     40     0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142,
     41     0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494,
     42     0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081,
     43     0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552,
     44     0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320,
     45     0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458,
     46     0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115,
     47     0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771,
     48     0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582,
     49     0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962,
     50     0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202,
     51     0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691,
     52     0.921330, 0.139902};
     53 
     54 static float fully_connected_golden_output[] = {
     55     0,        0.0732134,   0,        0,          0,         0.280859,
     56     0,        0.128927,    0,        0.0777251,  0,         0.270268,
     57     0.271435, 0.0173503,   0.335465, 0.235562,
     58 
     59     0,        0.0745866,   0,        0.051611,   0,         0.253876,
     60     0,        0.0814873,   0,        0.104104,   0,         0.248529,
     61     0.264194, 0,           0.302973, 0.166252,
     62 
     63     0,        0.0170409,   0,        0.0509851,  0,         0.212834,
     64     0,        0.0208326,   0,        0.129932,   0.203978,  0.103428,
     65     0.298051, 0,           0.332233, 0.00445903,
     66 
     67     0,        0.125246,    0,        0.0735336,  0,         0.0910256,
     68     0,        0,           0,        0.18933,    0.378111,  0.0712443,
     69     0.277298, 0.0123414,   0.267454, 0,
     70 
     71     0,        0.14687,     0,        0.155495,   0.0300215, 0.147256,
     72     0,        0,           0,        0.156412,   0.434914,  0.0461529,
     73     0.246508, 0,           0.363138, 0,
     74 
     75     0,        0,           0,        0.0212949,  0,         0.301708,
     76     0,        0.35497,     0,        0.406223,   0.0260211, 0.049195,
     77     0.197161, 0,           0.37316,  0,
     78 
     79     0,        0.221783,    0,        0,          0.0116515, 0.281945,
     80     0,        0,           0,        0,          0.285626,  0.181773,
     81     0.296401, 0.170452,    0.367135, 0.142597,
     82 
     83     0,        0,           0,        0,          0,         0.418886,
     84     0,        0.291063,    0,        0.227541,   0.0424759, 0.27589,
     85     0.398286, 0.177146,    0.40359,  0.121452,
     86 
     87     0,        0.0834884,   0,        0,          0,         0.287441,
     88     0,        0.0046838,   0,        0.0122087,  0,         0.217376,
     89     0.140183, 0.0948412,   0.436677, 0.0589876,
     90 
     91     0,        0.0289969,   0,        0.0921397,  0,         0.396802,
     92     0,        0.0126157,   0,        0.0968433,  0,         0.172271,
     93     0.173295, 0.0664741,   0.53645,  0.00915603,
     94 
     95     0,        0,           0,        0,          0,         0.147942,
     96     0,        0.263795,    0,        0.39782,    0,         0.382435,
     97     0.561072, 0.0579847,   0.145712, 0.13508,
     98 
     99     0,        0,           0,        0.16382,    0,         0.322294,
    100     0,        0.163798,    0,        0.405211,   0.367953,  0.076852,
    101     0.342473, 0.0834118,   0.377537, 0,
    102 
    103     0,        0.206,       0,        0,          0,         0.375769,
    104     0,        0,           0,        0,          0,         0.125165,
    105     0,        0.105591,    0.52055,  0.0536445,
    106 
    107     0,        0.259261,    0,        0,          0,         0.247707,
    108     0,        0,           0,        0,          0,         0.215862,
    109     0.149153, 0.224678,    0.359519, 0.129419,
    110 
    111     0,        0.17611,     0,        0.280895,   0,         0.576484,
    112     0,        0.000418848, 0,        0,          0,         0.151112,
    113     0.211902, 0,           0.566341, 0.106305,
    114 
    115     0,        0.0246284,   0,        0,          0,         0.196267,
    116     0,        0.0248624,   0,        0.265635,   0,         0.436199,
    117     0.408079, 0.134514,    0.328489, 0.411368};
    118 
    119 class BaseFullyConnectedOpModel : public SingleOpModel {
    120  public:
    121   // TODO(ahentz): test different activation types too.
    122   BaseFullyConnectedOpModel(int units, int batches, const TensorData& input,
    123                             const TensorData& output = {TensorType_FLOAT32})
    124       : batches_(batches), units_(units) {
    125     int total_input_size = 1;
    126     for (int i = 0; i < input.shape.size(); ++i) {
    127       total_input_size *= input.shape[i];
    128     }
    129     input_size_ = total_input_size / batches_;
    130 
    131     input_ = AddInput(input);
    132     weights_ =
    133         AddInput({input.type, {units_, input_size_}, input.min, input.max});
    134 
    135     if (input.type == TensorType_FLOAT32) {
    136       bias_ = AddInput({TensorType_FLOAT32, {units_}});
    137     } else {
    138       // This is a quantized version. The scale of 'bias' depends on the scales
    139       // of input and filter. Supposedly this is correctly set during quantized
    140       // training.
    141       auto bias_scale = GetScale(input_) * GetScale(weights_);
    142       TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
    143       bias_ = AddInput(bias);
    144     }
    145 
    146     output_ = AddOutput(output);
    147 
    148     SetBuiltinOp(
    149         BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
    150         CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
    151             .Union());
    152     BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
    153   }
    154 
    155   int input_size() { return input_size_; }
    156   int num_units() { return units_; }
    157   int num_batches() { return batches_; }
    158 
    159  protected:
    160   int input_;
    161   int weights_;
    162   int bias_;
    163   int output_;
    164 
    165   int batches_;
    166   int units_;
    167   int input_size_;
    168 };
    169 
    170 class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
    171  public:
    172   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
    173 
    174   void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
    175 
    176   void SetWeights(std::initializer_list<float> f) {
    177     PopulateTensor(weights_, f);
    178   }
    179 
    180   void SetInput(std::initializer_list<float> data) {
    181     PopulateTensor(input_, data);
    182   }
    183   void SetInput(int offset, float* begin, float* end) {
    184     PopulateTensor(input_, offset, begin, end);
    185   }
    186 
    187   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
    188 };
    189 
    190 class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
    191  public:
    192   using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
    193 
    194   void SetBias(std::initializer_list<float> data) {
    195     QuantizeAndPopulate<int32_t>(bias_, data);
    196   }
    197   void SetWeights(std::initializer_list<float> data) {
    198     QuantizeAndPopulate<uint8_t>(weights_, data);
    199   }
    200   void SetInput(std::initializer_list<float> data) {
    201     QuantizeAndPopulate<uint8_t>(input_, data);
    202   }
    203 
    204   std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
    205   std::vector<float> GetDequantizedOutput() {
    206     return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
    207                                GetScale(output_), GetZeroPoint(output_));
    208   }
    209 };
    210 
    211 // TODO(ahentz): add more small tests like this one, focused on making sure the
    212 // calculations are correct.
    213 TEST(FullyConnectedOpTest, SimpleTest) {
    214   FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}});
    215   m.SetWeights({
    216       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
    217       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    218       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    219   });
    220   m.SetBias({1, 2, 3});
    221 
    222   m.SetInput({
    223       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
    224       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
    225   });
    226 
    227   m.Invoke();
    228 
    229   EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
    230 }
    231 
    232 TEST(FullyConnectedOpTest, SimpleTestQuantized) {
    233   QuantizedFullyConnectedOpModel m(
    234       3, 2,
    235       /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64},
    236       /*output=*/{TensorType_UINT8, {}, -127, 128});
    237 
    238   // input_product_scale < output_scale was not true.
    239   m.SetWeights({
    240       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
    241       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    242       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    243   });
    244   m.SetBias({1, 2, 3});
    245 
    246   m.SetInput({
    247       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
    248       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
    249   });
    250 
    251   m.Invoke();
    252 
    253   EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
    254                                             24, 25, 26,  //
    255                                             58, 59, 60,  //
    256                                         })));
    257   EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
    258 }
    259 
    260 TEST(FullyConnectedOpTest, SimpleTest4DInput) {
    261   // Note that it is not required that the first dimension be the number of
    262   // batches. All we care is that the input can be evenly distributed in
    263   // batches. In this case, we need the input to have multiples of '2'.
    264   FloatFullyConnectedOpModel m(/*units=*/3,
    265                                /*batches=*/2,
    266                                /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
    267   m.SetWeights({
    268       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
    269       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    270       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    271   });
    272   m.SetBias({1, 2, 3});
    273 
    274   m.SetInput({
    275       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // first batch
    276       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // second batch
    277   });
    278 
    279   m.Invoke();
    280 
    281   EXPECT_THAT(m.GetOutput(), ElementsAreArray({
    282                                  24, 25, 26,  // first batch
    283                                  58, 59, 60,  // second batch
    284                              }));
    285 }
    286 
    287 TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) {
    288   QuantizedFullyConnectedOpModel m(
    289       3, 2,
    290       /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64},
    291       /*output=*/{TensorType_UINT8, {}, -127, 128});
    292 
    293   // input_product_scale < output_scale was not true.
    294   m.SetWeights({
    295       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 0
    296       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    297       1, 2, 3, 4, 5, 6, 7, 8, 9, 10,  // u = 1
    298   });
    299   m.SetBias({1, 2, 3});
    300 
    301   m.SetInput({
    302       1, 2, 3, 4, 5, 6, 7, 8,  -9, -10,  // b = 0
    303       1, 2, 3, 4, 5, 6, 7, -8, 9,  -10,  // b = 1
    304   });
    305 
    306   m.Invoke();
    307 
    308   EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
    309                                             24, 25, 26,  //
    310                                             58, 59, 60,  //
    311                                         })));
    312   EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
    313 }
    314 
    315 // TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard
    316 // to debug errors and doesn't necessarily test all the important details.
    317 TEST(FullyConnectedOpTest, BlackBoxTest) {
    318   FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}});
    319   m.SetWeights(
    320       {0.091327,  0.103366,  -0.316505, -0.083120, 0.149366,  -0.196636,
    321        -0.123672, 0.062800,  0.063031,  0.191670,  -0.062001, -0.061504,
    322        -0.275581, 0.059388,  -0.118497, -0.079224, 0.109758,  0.008307,
    323        -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650,
    324        0.266455,  0.051517,  -0.123448, 0.322464,  0.043282,  -0.173782,
    325        -0.190381, 0.002013,  0.096086,  0.131157,  0.031164,  0.100638,
    326        -0.312191, -0.080923, -0.101318, -0.116614, 0.142238,  0.086540,
    327        -0.139154, 0.174268,  -0.073161, 0.080072,  0.006874,  0.229382,
    328        -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824,
    329        -0.025021, 0.074460,  -0.252595, -0.161750, -0.136403, 0.008308,
    330        0.005710,  0.096600,  0.289839,  0.218816,  -0.304651, -0.070958,
    331        0.054598,  0.147113,  -0.139112, -0.072798, -0.163335, -0.167863,
    332        -0.128762, -0.035780, 0.117262,  0.017177,  0.263335,  -0.176612,
    333        0.262961,  -0.093654, -0.339283, 0.333071,  0.180827,  0.287583,
    334        0.066350,  -0.197947, -0.114449, -0.236035, 0.103532,  -0.034284,
    335        0.093299,  -0.145361, 0.054001,  0.250570,  0.157010,  -0.143480,
    336        -0.139061, -0.048873, 0.067557,  0.139038,  0.324106,  0.227041,
    337        0.037793,  -0.225747, -0.241619, 0.357835,  0.135762,  -0.306764,
    338        -0.125982, 0.091916,  0.266587,  0.030135,  0.265148,  0.141627,
    339        0.020120,  0.083815,  -0.124556, -0.100124, -0.048159, 0.181172,
    340        0.302309,  -0.041084, 0.146334,  -0.061511, -0.232605, 0.281324,
    341        0.145408,  -0.221897});
    342   m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860,
    343              0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804,
    344              0.048478, -0.032270, 0.175688, -0.085662});
    345 
    346   const int input_sequence_size = sizeof(fully_connected_input) /
    347                                   sizeof(float) /
    348                                   (m.input_size() * m.num_batches());
    349   for (int i = 0; i < input_sequence_size; i++) {
    350     // TODO(ahentz): This is what the original test was doing: two equal
    351     // batches per invocation. We could instead use two different batches.
    352     float* batch_start = fully_connected_input + i * m.input_size();
    353     float* batch_end = batch_start + m.input_size();
    354     m.SetInput(0, batch_start, batch_end);
    355     m.SetInput(m.input_size(), batch_start, batch_end);
    356 
    357     m.Invoke();
    358 
    359     float* golden_start = fully_connected_golden_output + i * m.num_units();
    360     float* golden_end = golden_start + m.num_units();
    361     std::vector<float> expected;
    362     expected.insert(expected.end(), golden_start, golden_end);
    363     expected.insert(expected.end(), golden_start, golden_end);
    364 
    365     EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    366   }
    367 }
    368 
    369 }  // namespace
    370 }  // namespace tflite
    371 
    372 int main(int argc, char** argv) {
    373   ::tflite::LogToStderr();
    374   ::testing::InitGoogleTest(&argc, argv);
    375   return RUN_ALL_TESTS();
    376 }
    377