1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_GRAPPLER_UTILS_H_ 17 #define TENSORFLOW_GRAPPLER_UTILS_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 #include <unordered_set> 22 #include <vector> 23 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/core/threadpool.h" 29 #include "tensorflow/core/lib/gtl/inlined_vector.h" 30 31 namespace tensorflow { 32 namespace grappler { 33 34 // A utility class to lookup a node and its outputs by node name. 35 class NodeMap { 36 public: 37 // Note: The NodeMap will store pointers to nodes in graph, which may become 38 // invalid if graph is changed. 39 explicit NodeMap(GraphDef* graph); 40 NodeDef* GetNode(const string& name) const; 41 bool NodeExists(const string& name) const; 42 const std::set<NodeDef*>& GetOutputs(const string& node_name) const; 43 // This method doesn't record the outputs of the added node; the outputs need 44 // to be explicitly added by the AddOutput method. 45 void AddNode(const string& name, NodeDef* node); 46 void RemoveNode(const string& name); 47 void UpdateInput(const string& node_name, const string& old_input_name, 48 const string& new_input_name); 49 void AddOutput(const string& node_name, const string& output_name); 50 void RemoveInputs(const string& node_name); 51 void RemoveOutput(const string& node_name, const string& output_name); 52 void RemoveOutputs(const string& node_name); 53 void UpdateOutput(const string& node_name, const string& old_output_name, 54 const string& new_output_name); 55 56 private: 57 const std::set<NodeDef*> empty_set_; 58 std::unordered_map<string, NodeDef*> nodes_; 59 std::unordered_map<string, std::set<NodeDef*>> outputs_; 60 }; 61 62 // A vector with a set. The set stores the same elements as the vector, and 63 // quickly answers whether a value is in the vector. Duplicated elements are not 64 // allowed for now. 65 template <class T> 66 class SetVector { 67 public: 68 // Returns false if value already existed in the set, true otherwise. 69 bool PushBack(const T& value) { 70 if (!set_.insert(value).second) { 71 return false; 72 } 73 vector_.push_back(value); 74 return true; 75 } 76 77 T PopBack() { 78 T back = vector_.back(); 79 set_.erase(back); 80 vector_.pop_back(); 81 return back; 82 } 83 84 bool Exists(const T& value) const { return set_.find(value) != set_.end(); } 85 86 bool Empty() const { return vector_.empty(); } 87 88 void Reserve(int64 size) { vector_.reserve(size); } 89 90 private: 91 std::unordered_set<T> set_; 92 std::vector<T> vector_; 93 }; 94 95 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with 96 // the ^ character. 97 bool IsControlInput(const string& name); 98 99 // True iff 'name1' and 'name2' refer to the same input. 100 bool IsSameInput(const string& name1, const string& name2); 101 102 // Return the node name corresponding to 'name' if name is valid, or the empty 103 // string otherwise. 104 string NodeName(const string& name); 105 106 // Get the trailing position number ":{digits}" (if any) of a node name. 107 int NodePosition(const string& name); 108 109 // Returns the node name and position in a single call. 110 string ParseNodeName(const string& name, int* position); 111 112 // Add a prefix to a node name with a custom delimiter. 113 string AddPrefixToNodeName(const string& name, const string& prefix, 114 const string& delimiter); 115 116 // Add a prefix to a node name. 117 string AddPrefixToNodeName(const string& name, const string& prefix); 118 119 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured 120 // timeout (in milliseconds) for 'fn' to complete, before returning false. 121 // 122 // If returning false, the 'fn' may still continue to execute in the 123 // thread-pool. It is the responsibility of the caller to reset the thread-pool 124 // as appropriate. 125 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms, 126 thread::ThreadPool* thread_pool); 127 128 // Returns the node name prefixed with conventional symbol '^' 129 // for control dependency, given a NodeDef. 130 string AsControlDependency(const NodeDef& node); 131 132 // Returns the node name prefixed with conventional symbol '^' 133 // for control dependency, given a node name 134 string AsControlDependency(const string& node); 135 136 // Returns the number of outputs of a node according to its OpDef. Note that 137 // some of the outputs may be unconnected. 138 int NumOutputs(const NodeDef& node, GraphDef* graph); 139 140 // Number of connected non-control inputs. 141 int NumNonControlInputs(const NodeDef& node); 142 143 // Number of connected non-control outputs. 144 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); 145 146 // Removes redundant control inputs from node. 147 void DedupControlInputs(NodeDef* node); 148 149 // Returns the data type in attribute `attr_name` of `node`. If that attribute 150 // doesn't exist, returns DT_INVALID. 151 DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name); 152 153 // Returns the last node in the simple chain starting at source and traversing 154 // through the input(0) edge from each node as long as the next node satisfies 155 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source 156 // will be returned. Example: For the chain 157 // source <- a <- b <- ... <- y <- z 158 // where 159 // pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true, 160 // pred_fn(z) = false, 161 // the return value will be a pointer to y. 162 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, 163 bool follow_control_input, 164 const std::function<bool(const NodeDef&)>& pred_fn); 165 166 // Permute the nodes of graph in place according to the permutation. 167 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, 168 bool invert_permutation); 169 170 class SimpleGraphView { 171 public: 172 Status Initialize(const GraphDef& graph) { 173 return Initialize(graph, true, true); 174 } 175 Status Initialize(const GraphDef& graph, bool dedup_inputs, 176 bool dedup_outputs); 177 178 inline int num_nodes() const { return index_to_name_.size(); } 179 inline const int index(const string& node_name) const { 180 const auto& it = name_to_index_.find(node_name); 181 DCHECK(it != name_to_index_.end()); 182 return it == name_to_index_.end() ? -1 : it->second; 183 } 184 inline const string& node_name(int node_idx) const { 185 return index_to_name_[node_idx]; 186 } 187 inline const gtl::InlinedVector<int, 4>& inputs(int node_idx) const { 188 return inputs_[node_idx]; 189 } 190 inline const gtl::InlinedVector<int, 2>& outputs(int node_idx) const { 191 return outputs_[node_idx]; 192 } 193 194 string PrintToString() const; 195 196 private: 197 std::vector<string> index_to_name_; 198 std::unordered_map<string, int> name_to_index_; 199 std::vector<gtl::InlinedVector<int, 4>> inputs_; 200 std::vector<gtl::InlinedVector<int, 2>> outputs_; 201 }; 202 203 } // end namespace grappler 204 } // end namespace tensorflow 205 206 #endif // TENSORFLOW_GRAPPLER_UTILS_H_ 207