1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ 17 #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ 18 19 #include <set> 20 #include <unordered_set> 21 #include <vector> 22 23 #include "tensorflow/core/framework/attr_value.pb.h" 24 #include "tensorflow/core/framework/attr_value_util.h" 25 #include "tensorflow/core/framework/graph.pb.h" 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor.pb.h" 29 #include "tensorflow/core/lib/core/status.h" 30 31 namespace tensorflow { 32 namespace graph_transforms { 33 34 // Used to quickly look up nodes in the graph def from a name. 35 void MapNamesToNodes(const GraphDef& graph_def, 36 std::map<string, const NodeDef*>* result); 37 38 // For every node in the graph create a list of the nodes that use it as an 39 // input. 40 void MapNodesToOutputs(const GraphDef& graph_def, 41 std::map<string, std::vector<const NodeDef*>>* result); 42 43 // NodeDef input strings can contain other information besides the name of an 44 // input node. These include: 45 // - Optional '^' prefix, indicating this is a control edge. 46 // - The required name of the input node. 47 // - Optional ':<number>' suffix, showing which output of the node to use. 48 // This function takes a raw string, and breaks it into those component parts. 49 // The rules for inputs in function libraries are a bit more complex, and 50 // aren't handled by this routine. 51 void NodeNamePartsFromInput(const string& input_name, string* prefix, 52 string* node_name, string* suffix); 53 54 // Adds a ':0' port to any inputs with no suffix, to make comparisons easier. 55 string CanonicalInputName(const string& input_name); 56 57 // Convenience function to strip the optional prefix and suffix components from 58 // a string pulled from a NodeDef input, and return the plain node name. 59 string NodeNameFromInput(const string& input_name); 60 61 // Returns a stable hash for the contents of the NodeDef, so that equivalent 62 // nodes should have equal hashes. 63 uint64 HashNodeDef(const NodeDef& node); 64 65 // Adds the given node name to the end of the node's inputs. 66 void AddNodeInput(const string& input_name, NodeDef* node); 67 68 // Copies an attribute from one NodeDef to another. 69 void CopyNodeAttr(const NodeDef& source, const string& source_key, 70 const string& dest_key, NodeDef* dest); 71 72 // Inserts a value into a NodeDef's map of attributes. 73 // This is a bit different than AddNodeAttr in node_def_util.h because it 74 // overwrites any existing attributes with the same key. 75 template <class T> 76 inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) { 77 AttrValue attr_value; 78 SetAttrValue(value, &attr_value); 79 auto* attr_map = node->mutable_attr(); 80 (*attr_map)[key] = attr_value; 81 } 82 83 template <class T> 84 inline void SetNodeTensorAttr(const string& key, const Tensor& tensor, 85 NodeDef* node) { 86 TensorProto tensor_proto; 87 tensor.AsProtoTensorContent(&tensor_proto); 88 SetNodeAttr(key, tensor_proto, node); 89 } 90 91 // Inserts a Tensor into the specified attribute of a NodeDef. 92 template <class T> 93 inline void SetNodeTensorAttr(const string& key, const TensorShape& shape, 94 const std::vector<T>& values, NodeDef* node) { 95 const DataType dtype = DataTypeToEnum<T>::v(); 96 CHECK_EQ(shape.num_elements(), values.size()); 97 Tensor tensor(dtype, shape); 98 T* dest_data = tensor.flat<T>().data(); 99 std::copy_n(values.data(), values.size(), dest_data); 100 SetNodeTensorAttr<T>(key, tensor, node); 101 } 102 103 // Retrieves a tensor value from a NodeDef attribute. 104 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key); 105 106 // Creates a copy of the input GraphDef, but only containing the nodes where the 107 // supplied selector function returned true. 108 void FilterGraphDef(const GraphDef& input_graph_def, 109 std::function<bool(const NodeDef&)> selector, 110 GraphDef* output_graph_def); 111 112 // Creates a copy of the input graph, with all occurrences of the attributes 113 // with the names in the argument removed from the node defs. 114 void RemoveAttributes(const GraphDef& input_graph_def, 115 const std::vector<string>& attributes, 116 GraphDef* output_graph_def); 117 118 // For a lot of replacement and matching operations it's useful to have the 119 // nodes processed in a controlled order, so this does a topological sort to 120 // ensure that nodes always appear in the GraphDef.node list after their inputs. 121 Status SortByExecutionOrder(const GraphDef& input_graph_def, 122 GraphDef* output_graph_def); 123 124 // Finds inputs that refer to nodes that are not in the graph. 125 void FindInvalidInputs(const GraphDef& graph_def, 126 std::vector<std::pair<string, string>>* invalid_inputs); 127 128 // Returns a descriptive error status if there are problems spotted with the 129 // graph. 130 Status IsGraphValid(const GraphDef& graph_def); 131 132 // Returns input and output types for a particular NodeDef. 133 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, 134 DataTypeVector* outputs); 135 136 // Takes a comma-separated string of numbers and parses them into a shape. 137 Status TensorShapeFromString(const string& shape_string, TensorShape* result); 138 139 // This is used to spot particular subgraphs in a larger model. To use it, 140 // create a pattern like: 141 // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); 142 // This defines a subgraph where a Conv2D has a ResizeBilinear input, which 143 // pulls from a MirrorPad op. 144 // Regular expressions aren't supported for the op names, but you can use "*" to 145 // match any op. You can also use | as a separator to match multiple op names, 146 // like "Reshape|Concat|Conv2D". 147 struct OpTypePattern { 148 string op; 149 std::vector<OpTypePattern> inputs; 150 string DebugString() const; 151 }; 152 153 // Returns a sub-graph of nodes that match a pattern. 154 struct NodeMatch { 155 NodeMatch() : node() {} 156 NodeDef node; 157 std::vector<NodeMatch> inputs; 158 string DebugString() const; 159 }; 160 161 // Utility class to spot subgraphs matching particular patterns. 162 class GraphMatcher { 163 public: 164 GraphMatcher(const GraphDef& graph_def); 165 166 // Sorts the input nodes into execution order, and then skips any previously 167 // matches so that no node appears in more than one match. The NodeDef 168 // pointers contained in the results are owned by the GraphMatcher object, and 169 // so will be invalid after its lifetime. 170 Status GetOpTypeMatches(const OpTypePattern& pattern, 171 std::vector<NodeMatch>* matches); 172 173 private: 174 bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern, 175 const std::set<string>& previously_matched_nodes, 176 NodeMatch* match); 177 178 GraphDef graph_def_; 179 std::map<string, const NodeDef*> node_map_; 180 }; 181 182 struct ReplaceMatchingOpTypesOptions { 183 // Whether to raise an error if the graph is left with dangling inputs. If you 184 // enable this option, you must fix inconsistencies in a later pass. 185 bool allow_inconsistencies; 186 }; 187 188 // Replaces all of the matching sub-graphs with new ops. This calls into the 189 // given function, and expects to receive a set of new nodes to replace each 190 // matched sub-graph. It has some logic to protect the integrity of the 191 // resulting graph, for example making sure that nodes needed by other nodes 192 // outside the sub-graph aren't removed. These are passed in as the set of 193 // outputs, and nodes with the same names must be added to the new nodes 194 // produced by the replacement function. Many of these checks can be disabled 195 // by setting allow_inconsistencies to true in the options, but then it's the 196 // caller's responsibility to patch up any problems before passing on the graph 197 // to others. There's more comprehensive usage documentation in the README. 198 Status ReplaceMatchingOpTypes( 199 const GraphDef& input_graph_def, const OpTypePattern& pattern, 200 const std::function<Status(const NodeMatch&, const std::set<string>&, 201 const std::set<string>&, std::vector<NodeDef>*)>& 202 node_generator, 203 const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def); 204 205 // Returns a list of the unique nodes found in this match. 206 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result); 207 208 // Changes all input references to a particular node name. Any nodes with names 209 // listed in nodes_to_ignore will not have their inputs rewritten. 210 Status RenameNodeInputs(const GraphDef& input_graph_def, 211 const std::map<string, string>& inputs_to_rename, 212 const std::unordered_set<string>& nodes_to_ignore, 213 GraphDef* output_graph_def); 214 215 // Utility function that copies all the nodes found in a match into the 216 // new_nodes list. This is useful in replacement functions when you decide to 217 // leave the original matched subgraph untouched and make no changes. 218 void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes); 219 220 // Holds information that's needed for transform functions. 221 typedef std::map<string, std::vector<string>> TransformFuncParameters; 222 struct TransformFuncContext { 223 std::vector<string> input_names; 224 std::vector<string> output_names; 225 TransformFuncParameters params; 226 227 // Returns how many occurrences of the given parameter are present. 228 int CountParameters(const string& name) const; 229 230 // Gets a single instance of a parameter, using a default if it's not present. 231 Status GetOneStringParameter(const string& name, const string& default_value, 232 string* result) const; 233 234 // Gets a single occurrence of a parameter as a 32-bit integer, falling back 235 // to a default if it isn't present and returning an error if it isn't 236 // convertible to a number. 237 Status GetOneInt32Parameter(const string& name, int32 default_value, 238 int32* result) const; 239 240 // Gets a single occurrence of a parameter as a 64-bit integer, falling back 241 // to a default if it isn't present and returning an error if it isn't 242 // convertible to a number. 243 Status GetOneInt64Parameter(const string& name, int64 default_value, 244 int64* result) const; 245 246 // Gets a single occurrence of a parameter as a floating point number, falling 247 // back to a default if it isn't present and returning an error if it isn't 248 // convertible to a number. 249 Status GetOneFloatParameter(const string& name, float default_value, 250 float* result) const; 251 252 // Gets a single occurrence of a parameter as a boolean, falling back to a 253 // default if it isn't present and returning an error if it's not one of 254 // "true", "1", "false", or "0". 255 Status GetOneBoolParameter(const string& name, bool default_value, 256 bool* result) const; 257 }; 258 259 // This is the function API for all graph transformations, taking an input 260 // GraphDef and other arguments, and returning a transformed GraphDef. 261 typedef std::function<Status(const GraphDef&, 262 const TransformFuncContext& context, GraphDef*)> 263 TransformFunc; 264 265 // To add a new graph transform function, call the macro: 266 // REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants); 267 // Under the hood this adds the function to the list of known transforms, so you 268 // just need to link in the .cc file with your registration call to have access 269 // to it through the command line tool. 270 // The rest of the machinery below is to enable that automagical registration. 271 typedef std::map<string, TransformFunc> TransformRegistry; 272 TransformRegistry* GetTransformRegistry(); 273 class TransformRegistrar { 274 public: 275 TransformRegistrar(const string& name, TransformFunc transform_func) { 276 TransformRegistry* transform_registry = GetTransformRegistry(); 277 (*transform_registry)[name] = transform_func; 278 } 279 }; 280 #define REGISTER_GRAPH_TRANSFORM(name, func) \ 281 REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func) 282 #define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \ 283 REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) 284 #define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func) \ 285 static tensorflow::graph_transforms::TransformRegistrar \ 286 registrar__body__##ctr##__object(name, func); 287 288 } // namespace graph_transforms 289 } // namespace tensorflow 290 291 #endif // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_ 292