Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <gtest/gtest.h>
     16 #include "tensorflow/contrib/lite/interpreter.h"
     17 #include "tensorflow/contrib/lite/kernels/register.h"
     18 #include "tensorflow/contrib/lite/kernels/test_util.h"
     19 #include "tensorflow/contrib/lite/model.h"
     20 
     21 namespace tflite {
     22 namespace {
     23 
     24 using ::testing::ElementsAreArray;
     25 
     26 class PadOpModel : public SingleOpModel {
     27  public:
     28   void SetInput(std::initializer_list<float> data) {
     29     PopulateTensor<float>(input_, data);
     30   }
     31 
     32   void SetPaddings(std::initializer_list<int> paddings) {
     33     PopulateTensor<int>(paddings_, paddings);
     34   }
     35 
     36   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
     37   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
     38 
     39  protected:
     40   int input_;
     41   int output_;
     42   int paddings_;
     43 };
     44 
     45 // Tests case where paddings is a const tensor.
     46 //
     47 // Example usage is as follows:
     48 //    PadOpDynamicModel m(input_shape, paddings_shape, paddings_data);
     49 //    m.SetInput(input_data);
     50 //    m.Invoke();
     51 class PadOpConstModel : public PadOpModel {
     52  public:
     53   PadOpConstModel(std::initializer_list<int> input_shape,
     54                   std::initializer_list<int> paddings_shape,
     55                   std::initializer_list<int> paddings) {
     56     input_ = AddInput(TensorType_FLOAT32);
     57     paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
     58     output_ = AddOutput(TensorType_FLOAT32);
     59 
     60     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
     61                  CreatePadOptions(builder_).Union());
     62     BuildInterpreter({input_shape});
     63   }
     64 };
     65 
     66 // Test case where paddings is a non-const tensor.
     67 //
     68 // Example usage is as follows:
     69 //    PadOpDynamicModel m(input_shape, paddings_shape);
     70 //    m.SetInput(input_data);
     71 //    m.SetPaddings(paddings_data);
     72 //    m.Invoke();
     73 class PadOpDynamicModel : public PadOpModel {
     74  public:
     75   PadOpDynamicModel(std::initializer_list<int> input_shape,
     76                     std::initializer_list<int> paddings_shape) {
     77     input_ = AddInput(TensorType_FLOAT32);
     78     paddings_ = AddInput(TensorType_INT32);
     79     output_ = AddOutput(TensorType_FLOAT32);
     80 
     81     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
     82                  CreatePadOptions(builder_).Union());
     83     BuildInterpreter({input_shape, paddings_shape});
     84   }
     85 };
     86 
     87 TEST(PadOpTest, TooManyDimensions) {
     88   EXPECT_DEATH(
     89       PadOpConstModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 2},
     90                       {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}),
     91       "dims != 4");
     92 }
     93 
     94 TEST(PadOpTest, UnequalDimensions) {
     95   EXPECT_DEATH(PadOpConstModel({1, 1, 2, 1}, {3, 2}, {1, 1, 2, 2, 3, 3}),
     96                "3 != 4");
     97 }
     98 
     99 TEST(PadOpTest, InvalidPadValue) {
    100   EXPECT_DEATH(
    101       PadOpConstModel({1, 1, 2, 1}, {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}),
    102       "Pad value has to be greater than equal to 0.");
    103 }
    104 
    105 TEST(PadOpTest, SimpleConstTest) {
    106   // Padding is represented as four 2-D lists representing above padding and
    107   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
    108   PadOpConstModel m({1, 2, 2, 1}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0});
    109   m.SetInput({1, 2, 3, 4});
    110   m.Invoke();
    111   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
    112                                                0, 0, 0, 0, 0}));
    113   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
    114 }
    115 
    116 TEST(PadOpTest, SimpleDynamicTest) {
    117   PadOpDynamicModel m({1, 2, 2, 1}, {4, 2});
    118   m.SetInput({1, 2, 3, 4});
    119   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
    120   m.Invoke();
    121   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
    122                                                0, 0, 0, 0, 0}));
    123   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
    124 }
    125 
    126 TEST(PadOpTest, AdvancedConstTest) {
    127   PadOpConstModel m({1, 2, 3, 1}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0});
    128   m.SetInput({1, 2, 3, 4, 5, 6});
    129   m.Invoke();
    130   EXPECT_THAT(m.GetOutput(),
    131               ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
    132                                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
    133   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
    134 }
    135 
    136 TEST(PadOpTest, AdvancedDynamicTest) {
    137   PadOpDynamicModel m({1, 2, 3, 1}, {4, 2});
    138   m.SetInput({1, 2, 3, 4, 5, 6});
    139   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
    140   m.Invoke();
    141   EXPECT_THAT(m.GetOutput(),
    142               ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
    143                                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
    144   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
    145 }
    146 
    147 }  // namespace
    148 }  // namespace tflite
    149 
    150 int main(int argc, char** argv) {
    151   ::tflite::LogToStderr();
    152   ::testing::InitGoogleTest(&argc, argv);
    153   return RUN_ALL_TESTS();
    154 }
    155