Home | History | Annotate | Download | only in lite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <gmock/gmock.h>
     17 #include <gtest/gtest.h>
     18 
     19 #include "tensorflow/contrib/lite/graph_info.h"
     20 #include "tensorflow/contrib/lite/testing/util.h"
     21 
     22 namespace tflite {
     23 namespace {
     24 
     25 // Makes a TfLiteIntArray* from std::vector, must free with TfLiteIntFree().
     26 TfLiteIntArray* ConvertVector(const std::vector<int>& x) {
     27   TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
     28   for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
     29   return lite;
     30 }
     31 
     32 // A very simple test graph that supports setting in/out tensors on nodes.
     33 class SimpleTestGraph : public GraphInfo {
     34  public:
     35   ~SimpleTestGraph() override {
     36     for (auto& node : nodes_) {
     37       TfLiteIntArrayFree(node.inputs);
     38       TfLiteIntArrayFree(node.outputs);
     39     }
     40   }
     41 
     42   size_t num_tensors() const override { return tensors_.size(); }
     43   size_t num_nodes() const override { return nodes_.size(); }
     44   const TfLiteNode& node(size_t index) const override { return nodes_[index]; }
     45   TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; }
     46   const std::vector<int>& inputs() const override { return inputs_; }
     47   const std::vector<int>& outputs() const override { return outputs_; }
     48 
     49   void AddNode(const std::vector<int>& inputs,
     50                const std::vector<int>& outputs) {
     51     nodes_.push_back(TfLiteNode());
     52     TfLiteNode& node = nodes_.back();
     53     node.inputs = ConvertVector(inputs);
     54     node.outputs = ConvertVector(outputs);
     55   }
     56 
     57   void AddTensors(int count) { tensors_.resize(count + tensors_.size()); }
     58 
     59   void SetInputsAndOutputs(const std::vector<int>& inputs,
     60                            const std::vector<int>& outputs) {
     61     inputs_ = inputs;
     62     outputs_ = outputs;
     63   }
     64 
     65  private:
     66   std::vector<TfLiteNode> nodes_;
     67   std::vector<TfLiteTensor> tensors_;
     68   std::vector<int> inputs_;
     69   std::vector<int> outputs_;
     70 };
     71 
     72 // Partition a graph to generate a list of subgraphs. This wraps the API call
     73 // we are testing and handles memory management and conversion to
     74 // TfLiteIntArray. Populates `subgraphs` with resulting generated subgraphs.
     75 void PartitionGraph(const SimpleTestGraph& graph,
     76                     const std::vector<int>& nodes_to_partition,
     77                     std::vector<Subgraph>* subgraphs) {
     78   TfLiteIntArray* nodes_to_partition_int_array =
     79       ConvertVector(nodes_to_partition);
     80   PartitionGraphIntoIndependentSubgraphs(&graph, nodes_to_partition_int_array,
     81                                          subgraphs);
     82   TfLiteIntArrayFree(nodes_to_partition_int_array);
     83 }
     84 
     85 // Check a generated list of subgraphs against the expected list of subgraphs.
     86 void CheckPartitionSubgraphs(const std::vector<Subgraph>& generated_subgraphs,
     87                              const std::vector<Subgraph>& expected_subgraphs) {
     88   ASSERT_EQ(generated_subgraphs.size(), expected_subgraphs.size());
     89   for (int subgraph_index = 0; subgraph_index < generated_subgraphs.size();
     90        subgraph_index++) {
     91     EXPECT_EQ(generated_subgraphs[subgraph_index].nodes,
     92               expected_subgraphs[subgraph_index].nodes);
     93     EXPECT_EQ(generated_subgraphs[subgraph_index].input_tensors,
     94               expected_subgraphs[subgraph_index].input_tensors);
     95     EXPECT_EQ(generated_subgraphs[subgraph_index].output_tensors,
     96               expected_subgraphs[subgraph_index].output_tensors);
     97   }
     98 }
     99 
    100 // Test an empty trivial graph with no partitions.
    101 TEST(PartitionTest, Nodes0_PartitionNodes0) {
    102   SimpleTestGraph graph;
    103   std::vector<int> nodes_to_partition = {};
    104   std::vector<Subgraph> generated_subgraphs;
    105   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    106   CheckPartitionSubgraphs(generated_subgraphs, {});
    107 }
    108 
    109 // Test a 1 node graph with no partitions.
    110 // Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[]
    111 // Output: [kTfNoPartition, tensor(0) -> node(0) -> tensor(1)]
    112 TEST(PartitionTest, Nodes1PartitionNodes0) {
    113   SimpleTestGraph graph;
    114   graph.AddTensors(2);
    115   graph.AddNode({0}, {1});
    116   graph.SetInputsAndOutputs({0}, {1});
    117   std::vector<int> nodes_to_partition = {};
    118   std::vector<Subgraph> generated_subgraphs;
    119   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    120 
    121   Subgraph expected_subgraph;
    122   expected_subgraph.type = Subgraph::kTfNonPartition;
    123   expected_subgraph.nodes = {0};
    124   expected_subgraph.input_tensors = {0};
    125   expected_subgraph.output_tensors = {1};
    126   CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
    127 }
    128 
    129 // Test a 1 node graph with no inputs that is fully partitioned.
    130 // Input: node(0) -> tensor(1), nodes_to_partition=[node0]
    131 // Output: [kTfPartition, node(0) -> tensor(1)]
    132 TEST(PartitionTest, Nodes1PartitionNodes0Inputs0) {
    133   SimpleTestGraph graph;
    134   graph.AddTensors(1);
    135   graph.AddNode({}, {0});
    136   graph.SetInputsAndOutputs({}, {0});
    137   std::vector<Subgraph> generated_subgraphs;
    138   std::vector<int> nodes_to_partition = {0};
    139   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    140 
    141   Subgraph expected_subgraph;
    142   expected_subgraph.type = Subgraph::kTfPartition;
    143   expected_subgraph.nodes = {0};
    144   expected_subgraph.input_tensors = {};
    145   expected_subgraph.output_tensors = {0};
    146   CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
    147 }
    148 
    149 // Test a 1 node graph that is partitioned completely.
    150 // Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[node0]
    151 // Output: [kTfPartition, tensor(0) -> node(0) -> tensor(1)]
    152 TEST(PartitionTest, Nodes1PartitionNodes1) {
    153   SimpleTestGraph graph;
    154   graph.AddTensors(2);
    155   graph.AddNode({0}, {1});
    156   graph.SetInputsAndOutputs({0}, {1});
    157   std::vector<int> nodes_to_partition = {0};
    158   std::vector<Subgraph> generated_subgraphs;
    159   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    160 
    161   Subgraph expected_subgraph;
    162   expected_subgraph.type = Subgraph::kTfPartition;
    163   expected_subgraph.nodes = {0};
    164   expected_subgraph.input_tensors = {0};
    165   expected_subgraph.output_tensors = {1};
    166   CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph});
    167 }
    168 
    169 // Test a 2 node graph where 1 node is partitioned and the other is not.
    170 // Input: tensor(0) -> node(0) -> tensor(1) -> node(1) -> tensor(2),
    171 //    nodes_to_partition = [1]
    172 // Output: [kTfNonPartition, tensor(0) -> node(0) -> tensor(1),
    173 //          kTfPartition, tensor(1) -> node(1), tensor(2)]
    174 TEST(PartitionTest, Nodes2PartitionNodes1) {
    175   SimpleTestGraph graph;
    176   graph.AddTensors(3);
    177   graph.AddNode({0}, {1});
    178   graph.AddNode({1}, {2});
    179   graph.SetInputsAndOutputs({0}, {2});
    180   std::vector<int> nodes_to_partition = {1};
    181   std::vector<Subgraph> generated_subgraphs;
    182   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    183 
    184   Subgraph expected_subgraph0;
    185   expected_subgraph0.type = Subgraph::kTfPartition;
    186   expected_subgraph0.nodes = {0};
    187   expected_subgraph0.input_tensors = {0};
    188   expected_subgraph0.output_tensors = {1};
    189   Subgraph expected_subgraph1;
    190   expected_subgraph1.type = Subgraph::kTfPartition;
    191   expected_subgraph1.nodes = {1};
    192   expected_subgraph1.input_tensors = {1};
    193   expected_subgraph1.output_tensors = {2};
    194   CheckPartitionSubgraphs(generated_subgraphs,
    195                           {expected_subgraph0, expected_subgraph1});
    196 }
    197 
    198 // Test a 2 node graph where both nodes are fully partitioned.
    199 // Input: tensor(0) -> node(0) -> tensor(1) -> node(1) -> tensor(2),
    200 //    nodes_to_partition = [0, 1]
    201 // Output: [kTfPartition, tensor(0) -> node(0) -> node(1) -> tensor(1)]
    202 TEST(PartitionTest, Nodes2PartitionNodes2) {
    203   SimpleTestGraph graph;
    204   graph.AddTensors(3);
    205   graph.AddNode({0}, {1});
    206   graph.AddNode({1}, {2});
    207   graph.SetInputsAndOutputs({0}, {2});
    208   std::vector<int> nodes_to_partition = {0, 1};
    209   std::vector<Subgraph> generated_subgraphs;
    210   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    211 
    212   Subgraph expected_subgraph0;
    213   expected_subgraph0.type = Subgraph::kTfPartition;
    214   expected_subgraph0.nodes = {0, 1};
    215   expected_subgraph0.input_tensors = {0};
    216   expected_subgraph0.output_tensors = {2};
    217   CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph0});
    218 }
    219 
    220 // Test a three node model where we want to partition nodes 0 and nodes
    221 // 2, but nodes 0 and nodes 2 cannot be in the same subgraph since node 2
    222 // depends on node 1 which depends on node 0. Thus, we need to produce three
    223 // subgraphs.
    224 //
    225 // Input: tensor(0) -> node(0) -> tensor(1)
    226 //        tensor(1) -> node(1) -> tensor(2)
    227 //        [tensor(2), tensor(1)] -> node(2) -> tensor(3)
    228 //    nodes_to_partition = [0, 2]
    229 // Output: [[kTfPartition, tensor(0) -> node(0) -> tensor(1),
    230 //          [kTfNonPartition, tensor(1) -> node(1) -> tensor(2)],
    231 //          [kTfPartition, [tensor(2), tensor(1)] -> node(2) -> node(3)]
    232 TEST(PartitionTest, Nodes3PartitionNodes2) {
    233   SimpleTestGraph graph;
    234   graph.AddTensors(4);
    235   graph.AddNode({0}, {1});
    236   graph.AddNode({1}, {2});
    237   graph.AddNode({1, 2}, {3});
    238   graph.SetInputsAndOutputs({0}, {3});
    239   std::vector<int> nodes_to_partition = {0, 2};
    240   std::vector<Subgraph> generated_subgraphs;
    241   PartitionGraph(graph, nodes_to_partition, &generated_subgraphs);
    242 
    243   Subgraph expected_subgraph0;
    244   expected_subgraph0.type = Subgraph::kTfPartition;
    245   expected_subgraph0.nodes = {0};
    246   expected_subgraph0.input_tensors = {0};
    247   expected_subgraph0.output_tensors = {1};
    248   Subgraph expected_subgraph1;
    249   expected_subgraph1.type = Subgraph::kTfNonPartition;
    250   expected_subgraph1.nodes = {1};
    251   expected_subgraph1.input_tensors = {1};
    252   expected_subgraph1.output_tensors = {2};
    253   Subgraph expected_subgraph2;
    254   expected_subgraph2.type = Subgraph::kTfPartition;
    255   expected_subgraph2.nodes = {2};
    256   expected_subgraph2.input_tensors = {1, 2};
    257   expected_subgraph2.output_tensors = {3};
    258   CheckPartitionSubgraphs(
    259       generated_subgraphs,
    260       {expected_subgraph0, expected_subgraph1, expected_subgraph2});
    261 }
    262 
    263 }  // namespace
    264 }  // namespace tflite
    265 
    266 int main(int argc, char** argv) {
    267   ::tflite::LogToStderr();
    268   ::testing::InitGoogleTest(&argc, argv);
    269   return RUN_ALL_TESTS();
    270 }
    271