Home | History | Annotate | Download | only in lite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
     16 #define TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/contrib/lite/context.h"
     21 
     22 namespace tflite {
     23 
     24 // Basic information about an inference graph, where execution nodes
     25 // are connected via tensors.
     26 class GraphInfo {
     27  public:
     28   virtual ~GraphInfo() {}
     29 
     30   // Total number of tensors in the graph.
     31   virtual size_t num_tensors() const = 0;
     32 
     33   // Returns a tensor given its index which is expected to be between 0 and
     34   // num_tensors().
     35   virtual TfLiteTensor* tensor(size_t index) = 0;
     36 
     37   // Total number of nodes in the graph.
     38   virtual size_t num_nodes() const = 0;
     39 
     40   // Returns a node given its index which is expected to be between 0 and
     41   // num_nodes().
     42   virtual const TfLiteNode& node(size_t index) const = 0;
     43 
     44   // Returns the indices of the input tensors.
     45   virtual const std::vector<int>& inputs() const = 0;
     46 
     47   // Returns the indices of the output tensors.
     48   virtual const std::vector<int>& outputs() const = 0;
     49 };
     50 
     51 // Represents a subgraph of a TensorFlow Lite graph.
     52 struct Subgraph {
     53   enum Type {
     54     kTfUnexplored = 0,  // temporarily used during creation
     55     kTfPartition,
     56     kTfNonPartition
     57   };
     58   Type type = kTfUnexplored;
     59   // Nodes within the subgraph
     60   std::vector<int> nodes;
     61   // Tensors that stride output from another subgraph that this depends on,
     62   // or global inputs to the TensorFlow Lite full graph.
     63   std::vector<int> input_tensors;
     64   // Outputs that are consumed by other subgraphs or are global output tensors.
     65   // All output tensors of the nodes in the subgraph that do not appear in this
     66   // list are intermediate results that can be potentially elided.
     67   std::vector<int> output_tensors;
     68 };
     69 
     70 // Partitions a list of node indices `nodes_to_partition` into subgraphs.
     71 // Each subgraph is in dependency order (i.e. all members of the subgraph).
     72 // `subgraphs` is assumed to be empty.
     73 TfLiteStatus PartitionGraphIntoIndependentSubgraphs(
     74     const GraphInfo* info, const TfLiteIntArray* nodes_to_partition,
     75     std::vector<Subgraph>* subgraphs);
     76 
     77 }  // namespace tflite
     78 
     79 #endif  // TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_
     80