Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include <gtest/gtest.h>
     16 #include "tensorflow/lite/interpreter.h"
     17 #include "tensorflow/lite/kernels/register.h"
     18 #include "tensorflow/lite/kernels/test_util.h"
     19 #include "tensorflow/lite/model.h"
     20 
     21 namespace tflite {
     22 namespace {
     23 
     24 using ::testing::ElementsAreArray;
     25 
     26 class SelectOpModel : public SingleOpModel {
     27  public:
     28   SelectOpModel(std::initializer_list<int> input1_shape,
     29                 std::initializer_list<int> input2_shape,
     30                 std::initializer_list<int> input3_shape,
     31                 TensorType input_type) {
     32     input1_ = AddInput(TensorType_BOOL);
     33     input2_ = AddInput(input_type);
     34     input3_ = AddInput(input_type);
     35     output_ = AddOutput(input_type);
     36     SetBuiltinOp(BuiltinOperator_SELECT, BuiltinOptions_SelectOptions,
     37                  CreateSelectOptions(builder_).Union());
     38     BuildInterpreter({input1_shape, input2_shape, input3_shape});
     39   }
     40 
     41   int input1() { return input1_; }
     42   int input2() { return input2_; }
     43   int input3() { return input3_; }
     44 
     45   template <typename T>
     46   std::vector<T> GetOutput() {
     47     return ExtractVector<T>(output_);
     48   }
     49 
     50   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
     51 
     52  private:
     53   int input1_;
     54   int input2_;
     55   int input3_;
     56   int output_;
     57 };
     58 
     59 TEST(SelectOpTest, SelectBool) {
     60   SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
     61                       TensorType_BOOL);
     62 
     63   model.PopulateTensor<bool>(model.input1(), {true, false, true, false});
     64   model.PopulateTensor<bool>(model.input2(), {false, false, false, false});
     65   model.PopulateTensor<bool>(model.input3(), {true, true, true, true});
     66   model.Invoke();
     67 
     68   EXPECT_THAT(model.GetOutput<bool>(),
     69               ElementsAreArray({false, true, false, true}));
     70   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
     71 }
     72 
     73 TEST(SelectOpTest, SelectFloat) {
     74   SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
     75                       TensorType_FLOAT32);
     76 
     77   model.PopulateTensor<bool>(model.input1(), {true, false, true, false});
     78   model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.3, 0.4});
     79   model.PopulateTensor<float>(model.input3(), {0.5, 0.6, 0.7, 0.8});
     80   model.Invoke();
     81 
     82   EXPECT_THAT(model.GetOutput<float>(), ElementsAreArray({0.1, 0.6, 0.3, 0.8}));
     83   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
     84 }
     85 
     86 TEST(SelectOpTest, SelectUInt8) {
     87   SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
     88                       TensorType_UINT8);
     89 
     90   model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
     91   model.PopulateTensor<uint8_t>(model.input2(), {1, 2, 3, 4});
     92   model.PopulateTensor<uint8_t>(model.input3(), {5, 6, 7, 8});
     93   model.Invoke();
     94 
     95   EXPECT_THAT(model.GetOutput<uint8_t>(), ElementsAreArray({5, 2, 7, 8}));
     96   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
     97 }
     98 
     99 TEST(SelectOpTest, SelectInt8) {
    100   SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
    101                       TensorType_INT8);
    102 
    103   model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
    104   model.PopulateTensor<int8_t>(model.input2(), {1, -2, 3, 4});
    105   model.PopulateTensor<int8_t>(model.input3(), {5, 6, 7, -8});
    106   model.Invoke();
    107 
    108   EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, -2, 7, -8}));
    109   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
    110 }
    111 
    112 TEST(SelectOpTest, SelectInt16) {
    113   SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
    114                       TensorType_INT16);
    115 
    116   model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
    117   model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4});
    118   model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8});
    119   model.Invoke();
    120 
    121   EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8}));
    122   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
    123 }
    124 
    125 TEST(SelectOpTest, SelectInt32) {
    126   SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
    127                       TensorType_INT32);
    128 
    129   model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
    130   model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
    131   model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
    132   model.Invoke();
    133 
    134   EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 2, 7, 8}));
    135   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
    136 }
    137 
    138 TEST(SelectOpTest, RankOneSelectInt32) {
    139   SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
    140 
    141   model.PopulateTensor<bool>(model.input1(), {false, true});
    142   model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
    143   model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
    144   model.Invoke();
    145 
    146   EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 3, 4}));
    147   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
    148 }
    149 
    150 TEST(SelectOpTest, RankZeroSelectInt32) {
    151   SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
    152 
    153   model.PopulateTensor<bool>(model.input1(), {false});
    154   model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
    155   model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
    156   model.Invoke();
    157 
    158   EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 7, 8}));
    159   EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
    160 }
    161 
    162 }  // namespace
    163 }  // namespace tflite
    164 
    165 int main(int argc, char** argv) {
    166   ::tflite::LogToStderr();
    167   ::testing::InitGoogleTest(&argc, argv);
    168   return RUN_ALL_TESTS();
    169 }
    170