1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <gtest/gtest.h> 16 #include "tensorflow/lite/interpreter.h" 17 #include "tensorflow/lite/kernels/register.h" 18 #include "tensorflow/lite/kernels/test_util.h" 19 #include "tensorflow/lite/model.h" 20 21 namespace tflite { 22 namespace { 23 24 using ::testing::ElementsAreArray; 25 26 class SelectOpModel : public SingleOpModel { 27 public: 28 SelectOpModel(std::initializer_list<int> input1_shape, 29 std::initializer_list<int> input2_shape, 30 std::initializer_list<int> input3_shape, 31 TensorType input_type) { 32 input1_ = AddInput(TensorType_BOOL); 33 input2_ = AddInput(input_type); 34 input3_ = AddInput(input_type); 35 output_ = AddOutput(input_type); 36 SetBuiltinOp(BuiltinOperator_SELECT, BuiltinOptions_SelectOptions, 37 CreateSelectOptions(builder_).Union()); 38 BuildInterpreter({input1_shape, input2_shape, input3_shape}); 39 } 40 41 int input1() { return input1_; } 42 int input2() { return input2_; } 43 int input3() { return input3_; } 44 45 template <typename T> 46 std::vector<T> GetOutput() { 47 return ExtractVector<T>(output_); 48 } 49 50 std::vector<int> GetOutputShape() { return GetTensorShape(output_); } 51 52 private: 53 int input1_; 54 int input2_; 55 int input3_; 56 int output_; 57 }; 58 59 TEST(SelectOpTest, SelectBool) { 60 SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, 61 TensorType_BOOL); 62 63 model.PopulateTensor<bool>(model.input1(), {true, false, true, false}); 64 model.PopulateTensor<bool>(model.input2(), {false, false, false, false}); 65 model.PopulateTensor<bool>(model.input3(), {true, true, true, true}); 66 model.Invoke(); 67 68 EXPECT_THAT(model.GetOutput<bool>(), 69 ElementsAreArray({false, true, false, true})); 70 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); 71 } 72 73 TEST(SelectOpTest, SelectFloat) { 74 SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, 75 TensorType_FLOAT32); 76 77 model.PopulateTensor<bool>(model.input1(), {true, false, true, false}); 78 model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.3, 0.4}); 79 model.PopulateTensor<float>(model.input3(), {0.5, 0.6, 0.7, 0.8}); 80 model.Invoke(); 81 82 EXPECT_THAT(model.GetOutput<float>(), ElementsAreArray({0.1, 0.6, 0.3, 0.8})); 83 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); 84 } 85 86 TEST(SelectOpTest, SelectUInt8) { 87 SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, 88 TensorType_UINT8); 89 90 model.PopulateTensor<bool>(model.input1(), {false, true, false, false}); 91 model.PopulateTensor<uint8_t>(model.input2(), {1, 2, 3, 4}); 92 model.PopulateTensor<uint8_t>(model.input3(), {5, 6, 7, 8}); 93 model.Invoke(); 94 95 EXPECT_THAT(model.GetOutput<uint8_t>(), ElementsAreArray({5, 2, 7, 8})); 96 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); 97 } 98 99 TEST(SelectOpTest, SelectInt8) { 100 SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, 101 TensorType_INT8); 102 103 model.PopulateTensor<bool>(model.input1(), {false, true, false, false}); 104 model.PopulateTensor<int8_t>(model.input2(), {1, -2, 3, 4}); 105 model.PopulateTensor<int8_t>(model.input3(), {5, 6, 7, -8}); 106 model.Invoke(); 107 108 EXPECT_THAT(model.GetOutput<int8_t>(), ElementsAreArray({5, -2, 7, -8})); 109 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); 110 } 111 112 TEST(SelectOpTest, SelectInt16) { 113 SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, 114 TensorType_INT16); 115 116 model.PopulateTensor<bool>(model.input1(), {false, true, false, false}); 117 model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4}); 118 model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8}); 119 model.Invoke(); 120 121 EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8})); 122 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); 123 } 124 125 TEST(SelectOpTest, SelectInt32) { 126 SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4}, 127 TensorType_INT32); 128 129 model.PopulateTensor<bool>(model.input1(), {false, true, false, false}); 130 model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4}); 131 model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8}); 132 model.Invoke(); 133 134 EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 2, 7, 8})); 135 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); 136 } 137 138 TEST(SelectOpTest, RankOneSelectInt32) { 139 SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32); 140 141 model.PopulateTensor<bool>(model.input1(), {false, true}); 142 model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4}); 143 model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8}); 144 model.Invoke(); 145 146 EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 3, 4})); 147 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1})); 148 } 149 150 TEST(SelectOpTest, RankZeroSelectInt32) { 151 SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32); 152 153 model.PopulateTensor<bool>(model.input1(), {false}); 154 model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4}); 155 model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8}); 156 model.Invoke(); 157 158 EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 7, 8})); 159 EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1})); 160 } 161 162 } // namespace 163 } // namespace tflite 164 165 int main(int argc, char** argv) { 166 ::tflite::LogToStderr(); 167 ::testing::InitGoogleTest(&argc, argv); 168 return RUN_ALL_TESTS(); 169 } 170