1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/delegates/flex/delegate.h" 16 17 #include <gmock/gmock.h> 18 #include <gtest/gtest.h> 19 #include "tensorflow/lite/delegates/flex/test_util.h" 20 21 namespace tflite { 22 namespace flex { 23 namespace { 24 25 using ::testing::ElementsAre; 26 27 class DelegateTest : public testing::FlexModelTest { 28 public: 29 DelegateTest() { 30 delegate_ = FlexDelegate::Create(); 31 interpreter_.reset(new Interpreter(&error_reporter_)); 32 } 33 34 ~DelegateTest() override { 35 // The delegate needs to be destructed after the interpreter because the 36 // interpreter references data contained in the delegate. 37 interpreter_.reset(); 38 delegate_.reset(); 39 } 40 41 void ConfigureDelegate() { 42 ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), 43 kTfLiteOk); 44 } 45 46 private: 47 std::unique_ptr<FlexDelegate> delegate_; 48 }; 49 50 TEST_F(DelegateTest, FullGraph) { 51 // Define the graph. 52 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); 53 54 AddTfOp(testing::kUnpack, {0}, {1, 2}); 55 AddTfOp(testing::kUnpack, {3}, {4, 5}); 56 AddTfOp(testing::kAdd, {1, 4}, {6}); 57 AddTfOp(testing::kAdd, {2, 5}, {7}); 58 AddTfOp(testing::kMul, {6, 7}, {8}); 59 60 // Apply the delegate. 61 ConfigureDelegate(); 62 63 // Define inputs. 64 SetShape(0, {2, 2, 1}); 65 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 66 SetShape(3, {2, 2, 1}); 67 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); 68 69 ASSERT_TRUE(Invoke()); 70 71 ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); 72 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); 73 ASSERT_EQ(GetType(8), kTfLiteFloat32); 74 } 75 76 TEST_F(DelegateTest, NonFloatTypeInference) { 77 AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2}); 78 79 AddTfOp(testing::kAdd, {0, 1}, {2}); 80 81 ConfigureDelegate(); 82 83 SetShape(0, {2, 2}); 84 SetTypedValues<int>(0, {1, 2, 3, 4}); 85 SetShape(1, {2, 2}); 86 SetTypedValues<int>(1, {4, 3, 2, 1}); 87 88 ASSERT_TRUE(Invoke()); 89 90 ASSERT_THAT(GetShape(2), ElementsAre(2, 2)); 91 ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5)); 92 ASSERT_EQ(GetType(2), kTfLiteInt32); 93 } 94 95 TEST_F(DelegateTest, StringInference) { 96 AddTensors(3, {0, 1}, {2}, kTfLiteString, {2}); 97 98 AddTfOp(testing::kAdd, {0, 1}, {2}); 99 100 ConfigureDelegate(); 101 102 SetShape(0, {2, 2}); 103 SetStringValues(0, {"1", "2", "3", "4"}); 104 SetShape(1, {2, 2}); 105 SetStringValues(1, {"4", "3", "2", "1"}); 106 107 ASSERT_TRUE(Invoke()); 108 109 ASSERT_THAT(GetShape(2), ElementsAre(2, 2)); 110 ASSERT_THAT(GetStringValues(2), ElementsAre("14", "23", "32", "41")); 111 ASSERT_EQ(GetType(2), kTfLiteString); 112 } 113 114 TEST_F(DelegateTest, MixedGraph) { 115 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); 116 117 AddTfOp(testing::kUnpack, {0}, {1, 2}); 118 AddTfOp(testing::kUnpack, {3}, {4, 5}); 119 AddTfOp(testing::kAdd, {1, 4}, {6}); 120 AddTfOp(testing::kAdd, {2, 5}, {7}); 121 AddTfLiteMulOp({6, 7}, {8}); 122 123 ConfigureDelegate(); 124 125 SetShape(0, {2, 2, 1}); 126 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 127 SetShape(3, {2, 2, 1}); 128 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); 129 130 ASSERT_TRUE(Invoke()); 131 132 ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); 133 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); 134 } 135 136 TEST_F(DelegateTest, SplitGraph) { 137 AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); 138 139 AddTfOp(testing::kUnpack, {0}, {1, 2}); 140 AddTfOp(testing::kAdd, {1, 2}, {3}); 141 AddTfOp(testing::kUnpack, {3}, {4, 5}); 142 143 AddTfLiteMulOp({4, 5}, {6}); 144 145 AddTfOp(testing::kUnpack, {6}, {7, 8}); 146 AddTfOp(testing::kAdd, {7, 8}, {9}); 147 148 ConfigureDelegate(); 149 150 SetShape(0, {2, 2, 2, 1}); 151 SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); 152 153 ASSERT_TRUE(Invoke()); 154 155 ASSERT_THAT(GetShape(9), ElementsAre(1)); 156 ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); 157 } 158 159 TEST_F(DelegateTest, OnlyTFLite) { 160 // Only TFLite single op model. 161 AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); 162 AddTfLiteMulOp({0, 1}, {2}); 163 164 ConfigureDelegate(); 165 166 SetShape(0, {2, 2, 1}); 167 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 168 SetShape(1, {2, 2, 1}); 169 SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); 170 171 ASSERT_TRUE(Invoke()); 172 173 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); 174 ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); 175 } 176 177 TEST_F(DelegateTest, MultipleInvokeCalls) { 178 // Call Invoke() multiple times on the same model. 179 AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); 180 AddTfLiteMulOp({0, 1}, {2}); 181 182 ConfigureDelegate(); 183 184 SetShape(0, {2, 2, 1}); 185 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 186 SetShape(1, {2, 2, 1}); 187 SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); 188 189 ASSERT_TRUE(Invoke()); 190 191 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); 192 ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); 193 194 SetShape(0, {2, 2, 1}); 195 SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f}); 196 SetShape(1, {2, 2, 1}); 197 SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f}); 198 199 ASSERT_TRUE(Invoke()); 200 201 ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); 202 ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f)); 203 } 204 205 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { 206 // Build a graph, configure the delegate and set inputs. 207 { 208 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); 209 AddTfOp(testing::kUnpack, {0}, {1, 2}); 210 AddTfOp(testing::kUnpack, {3}, {4, 5}); 211 AddTfOp(testing::kAdd, {1, 4}, {6}); 212 AddTfOp(testing::kAdd, {2, 5}, {7}); 213 AddTfOp(testing::kMul, {6, 7}, {8}); 214 ConfigureDelegate(); 215 SetShape(0, {2, 2, 1}); 216 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 217 SetShape(3, {2, 2, 1}); 218 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); 219 } 220 221 // Create a new interpreter, inject into the test framework and build 222 // a different graph using the *same* delegate. 223 std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_)); 224 interpreter_.swap(interpreter); 225 { 226 AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); 227 AddTfOp(testing::kUnpack, {0}, {1, 2}); 228 AddTfOp(testing::kAdd, {1, 2}, {3}); 229 AddTfOp(testing::kUnpack, {3}, {4, 5}); 230 AddTfLiteMulOp({4, 5}, {6}); 231 AddTfOp(testing::kUnpack, {6}, {7, 8}); 232 AddTfOp(testing::kAdd, {7, 8}, {9}); 233 ConfigureDelegate(); 234 SetShape(0, {2, 2, 2, 1}); 235 SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); 236 } 237 238 // Swap back in the first interpreter and validate inference. 239 interpreter_.swap(interpreter); 240 { 241 ASSERT_TRUE(Invoke()); 242 EXPECT_THAT(GetShape(8), ElementsAre(2, 1)); 243 EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); 244 } 245 246 // Swap in the second interpreter and validate inference. 247 interpreter_.swap(interpreter); 248 { 249 ASSERT_TRUE(Invoke()); 250 EXPECT_THAT(GetShape(9), ElementsAre(1)); 251 EXPECT_THAT(GetValues(9), ElementsAre(10.0f)); 252 } 253 } 254 255 TEST_F(DelegateTest, SingleThreaded) { 256 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); 257 AddTfOp(testing::kUnpack, {0}, {1, 2}); 258 AddTfOp(testing::kUnpack, {3}, {4, 5}); 259 AddTfOp(testing::kAdd, {1, 4}, {6}); 260 AddTfOp(testing::kAdd, {2, 5}, {7}); 261 AddTfOp(testing::kMul, {6, 7}, {8}); 262 263 // Explicitly disable multi-threading before installing the delegate. 264 interpreter_->SetNumThreads(1); 265 ConfigureDelegate(); 266 267 SetShape(0, {2, 2, 1}); 268 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 269 SetShape(3, {2, 2, 1}); 270 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); 271 272 // Invocation should behave as expected. 273 ASSERT_TRUE(Invoke()); 274 275 ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); 276 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); 277 ASSERT_EQ(GetType(8), kTfLiteFloat32); 278 } 279 280 TEST_F(DelegateTest, MultiThreaded) { 281 AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); 282 AddTfOp(testing::kUnpack, {0}, {1, 2}); 283 AddTfOp(testing::kUnpack, {3}, {4, 5}); 284 AddTfOp(testing::kAdd, {1, 4}, {6}); 285 AddTfOp(testing::kAdd, {2, 5}, {7}); 286 AddTfOp(testing::kMul, {6, 7}, {8}); 287 288 // Explicitly enable multi-threading before installing the delegate. 289 interpreter_->SetNumThreads(4); 290 ConfigureDelegate(); 291 292 SetShape(0, {2, 2, 1}); 293 SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); 294 SetShape(3, {2, 2, 1}); 295 SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); 296 297 // Invocation should behave as expected. 298 ASSERT_TRUE(Invoke()); 299 300 ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); 301 ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); 302 ASSERT_EQ(GetType(8), kTfLiteFloat32); 303 } 304 305 } // namespace 306 } // namespace flex 307 } // namespace tflite 308 309 int main(int argc, char** argv) { 310 ::tflite::LogToStderr(); 311 ::testing::InitGoogleTest(&argc, argv); 312 return RUN_ALL_TESTS(); 313 } 314