1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_H_ 18 19 #include <functional> 20 #include <iterator> 21 #include <set> 22 #include <unordered_set> 23 #include <utility> 24 #include <vector> 25 #include "absl/types/span.h" 26 #include "tensorflow/core/framework/graph.pb.h" 27 #include "tensorflow/core/framework/node_def.pb.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/graph/tensor_id.h" 31 #include "tensorflow/core/lib/core/status.h" 32 #include "tensorflow/core/lib/core/stringpiece.h" 33 #include "tensorflow/core/lib/core/threadpool.h" 34 #include "tensorflow/core/lib/gtl/flatmap.h" 35 #include "tensorflow/core/lib/gtl/flatset.h" 36 #include "tensorflow/core/lib/gtl/inlined_vector.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace tensorflow { 40 namespace grappler { 41 42 // A utility class to lookup a node and its outputs by node name. 43 class NodeMap { 44 public: 45 // Note: The NodeMap will store pointers to nodes in graph, which may become 46 // invalid if graph is changed. 47 explicit NodeMap(GraphDef* graph); 48 NodeDef* GetNode(const string& name) const; 49 bool NodeExists(const string& name) const; 50 const std::set<NodeDef*>& GetOutputs(const string& node_name) const; 51 // This method doesn't record the outputs of the added node; the outputs need 52 // to be explicitly added by the AddOutput method. 53 void AddNode(const string& name, NodeDef* node); 54 void RemoveNode(const string& name); 55 void UpdateInput(const string& node_name, const string& old_input_name, 56 const string& new_input_name); 57 void AddOutput(const string& node_name, const string& output_name); 58 void RemoveInputs(const string& node_name); 59 void RemoveOutput(const string& node_name, const string& output_name); 60 void RemoveOutputs(const string& node_name); 61 void UpdateOutput(const string& node_name, const string& old_output_name, 62 const string& new_output_name); 63 64 private: 65 const std::set<NodeDef*> empty_set_; 66 gtl::FlatMap<string, NodeDef*> nodes_; 67 gtl::FlatMap<string, std::set<NodeDef*>> outputs_; 68 }; 69 70 // A vector with a set. The set stores the same elements as the vector, and 71 // quickly answers whether a value is in the vector. Duplicated elements are not 72 // allowed for now. 73 template <class T, class Hash = std::hash<T>> 74 class SetVector { 75 public: 76 // Returns false if value already existed in the set, true otherwise. 77 bool PushBack(const T& value) { 78 if (!set_.insert(value).second) { 79 return false; 80 } 81 vector_.push_back(value); 82 return true; 83 } 84 85 T PopBack() { 86 T back = vector_.back(); 87 set_.erase(back); 88 vector_.pop_back(); 89 return back; 90 } 91 92 bool Exists(const T& value) const { return set_.find(value) != set_.end(); } 93 94 bool Empty() const { return vector_.empty(); } 95 96 void Reserve(int64 size) { vector_.reserve(size); } 97 98 private: 99 gtl::FlatSet<T, Hash> set_; 100 std::vector<T> vector_; 101 }; 102 103 // Returns formatted string from TensorId specific to grappler. Specifically, 104 // for the 0 port (first output), only the node name is returned. 105 string TensorIdToString(const TensorId& tensor_id); 106 107 // True iff 'name' refers to a control inputs, i.e. a node name prefixed with 108 // the ^ character. 109 bool IsControlInput(const string& name); 110 111 // True iff tensor index refers to a control input. 112 bool IsControlInput(const TensorId& tensor_id); 113 114 // True iff 'name1' and 'name2' refer to the same input. 115 bool IsSameInput(const string& name1, const string& name2); 116 117 // Returns the trailing position number (or zero if no number is present) if 118 // NodeName(input_name) is equal to node_name. Returns -1 for control inputs. 119 // Returns -2 if NodeName(input_name) is not equal to node_name. 120 // Note: This function is used very heavily, and this hand-optimized 121 // version is 3-4x faster than the version using Scanner, which it replaced. 122 // This is worth the reduction in readability. 123 inline int NodePositionIfSameNode(const string& input_name, 124 const string& node_name) { 125 if (input_name.empty()) return -2; 126 const bool is_ctrl = input_name[0] == '^'; 127 auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); 128 auto node_it = node_name.begin(); 129 if (node_name.empty() || 130 std::distance(input_it, input_name.end()) < node_name.size()) { 131 return -2; 132 } 133 while (node_it != node_name.end()) { 134 if (*input_it++ != *node_it++) { 135 return -2; 136 } 137 } 138 if (input_it == input_name.end()) { 139 return is_ctrl ? -1 : 0; 140 } else if (*input_it++ == ':') { 141 StringPiece remaining(&(*input_it), 142 std::distance(input_it, input_name.end())); 143 int position; 144 if (!strings::safe_strto32(remaining, &position)) { 145 return -2; 146 } 147 return is_ctrl ? -1 : position; 148 } else { 149 return -2; 150 } 151 } 152 153 // Return the node name corresponding to 'name' if name is valid, or the empty 154 // string otherwise. 155 inline StringPiece NodeNameAsStringPiece(const string& name) { 156 static const string empty; 157 if (name.empty()) return StringPiece(empty); 158 const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin(); 159 auto end_it = begin_it; 160 while (end_it != name.end() && *end_it != ':') { 161 ++end_it; 162 } 163 if (end_it != name.end() && *end_it != ':') { 164 return StringPiece(empty); 165 } 166 return StringPiece(&(*begin_it), std::distance(begin_it, end_it)); 167 } 168 169 // Return the node name corresponding to 'name' if name is valid, or the empty 170 // string otherwise. 171 inline string NodeName(const string& name) { 172 return string(NodeNameAsStringPiece(name)); 173 } 174 175 // Returns the node name and position in a single call. 176 // DEPRECATED(ezhulenev): Use TensorId and ParseTensorName. 177 inline StringPiece ParseNodeNameAsStringPiece(const string& name, 178 int* position) { 179 static const string empty; 180 if (name.empty()) { 181 *position = 0; 182 return StringPiece(empty); 183 } 184 const bool is_ctrl = name[0] == '^'; 185 const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin(); 186 *position = is_ctrl ? -1 : 0; 187 auto end_it = begin_it; 188 while (end_it != name.end() && *end_it != ':') { 189 ++end_it; 190 } 191 const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it)); 192 if (end_it != name.end()) { 193 if (*end_it != ':') { 194 return StringPiece(empty); 195 } else if (!is_ctrl) { 196 ++end_it; 197 StringPiece remaining(&(*end_it), std::distance(end_it, name.end())); 198 if (!strings::safe_strto32(remaining, position)) { 199 return StringPiece(empty); 200 } 201 } 202 } 203 return node_name; 204 } 205 206 // Returns the node name and position in a single call. 207 // DEPRECATED(ezhulenev): Use SafeTensorId and ParseTensorName. 208 inline string ParseNodeName(const string& name, int* position) { 209 return string(ParseNodeNameAsStringPiece(name, position)); 210 } 211 212 inline int NodePosition(const string& name) { 213 int position; 214 ParseNodeNameAsStringPiece(name, &position); 215 return position; 216 } 217 218 // Add a prefix to a node name with a custom delimiter. 219 string AddPrefixToNodeName(const string& name, const string& prefix, 220 const string& delimiter); 221 222 // Add a prefix to a node name. 223 string AddPrefixToNodeName(const string& name, const string& prefix); 224 225 // Executes a 'fn' in the 'thread_pool'. The method waits for the configured 226 // timeout (in milliseconds) for 'fn' to complete, before returning false. 227 // 228 // If returning false, the 'fn' may still continue to execute in the 229 // thread-pool. It is the responsibility of the caller to reset the thread-pool 230 // as appropriate. 231 bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms, 232 thread::ThreadPool* thread_pool); 233 234 // Returns the node name prefixed with conventional symbol '^' 235 // for control dependency, given a NodeDef. 236 string AsControlDependency(const NodeDef& node); 237 238 // Returns the node name prefixed with conventional symbol '^' 239 // for control dependency, given a node name 240 string AsControlDependency(const string& node); 241 242 // Returns true if the node is assigned to run on CPU device. 243 bool NodeIsOnCpu(const NodeDef* node); 244 245 // Returns true if the node is assigned to run on GPU device. 246 bool NodeIsOnGpu(const NodeDef* node); 247 248 // Returns the number of outputs of a node according to its OpDef. Note that 249 // some of the outputs may be unconnected. 250 int NumOutputs(const NodeDef& node, GraphDef* graph); 251 252 // Returns true iff the node has at least one control input. 253 bool HasControlInputs(const NodeDef& node); 254 255 // Number of connected non-control inputs. 256 int NumNonControlInputs(const NodeDef& node); 257 258 // Number of connected non-control outputs. 259 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); 260 261 // Number of connected non-control data outputs (Ops that consume output tensor 262 // data, not just it's shape). 263 int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map); 264 265 // Removes redundant control inputs from node. 266 void DedupControlInputs(NodeDef* node); 267 268 // Returns an error if an attribute with the given key does not exist in node. 269 Status CheckAttrExists(const NodeDef& node, const string& key); 270 271 // Returns an error if attributes with the given keys do not exist in node. 272 Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys); 273 274 // Returns the data type in attribute `attr_name` of `node`. If that attribute 275 // doesn't exist, returns DT_INVALID. 276 DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr); 277 278 // Returns the last node in the simple chain starting at source and traversing 279 // through the input(0) edge from each node as long as the next node satisfies 280 // the predicate given in pred_fn. If no nodes satisfy the predicate, &source 281 // will be returned. Example: For the chain 282 // source <- a <- b <- ... <- y <- z 283 // where 284 // pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true, 285 // pred_fn(z) = false, 286 // the return value will be a pointer to y. 287 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, 288 bool follow_control_input, 289 const std::function<bool(const NodeDef&)>& pred_fn); 290 291 // Permute the nodes of graph in place according to the permutation. 292 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, 293 bool invert_permutation); 294 295 // Returns Status::OK() if a kernel is registered for node.op() on the device 296 // type corresponding to node.device(). 297 Status IsKernelRegisteredForNode(const NodeDef& node); 298 299 Status SetTensorValue(DataType dtype, int value, Tensor* tensor); 300 301 void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph); 302 303 void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph); 304 305 void EraseNodesFromGraph(const std::set<string>& nodes_to_delete, 306 GraphDef* graph); 307 308 } // end namespace grappler 309 } // end namespace tensorflow 310 311 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_ 312