Home | History | Annotate | Download | only in flex
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/lite/delegates/flex/delegate.h"
     16 
     17 #include <gmock/gmock.h>
     18 #include <gtest/gtest.h>
     19 #include "tensorflow/lite/delegates/flex/test_util.h"
     20 
     21 namespace tflite {
     22 namespace flex {
     23 namespace {
     24 
     25 using ::testing::ElementsAre;
     26 
     27 class DelegateTest : public testing::FlexModelTest {
     28  public:
     29   DelegateTest() {
     30     delegate_ = FlexDelegate::Create();
     31     interpreter_.reset(new Interpreter(&error_reporter_));
     32   }
     33 
     34   ~DelegateTest() override {
     35     // The delegate needs to be destructed after the interpreter because the
     36     // interpreter references data contained in the delegate.
     37     interpreter_.reset();
     38     delegate_.reset();
     39   }
     40 
     41   void ConfigureDelegate() {
     42     ASSERT_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
     43               kTfLiteOk);
     44   }
     45 
     46  private:
     47   std::unique_ptr<FlexDelegate> delegate_;
     48 };
     49 
     50 TEST_F(DelegateTest, FullGraph) {
     51   // Define the graph.
     52   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
     53 
     54   AddTfOp(testing::kUnpack, {0}, {1, 2});
     55   AddTfOp(testing::kUnpack, {3}, {4, 5});
     56   AddTfOp(testing::kAdd, {1, 4}, {6});
     57   AddTfOp(testing::kAdd, {2, 5}, {7});
     58   AddTfOp(testing::kMul, {6, 7}, {8});
     59 
     60   // Apply the delegate.
     61   ConfigureDelegate();
     62 
     63   // Define inputs.
     64   SetShape(0, {2, 2, 1});
     65   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
     66   SetShape(3, {2, 2, 1});
     67   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
     68 
     69   ASSERT_TRUE(Invoke());
     70 
     71   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
     72   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
     73   ASSERT_EQ(GetType(8), kTfLiteFloat32);
     74 }
     75 
     76 TEST_F(DelegateTest, NonFloatTypeInference) {
     77   AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
     78 
     79   AddTfOp(testing::kAdd, {0, 1}, {2});
     80 
     81   ConfigureDelegate();
     82 
     83   SetShape(0, {2, 2});
     84   SetTypedValues<int>(0, {1, 2, 3, 4});
     85   SetShape(1, {2, 2});
     86   SetTypedValues<int>(1, {4, 3, 2, 1});
     87 
     88   ASSERT_TRUE(Invoke());
     89 
     90   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
     91   ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
     92   ASSERT_EQ(GetType(2), kTfLiteInt32);
     93 }
     94 
     95 TEST_F(DelegateTest, StringInference) {
     96   AddTensors(3, {0, 1}, {2}, kTfLiteString, {2});
     97 
     98   AddTfOp(testing::kAdd, {0, 1}, {2});
     99 
    100   ConfigureDelegate();
    101 
    102   SetShape(0, {2, 2});
    103   SetStringValues(0, {"1", "2", "3", "4"});
    104   SetShape(1, {2, 2});
    105   SetStringValues(1, {"4", "3", "2", "1"});
    106 
    107   ASSERT_TRUE(Invoke());
    108 
    109   ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
    110   ASSERT_THAT(GetStringValues(2), ElementsAre("14", "23", "32", "41"));
    111   ASSERT_EQ(GetType(2), kTfLiteString);
    112 }
    113 
    114 TEST_F(DelegateTest, MixedGraph) {
    115   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
    116 
    117   AddTfOp(testing::kUnpack, {0}, {1, 2});
    118   AddTfOp(testing::kUnpack, {3}, {4, 5});
    119   AddTfOp(testing::kAdd, {1, 4}, {6});
    120   AddTfOp(testing::kAdd, {2, 5}, {7});
    121   AddTfLiteMulOp({6, 7}, {8});
    122 
    123   ConfigureDelegate();
    124 
    125   SetShape(0, {2, 2, 1});
    126   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
    127   SetShape(3, {2, 2, 1});
    128   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
    129 
    130   ASSERT_TRUE(Invoke());
    131 
    132   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
    133   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
    134 }
    135 
    136 TEST_F(DelegateTest, SplitGraph) {
    137   AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
    138 
    139   AddTfOp(testing::kUnpack, {0}, {1, 2});
    140   AddTfOp(testing::kAdd, {1, 2}, {3});
    141   AddTfOp(testing::kUnpack, {3}, {4, 5});
    142 
    143   AddTfLiteMulOp({4, 5}, {6});
    144 
    145   AddTfOp(testing::kUnpack, {6}, {7, 8});
    146   AddTfOp(testing::kAdd, {7, 8}, {9});
    147 
    148   ConfigureDelegate();
    149 
    150   SetShape(0, {2, 2, 2, 1});
    151   SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
    152 
    153   ASSERT_TRUE(Invoke());
    154 
    155   ASSERT_THAT(GetShape(9), ElementsAre(1));
    156   ASSERT_THAT(GetValues(9), ElementsAre(10.0f));
    157 }
    158 
    159 TEST_F(DelegateTest, OnlyTFLite) {
    160   // Only TFLite single op model.
    161   AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
    162   AddTfLiteMulOp({0, 1}, {2});
    163 
    164   ConfigureDelegate();
    165 
    166   SetShape(0, {2, 2, 1});
    167   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
    168   SetShape(1, {2, 2, 1});
    169   SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
    170 
    171   ASSERT_TRUE(Invoke());
    172 
    173   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
    174   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
    175 }
    176 
    177 TEST_F(DelegateTest, MultipleInvokeCalls) {
    178   // Call Invoke() multiple times on the same model.
    179   AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
    180   AddTfLiteMulOp({0, 1}, {2});
    181 
    182   ConfigureDelegate();
    183 
    184   SetShape(0, {2, 2, 1});
    185   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
    186   SetShape(1, {2, 2, 1});
    187   SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
    188 
    189   ASSERT_TRUE(Invoke());
    190 
    191   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
    192   ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
    193 
    194   SetShape(0, {2, 2, 1});
    195   SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
    196   SetShape(1, {2, 2, 1});
    197   SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
    198 
    199   ASSERT_TRUE(Invoke());
    200 
    201   ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
    202   ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
    203 }
    204 
    205 TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
    206   // Build a graph, configure the delegate and set inputs.
    207   {
    208     AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
    209     AddTfOp(testing::kUnpack, {0}, {1, 2});
    210     AddTfOp(testing::kUnpack, {3}, {4, 5});
    211     AddTfOp(testing::kAdd, {1, 4}, {6});
    212     AddTfOp(testing::kAdd, {2, 5}, {7});
    213     AddTfOp(testing::kMul, {6, 7}, {8});
    214     ConfigureDelegate();
    215     SetShape(0, {2, 2, 1});
    216     SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
    217     SetShape(3, {2, 2, 1});
    218     SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
    219   }
    220 
    221   // Create a new interpreter, inject into the test framework and build
    222   // a different graph using the *same* delegate.
    223   std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_));
    224   interpreter_.swap(interpreter);
    225   {
    226     AddTensors(10, {0}, {9}, kTfLiteFloat32, {3});
    227     AddTfOp(testing::kUnpack, {0}, {1, 2});
    228     AddTfOp(testing::kAdd, {1, 2}, {3});
    229     AddTfOp(testing::kUnpack, {3}, {4, 5});
    230     AddTfLiteMulOp({4, 5}, {6});
    231     AddTfOp(testing::kUnpack, {6}, {7, 8});
    232     AddTfOp(testing::kAdd, {7, 8}, {9});
    233     ConfigureDelegate();
    234     SetShape(0, {2, 2, 2, 1});
    235     SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f});
    236   }
    237 
    238   // Swap back in the first interpreter and validate inference.
    239   interpreter_.swap(interpreter);
    240   {
    241     ASSERT_TRUE(Invoke());
    242     EXPECT_THAT(GetShape(8), ElementsAre(2, 1));
    243     EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
    244   }
    245 
    246   // Swap in the second interpreter and validate inference.
    247   interpreter_.swap(interpreter);
    248   {
    249     ASSERT_TRUE(Invoke());
    250     EXPECT_THAT(GetShape(9), ElementsAre(1));
    251     EXPECT_THAT(GetValues(9), ElementsAre(10.0f));
    252   }
    253 }
    254 
    255 TEST_F(DelegateTest, SingleThreaded) {
    256   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
    257   AddTfOp(testing::kUnpack, {0}, {1, 2});
    258   AddTfOp(testing::kUnpack, {3}, {4, 5});
    259   AddTfOp(testing::kAdd, {1, 4}, {6});
    260   AddTfOp(testing::kAdd, {2, 5}, {7});
    261   AddTfOp(testing::kMul, {6, 7}, {8});
    262 
    263   // Explicitly disable multi-threading before installing the delegate.
    264   interpreter_->SetNumThreads(1);
    265   ConfigureDelegate();
    266 
    267   SetShape(0, {2, 2, 1});
    268   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
    269   SetShape(3, {2, 2, 1});
    270   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
    271 
    272   // Invocation should behave as expected.
    273   ASSERT_TRUE(Invoke());
    274 
    275   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
    276   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
    277   ASSERT_EQ(GetType(8), kTfLiteFloat32);
    278 }
    279 
    280 TEST_F(DelegateTest, MultiThreaded) {
    281   AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3});
    282   AddTfOp(testing::kUnpack, {0}, {1, 2});
    283   AddTfOp(testing::kUnpack, {3}, {4, 5});
    284   AddTfOp(testing::kAdd, {1, 4}, {6});
    285   AddTfOp(testing::kAdd, {2, 5}, {7});
    286   AddTfOp(testing::kMul, {6, 7}, {8});
    287 
    288   // Explicitly enable multi-threading before installing the delegate.
    289   interpreter_->SetNumThreads(4);
    290   ConfigureDelegate();
    291 
    292   SetShape(0, {2, 2, 1});
    293   SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
    294   SetShape(3, {2, 2, 1});
    295   SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f});
    296 
    297   // Invocation should behave as expected.
    298   ASSERT_TRUE(Invoke());
    299 
    300   ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
    301   ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
    302   ASSERT_EQ(GetType(8), kTfLiteFloat32);
    303 }
    304 
    305 }  // namespace
    306 }  // namespace flex
    307 }  // namespace tflite
    308 
    309 int main(int argc, char** argv) {
    310   ::tflite::LogToStderr();
    311   ::testing::InitGoogleTest(&argc, argv);
    312   return RUN_ALL_TESTS();
    313 }
    314