1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/lite/kernels/subgraph_test_util.h" 17 #include <gtest/gtest.h> 18 #include "tensorflow/lite/interpreter.h" 19 #include "tensorflow/lite/kernels/test_util.h" 20 21 namespace tflite { 22 23 namespace subgraph_test_util { 24 25 namespace { 26 27 class SubgraphBuilderTest : public ::testing::Test { 28 public: 29 SubgraphBuilderTest() 30 : interpreter_(new Interpreter), builder_(new SubgraphBuilder) {} 31 32 ~SubgraphBuilderTest() override { 33 interpreter_.reset(); 34 builder_.reset(); 35 } 36 37 protected: 38 void TestAccumelateLoopBody(int input1, int input2, int output1, 39 int output2) { 40 interpreter_.reset(new Interpreter); 41 builder_->BuildAccumulateLoopBodySubgraph( 42 &interpreter_->primary_subgraph()); 43 44 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); 45 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1}); 46 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); 47 48 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {input1}); 49 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {input2}); 50 51 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); 52 TfLiteTensor* output_tensor1 = 53 interpreter_->tensor(interpreter_->outputs()[0]); 54 CheckIntTensor(output_tensor1, {1}, {output1}); 55 TfLiteTensor* output_tensor2 = 56 interpreter_->tensor(interpreter_->outputs()[1]); 57 CheckIntTensor(output_tensor2, {1}, {output2}); 58 } 59 60 std::unique_ptr<Interpreter> interpreter_; 61 std::unique_ptr<SubgraphBuilder> builder_; 62 }; 63 64 TEST_F(SubgraphBuilderTest, TestBuildAddSubgraph) { 65 builder_->BuildAddSubgraph(&interpreter_->primary_subgraph()); 66 67 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2}); 68 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2}); 69 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); 70 71 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7}); 72 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2}); 73 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); 74 75 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); 76 CheckIntTensor(output, {1, 2}, {6, 9}); 77 } 78 79 TEST_F(SubgraphBuilderTest, TestBuildMulSubgraph) { 80 builder_->BuildMulSubgraph(&interpreter_->primary_subgraph()); 81 82 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2}); 83 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2}); 84 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); 85 86 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7}); 87 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2}); 88 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); 89 90 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); 91 CheckIntTensor(output, {1, 2}, {5, 14}); 92 } 93 94 TEST_F(SubgraphBuilderTest, TestBuildPadSubgraph) { 95 builder_->BuildPadSubgraph(&interpreter_->primary_subgraph()); 96 97 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {2}); 98 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1, 2}); 99 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); 100 101 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {5, 7}); 102 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {1, 2}); 103 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); 104 105 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); 106 CheckIntTensor(output, {5}, {0, 5, 7, 0, 0}); 107 } 108 109 TEST_F(SubgraphBuilderTest, TestBuildLessEqualCondSubgraph) { 110 builder_->BuildLessEqualCondSubgraph(&interpreter_->primary_subgraph(), 3); 111 112 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {5}); 113 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {10, 10}); 114 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); 115 116 // Test [1, 2, 3, 4, 5] <= 3 == [true, true, true, false, false] 117 // (with broadcasting). 118 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), 119 {1, 2, 3, 4, 5}); 120 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); 121 TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]); 122 CheckBoolTensor(output, {5}, {true, true, true, false, false}); 123 } 124 125 TEST_F(SubgraphBuilderTest, TestBuildAccumulateLoopBodySubgraph) { 126 TestAccumelateLoopBody(1, 1, 2, 3); 127 TestAccumelateLoopBody(2, 3, 3, 6); 128 TestAccumelateLoopBody(3, 6, 4, 10); 129 } 130 131 TEST_F(SubgraphBuilderTest, TestBuildPadLoopBodySubgraph) { 132 builder_->BuildPadLoopBodySubgraph(&interpreter_->primary_subgraph(), {1, 2}); 133 134 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1}); 135 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {5}); 136 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); 137 138 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1}); 139 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), 140 {0, 5, 7, 0, 0}); 141 142 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk); 143 TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]); 144 CheckIntTensor(output1, {1}, {2}); 145 TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]); 146 CheckIntTensor(output2, {8}, {0, 0, 5, 7, 0, 0, 0, 0}); 147 } 148 149 } // namespace 150 } // namespace subgraph_test_util 151 } // namespace tflite 152 153 int main(int argc, char** argv) { 154 ::tflite::LogToStderr(); 155 ::testing::InitGoogleTest(&argc, argv); 156 return RUN_ALL_TESTS(); 157 } 158