Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Unit test for TFLite SVDF op.
     16 
     17 #include <iomanip>
     18 #include <vector>
     19 
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     22 #include "tensorflow/lite/interpreter.h"
     23 #include "tensorflow/lite/kernels/register.h"
     24 #include "tensorflow/lite/kernels/test_util.h"
     25 #include "tensorflow/lite/model.h"
     26 
     27 namespace tflite {
     28 namespace {
     29 
     30 using ::testing::ElementsAreArray;
     31 
     32 static float svdf_input[] = {
     33     0.12609188,  -0.46347019, -0.89598465,
     34     0.35867718,  0.36897406,  0.73463392,
     35 
     36     0.14278367,  -1.64410412, -0.75222826,
     37     -0.57290924, 0.12729003,  0.7567004,
     38 
     39     0.49837467,  0.19278903,  0.26584083,
     40     0.17660543,  0.52949083,  -0.77931279,
     41 
     42     -0.11186574, 0.13164264,  -0.05349274,
     43     -0.72674477, -0.5683046,  0.55900657,
     44 
     45     -0.68892461, 0.37783599,  0.18263303,
     46     -0.63690937, 0.44483393,  -0.71817774,
     47 
     48     -0.81299269, -0.86831826, 1.43940818,
     49     -0.95760226, 1.82078898,  0.71135032,
     50 
     51     -1.45006323, -0.82251364, -1.69082689,
     52     -1.65087092, -1.89238167, 1.54172635,
     53 
     54     0.03966608,  -0.24936394, -0.77526885,
     55     2.06740379,  -1.51439476, 1.43768692,
     56 
     57     0.11771342,  -0.23761693, -0.65898693,
     58     0.31088525,  -1.55601168, -0.87661445,
     59 
     60     -0.89477462, 1.67204106,  -0.53235275,
     61     -0.6230064,  0.29819036,  1.06939757,
     62 };
     63 
     64 static float svdf_golden_output_rank_1[] = {
     65     0.014899,    -0.0517661,  -0.143725,   -0.00271883,
     66     -0.03004015, 0.09565311,  0.1587342,   0.00784263,
     67 
     68     0.068281,    -0.162217,   -0.152268,   0.00323521,
     69     0.01582633,  0.03858774,  -0.03001583, -0.02671271,
     70 
     71     -0.0317821,  -0.0333089,  0.0609602,   0.0333759,
     72     -0.01432795, 0.05524484,  0.1101355,   -0.02382665,
     73 
     74     -0.00623099, -0.077701,   -0.391193,   -0.0136691,
     75     -0.02333033, 0.02293761,  0.12338032,  0.04326871,
     76 
     77     0.201551,    -0.164607,   -0.179462,   -0.0592739,
     78     0.01064911,  -0.17503069, 0.07821996,  -0.00224009,
     79 
     80     0.0886511,   -0.0875401,  -0.269283,   0.0281379,
     81     -0.02282338, 0.09741908,  0.32973239,  0.12281385,
     82 
     83     -0.201174,   -0.586145,   -0.628624,   -0.0330412,
     84     0.24780814,  -0.39304617, -0.22473189, 0.02589256,
     85 
     86     -0.0839096,  -0.299329,   0.108746,    0.109808,
     87     0.10084175,  -0.06416984, 0.28936723,  0.0026358,
     88 
     89     0.419114,    -0.237824,   -0.422627,   0.175115,
     90     -0.2314795,  -0.18584411, -0.4228974,  -0.12928449,
     91 
     92     0.36726,     -0.522303,   -0.456502,   -0.175475,
     93     0.17012937,  -0.34447709, 0.38505614,  -0.28158101,
     94 };
     95 
     96 static float svdf_golden_output_rank_2[] = {
     97     -0.09623547, -0.10193135, 0.11083051,  -0.0347917,
     98     0.1141196,   0.12965347,  -0.12652366, 0.01007236,
     99 
    100     -0.16396809, -0.21247184, 0.11259045,  -0.04156673,
    101     0.10132131,  -0.06143532, -0.00924693, 0.10084561,
    102 
    103     0.01257364,  0.0506071,   -0.19287863, -0.07162561,
    104     -0.02033747, 0.22673416,  0.15487903,  0.02525555,
    105 
    106     -0.1411963,  -0.37054959, 0.01774767,  0.05867489,
    107     0.09607603,  -0.0141301,  -0.08995658, 0.12867066,
    108 
    109     -0.27142537, -0.16955489, 0.18521598,  -0.12528358,
    110     0.00331409,  0.11167502,  0.02218599,  -0.07309391,
    111 
    112     0.09593632,  -0.28361851, -0.0773851,  0.17199151,
    113     -0.00075242, 0.33691186,  -0.1536046,  0.16572715,
    114 
    115     -0.27916506, -0.27626723, 0.42615682,  0.3225764,
    116     -0.37472126, -0.55655634, -0.05013514, 0.289112,
    117 
    118     -0.24418658, 0.07540751,  -0.1940318,  -0.08911639,
    119     0.00732617,  0.46737891,  0.26449674,  0.24888524,
    120 
    121     -0.17225097, -0.54660404, -0.38795233, 0.08389944,
    122     0.07736043,  -0.28260678, 0.15666828,  1.14949894,
    123 
    124     -0.57454878, -0.64704704, 0.73235172,  -0.34616736,
    125     0.21120001,  -0.22927976, 0.02455296,  -0.35906726,
    126 };
    127 
    128 // Derived class of SingleOpModel, which is used to test SVDF TFLite op.
    129 class BaseSVDFOpModel : public SingleOpModel {
    130  public:
    131   BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
    132                   int rank,
    133                   TensorType weights_feature_type = TensorType_FLOAT32,
    134                   TensorType weights_time_type = TensorType_FLOAT32)
    135       : batches_(batches),
    136         units_(units),
    137         input_size_(input_size),
    138         memory_size_(memory_size),
    139         rank_(rank) {
    140     input_ = AddInput(TensorType_FLOAT32);
    141     weights_feature_ = AddInput(weights_feature_type);
    142     weights_time_ = AddInput(weights_time_type);
    143     bias_ = AddNullInput();
    144     const int num_filters = units * rank;
    145     activation_state_ = AddInput(
    146         TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
    147         /*is_variable=*/true);
    148     output_ = AddOutput(TensorType_FLOAT32);
    149     SetBuiltinOp(
    150         BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
    151         CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
    152     BuildInterpreter({
    153         {batches_, input_size_},              // input tensor
    154         {units_ * rank, input_size_},         // weights_feature tensor
    155         {units_ * rank, memory_size_},        // weights_time tensor
    156         {units_},                             // bias tensor
    157         {batches, memory_size * num_filters}  // activation_state tensor
    158     });
    159   }
    160 
    161   // Populates the weights_feature tensor.
    162   void SetWeightsFeature(std::initializer_list<float> f) {
    163     PopulateTensor(weights_feature_, f);
    164   }
    165 
    166   // Populates the weights_time tensor.
    167   void SetWeightsTime(std::initializer_list<float> f) {
    168     PopulateTensor(weights_time_, f);
    169   }
    170 
    171   // Populates the input tensor.
    172   void SetInput(int offset, float* begin, float* end) {
    173     PopulateTensor(input_, offset, begin, end);
    174   }
    175 
    176   // Extracts the output tensor from the SVDF op.
    177   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
    178 
    179   int input_size() { return input_size_; }
    180   int num_units() { return units_; }
    181   int num_batches() { return batches_; }
    182 
    183  protected:
    184   int input_;
    185   int weights_feature_;
    186   int weights_time_;
    187   int bias_;
    188   int activation_state_;
    189   int output_;
    190 
    191   int batches_;
    192   int units_;
    193   int input_size_;
    194   int memory_size_;
    195   int rank_;
    196 };
    197 
    198 class SVDFOpModel : public BaseSVDFOpModel {
    199  public:
    200   using BaseSVDFOpModel::BaseSVDFOpModel;
    201 };
    202 
    203 class HybridSVDFOpModel : public BaseSVDFOpModel {
    204  public:
    205   HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
    206                     int rank, TensorType tensor_type)
    207       : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
    208                         tensor_type, tensor_type) {
    209     tensor_type_ = tensor_type;
    210   }
    211 
    212   void SetWeights(int weights_idx, const std::vector<float>& f) {
    213     if (tensor_type_ == TensorType_UINT8) {
    214       SymmetricQuantizeAndPopulate(weights_idx, f);
    215     } else {
    216       SignedSymmetricQuantizeAndPopulate(weights_idx, f);
    217     }
    218   }
    219 
    220   void SetWeightsFeature(std::initializer_list<float> f) {
    221     SetWeights(weights_feature_, f);
    222   }
    223 
    224   void SetWeightsTime(std::initializer_list<float> f) {
    225     SetWeights(weights_time_, f);
    226   }
    227 
    228  protected:
    229   TensorType tensor_type_;
    230 };
    231 
    232 class SVDFOpTest : public ::testing::Test {
    233  protected:
    234   void VerifyGoldens(float golden_input[], float golden_output[],
    235                      int golden_size, BaseSVDFOpModel* svdf,
    236                      float tolerance = 1e-5) {
    237     const int svdf_num_batches = svdf->num_batches();
    238     const int svdf_input_size = svdf->input_size();
    239     const int svdf_num_units = svdf->num_units();
    240     const int input_sequence_size =
    241         golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
    242     // Going over each input batch, setting the input tensor, invoking the SVDF
    243     // op and checking the output with the expected golden values.
    244     for (int i = 0; i < input_sequence_size; i++) {
    245       float* batch_start =
    246           golden_input + i * svdf_input_size * svdf_num_batches;
    247       float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
    248       svdf->SetInput(0, batch_start, batch_end);
    249 
    250       svdf->Invoke();
    251 
    252       const float* golden_start =
    253           golden_output + i * svdf_num_units * svdf_num_batches;
    254       const float* golden_end =
    255           golden_start + svdf_num_units * svdf_num_batches;
    256       std::vector<float> expected;
    257       expected.insert(expected.end(), golden_start, golden_end);
    258 
    259       EXPECT_THAT(svdf->GetOutput(),
    260                   ElementsAreArray(ArrayFloatNear(expected, tolerance)));
    261     }
    262   }
    263 };
    264 
    265 TEST_F(SVDFOpTest, BlackBoxTestRank1) {
    266   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
    267                    /*memory_size=*/10, /*rank=*/1);
    268   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
    269                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
    270                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
    271 
    272   svdf.SetWeightsTime(
    273       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    274        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
    275 
    276        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    277        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
    278 
    279        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    280        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
    281 
    282        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    283        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
    284 
    285   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
    286                 &svdf);
    287 }
    288 
    289 TEST_F(SVDFOpTest, BlackBoxTestRank2) {
    290   SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
    291                    /*memory_size=*/10, /*rank=*/2);
    292   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
    293                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
    294                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
    295                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
    296                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
    297                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
    298 
    299   svdf.SetWeightsTime(
    300       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    301        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
    302 
    303        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    304        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
    305 
    306        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    307        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
    308 
    309        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    310        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
    311 
    312        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
    313        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
    314 
    315        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
    316        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
    317 
    318        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
    319        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
    320 
    321        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
    322        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
    323 
    324   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
    325                 &svdf);
    326 }
    327 
    328 TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Uint8) {
    329   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
    330                          /*memory_size=*/10, /*rank=*/1, TensorType_UINT8);
    331   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
    332                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
    333                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
    334 
    335   svdf.SetWeightsTime(
    336       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    337        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
    338 
    339        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    340        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
    341 
    342        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    343        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
    344 
    345        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    346        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
    347 
    348   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
    349                 &svdf,
    350                 /*tolerance=*/0.002945);
    351 }
    352 
    353 TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Uint8) {
    354   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
    355                          /*memory_size=*/10, /*rank=*/2, TensorType_UINT8);
    356   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
    357                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
    358                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
    359                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
    360                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
    361                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
    362 
    363   svdf.SetWeightsTime(
    364       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    365        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
    366 
    367        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    368        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
    369 
    370        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    371        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
    372 
    373        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    374        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
    375 
    376        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
    377        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
    378 
    379        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
    380        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
    381 
    382        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
    383        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
    384 
    385        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
    386        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
    387 
    388   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
    389                 &svdf,
    390                 /*tolerance=*/0.00625109);
    391 }
    392 
    393 TEST_F(SVDFOpTest, BlackBoxTestHybridRank1Int8) {
    394   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
    395                          /*memory_size=*/10, /*rank=*/1, TensorType_INT8);
    396   svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
    397                           0.22197971, 0.12416199, 0.27901134, 0.27557442,
    398                           0.3905206, -0.36137494, -0.06634006, -0.10640851});
    399 
    400   svdf.SetWeightsTime(
    401       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    402        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
    403 
    404        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    405        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
    406 
    407        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    408        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
    409 
    410        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    411        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657});
    412 
    413   VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
    414                 &svdf,
    415                 /*tolerance=*/0.002945);
    416 }
    417 
    418 TEST_F(SVDFOpTest, BlackBoxTestHybridRank2Int8) {
    419   HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
    420                          /*memory_size=*/10, /*rank=*/2, TensorType_INT8);
    421   svdf.SetWeightsFeature({-0.31930989, 0.0079667,   0.39296314,  0.37613347,
    422                           0.12416199,  0.15785322,  0.27901134,  0.3905206,
    423                           0.21931258,  -0.36137494, -0.10640851, 0.31053296,
    424                           -0.36118156, -0.0976817,  -0.36916667, 0.22197971,
    425                           0.15294972,  0.38031587,  0.27557442,  0.39635518,
    426                           -0.21580373, -0.06634006, -0.02702999, 0.27072677});
    427 
    428   svdf.SetWeightsTime(
    429       {-0.31930989, 0.37613347,  0.27901134,  -0.36137494, -0.36118156,
    430        0.22197971,  0.27557442,  -0.06634006, 0.0079667,   0.12416199,
    431 
    432        0.3905206,   -0.10640851, -0.0976817,  0.15294972,  0.39635518,
    433        -0.02702999, 0.39296314,  0.15785322,  0.21931258,  0.31053296,
    434 
    435        -0.36916667, 0.38031587,  -0.21580373, 0.27072677,  0.23622236,
    436        0.34936687,  0.18174365,  0.35907319,  -0.17493086, 0.324846,
    437 
    438        -0.10781813, 0.27201805,  0.14324132,  -0.23681851, -0.27115166,
    439        -0.01580888, -0.14943552, 0.15465137,  0.09784451,  -0.0337657,
    440 
    441        -0.14884081, 0.19931212,  -0.36002168, 0.34663299,  -0.11405486,
    442        0.12672701,  0.39463779,  -0.07886535, -0.06384811, 0.08249187,
    443 
    444        -0.26816407, -0.19905911, 0.29211238,  0.31264046,  -0.28664589,
    445        0.05698794,  0.11613581,  0.14078894,  0.02187902,  -0.21781836,
    446 
    447        -0.15567942, 0.08693647,  -0.38256618, 0.36580828,  -0.22922277,
    448        -0.0226903,  0.12878349,  -0.28122205, -0.10850525, -0.11955214,
    449 
    450        0.27179423,  -0.04710215, 0.31069002,  0.22672787,  0.09580326,
    451        0.08682203,  0.1258215,   0.1851041,   0.29228821,  0.12366763});
    452 
    453   VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
    454                 &svdf,
    455                 /*tolerance=*/0.00625109);
    456 }
    457 
    458 }  // namespace
    459 }  // namespace tflite
    460 
    461 int main(int argc, char** argv) {
    462   ::tflite::LogToStderr();
    463   ::testing::InitGoogleTest(&argc, argv);
    464   return RUN_ALL_TESTS();
    465 }
    466