Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPPLER_UTILS_H_
     17 #define TENSORFLOW_GRAPPLER_UTILS_H_
     18 
     19 #include <functional>
     20 #include <unordered_map>
     21 #include <unordered_set>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/graph.pb.h"
     25 #include "tensorflow/core/framework/node_def.pb.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/core/threadpool.h"
     29 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     30 
     31 namespace tensorflow {
     32 namespace grappler {
     33 
     34 // A utility class to lookup a node and its outputs by node name.
     35 class NodeMap {
     36  public:
     37   // Note: The NodeMap will store pointers to nodes in graph, which may become
     38   // invalid if graph is changed.
     39   explicit NodeMap(GraphDef* graph);
     40   NodeDef* GetNode(const string& name) const;
     41   bool NodeExists(const string& name) const;
     42   const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
     43   // This method doesn't record the outputs of the added node; the outputs need
     44   // to be explicitly added by the AddOutput method.
     45   void AddNode(const string& name, NodeDef* node);
     46   void RemoveNode(const string& name);
     47   void UpdateInput(const string& node_name, const string& old_input_name,
     48                    const string& new_input_name);
     49   void AddOutput(const string& node_name, const string& output_name);
     50   void RemoveInputs(const string& node_name);
     51   void RemoveOutput(const string& node_name, const string& output_name);
     52   void RemoveOutputs(const string& node_name);
     53   void UpdateOutput(const string& node_name, const string& old_output_name,
     54                     const string& new_output_name);
     55 
     56  private:
     57   const std::set<NodeDef*> empty_set_;
     58   std::unordered_map<string, NodeDef*> nodes_;
     59   std::unordered_map<string, std::set<NodeDef*>> outputs_;
     60 };
     61 
     62 // A vector with a set. The set stores the same elements as the vector, and
     63 // quickly answers whether a value is in the vector. Duplicated elements are not
     64 // allowed for now.
     65 template <class T>
     66 class SetVector {
     67  public:
     68   // Returns false if value already existed in the set, true otherwise.
     69   bool PushBack(const T& value) {
     70     if (!set_.insert(value).second) {
     71       return false;
     72     }
     73     vector_.push_back(value);
     74     return true;
     75   }
     76 
     77   T PopBack() {
     78     T back = vector_.back();
     79     set_.erase(back);
     80     vector_.pop_back();
     81     return back;
     82   }
     83 
     84   bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
     85 
     86   bool Empty() const { return vector_.empty(); }
     87 
     88   void Reserve(int64 size) { vector_.reserve(size); }
     89 
     90  private:
     91   std::unordered_set<T> set_;
     92   std::vector<T> vector_;
     93 };
     94 
     95 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with
     96 // the ^ character.
     97 bool IsControlInput(const string& name);
     98 
     99 // True iff 'name1' and 'name2' refer to the same input.
    100 bool IsSameInput(const string& name1, const string& name2);
    101 
    102 // Return the node name corresponding to 'name' if name is valid, or the empty
    103 // string otherwise.
    104 string NodeName(const string& name);
    105 
    106 // Get the trailing position number ":{digits}" (if any) of a node name.
    107 int NodePosition(const string& name);
    108 
    109 // Returns the node name and position in a single call.
    110 string ParseNodeName(const string& name, int* position);
    111 
    112 // Add a prefix to a node name with a custom delimiter.
    113 string AddPrefixToNodeName(const string& name, const string& prefix,
    114                            const string& delimiter);
    115 
    116 // Add a prefix to a node name.
    117 string AddPrefixToNodeName(const string& name, const string& prefix);
    118 
    119 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured
    120 // timeout (in milliseconds) for 'fn' to complete, before returning false.
    121 //
    122 // If returning false, the 'fn' may still continue to execute in the
    123 // thread-pool. It is the responsibility of the caller to reset the thread-pool
    124 // as appropriate.
    125 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
    126                         thread::ThreadPool* thread_pool);
    127 
    128 // Returns the node name prefixed with conventional symbol '^'
    129 // for control dependency, given a NodeDef.
    130 string AsControlDependency(const NodeDef& node);
    131 
    132 // Returns the node name prefixed with conventional symbol '^'
    133 // for control dependency, given a node name
    134 string AsControlDependency(const string& node);
    135 
    136 // Returns the number of outputs of a node according to its OpDef. Note that
    137 // some of the outputs may be unconnected.
    138 int NumOutputs(const NodeDef& node, GraphDef* graph);
    139 
    140 // Number of connected non-control inputs.
    141 int NumNonControlInputs(const NodeDef& node);
    142 
    143 // Number of connected non-control outputs.
    144 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
    145 
    146 // Removes redundant control inputs from node.
    147 void DedupControlInputs(NodeDef* node);
    148 
    149 // Returns the data type in attribute `attr_name` of `node`. If that attribute
    150 // doesn't exist, returns DT_INVALID.
    151 DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name);
    152 
    153 // Returns the last node in the simple chain starting at source and traversing
    154 // through the input(0) edge from each node as long as the next node satisfies
    155 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source
    156 // will be returned. Example: For the chain
    157 //    source <- a <- b <- ... <- y <- z
    158 // where
    159 //    pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
    160 //    pred_fn(z) = false,
    161 // the return value will be a pointer to y.
    162 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
    163                         bool follow_control_input,
    164                         const std::function<bool(const NodeDef&)>& pred_fn);
    165 
    166 // Permute the nodes of graph in place according to the permutation.
    167 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
    168                          bool invert_permutation);
    169 
    170 class SimpleGraphView {
    171  public:
    172   Status Initialize(const GraphDef& graph) {
    173     return Initialize(graph, true, true);
    174   }
    175   Status Initialize(const GraphDef& graph, bool dedup_inputs,
    176                     bool dedup_outputs);
    177 
    178   inline int num_nodes() const { return index_to_name_.size(); }
    179   inline const int index(const string& node_name) const {
    180     const auto& it = name_to_index_.find(node_name);
    181     DCHECK(it != name_to_index_.end());
    182     return it == name_to_index_.end() ? -1 : it->second;
    183   }
    184   inline const string& node_name(int node_idx) const {
    185     return index_to_name_[node_idx];
    186   }
    187   inline const gtl::InlinedVector<int, 4>& inputs(int node_idx) const {
    188     return inputs_[node_idx];
    189   }
    190   inline const gtl::InlinedVector<int, 2>& outputs(int node_idx) const {
    191     return outputs_[node_idx];
    192   }
    193 
    194   string PrintToString() const;
    195 
    196  private:
    197   std::vector<string> index_to_name_;
    198   std::unordered_map<string, int> name_to_index_;
    199   std::vector<gtl::InlinedVector<int, 4>> inputs_;
    200   std::vector<gtl::InlinedVector<int, 2>> outputs_;
    201 };
    202 
    203 }  // end namespace grappler
    204 }  // end namespace tensorflow
    205 
    206 #endif  // TENSORFLOW_GRAPPLER_UTILS_H_
    207