Home | History | Annotate | Download | only in graph_transforms
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
     17 #define TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
     18 
     19 #include <set>
     20 #include <unordered_set>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/attr_value.pb.h"
     24 #include "tensorflow/core/framework/attr_value_util.h"
     25 #include "tensorflow/core/framework/graph.pb.h"
     26 #include "tensorflow/core/framework/node_def.pb.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor.pb.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 
     31 namespace tensorflow {
     32 namespace graph_transforms {
     33 
     34 // Used to quickly look up nodes in the graph def from a name.
     35 void MapNamesToNodes(const GraphDef& graph_def,
     36                      std::map<string, const NodeDef*>* result);
     37 
     38 // For every node in the graph create a list of the nodes that use it as an
     39 // input.
     40 void MapNodesToOutputs(const GraphDef& graph_def,
     41                        std::map<string, std::vector<const NodeDef*>>* result);
     42 
     43 // NodeDef input strings can contain other information besides the name of an
     44 // input node. These include:
     45 //  - Optional '^' prefix, indicating this is a control edge.
     46 //  - The required name of the input node.
     47 //  - Optional ':<number>' suffix, showing which output of the node to use.
     48 // This function takes a raw string, and breaks it into those component parts.
     49 // The rules for inputs in function libraries are a bit more complex, and
     50 // aren't handled by this routine.
     51 void NodeNamePartsFromInput(const string& input_name, string* prefix,
     52                             string* node_name, string* suffix);
     53 
     54 // Adds a ':0' port to any inputs with no suffix, to make comparisons easier.
     55 string CanonicalInputName(const string& input_name);
     56 
     57 // Convenience function to strip the optional prefix and suffix components from
     58 // a string pulled from a NodeDef input, and return the plain node name.
     59 string NodeNameFromInput(const string& input_name);
     60 
     61 // Returns a stable hash for the contents of the NodeDef, so that equivalent
     62 // nodes should have equal hashes.
     63 uint64 HashNodeDef(const NodeDef& node);
     64 
     65 // Adds the given node name to the end of the node's inputs.
     66 void AddNodeInput(const string& input_name, NodeDef* node);
     67 
     68 // Copies an attribute from one NodeDef to another.
     69 void CopyNodeAttr(const NodeDef& source, const string& source_key,
     70                   const string& dest_key, NodeDef* dest);
     71 
     72 // Inserts a value into a NodeDef's map of attributes.
     73 // This is a bit different than AddNodeAttr in node_def_util.h because it
     74 // overwrites any existing attributes with the same key.
     75 template <class T>
     76 inline void SetNodeAttr(const string& key, const T& value, NodeDef* node) {
     77   AttrValue attr_value;
     78   SetAttrValue(value, &attr_value);
     79   auto* attr_map = node->mutable_attr();
     80   (*attr_map)[key] = attr_value;
     81 }
     82 
     83 template <class T>
     84 inline void SetNodeTensorAttr(const string& key, const Tensor& tensor,
     85                               NodeDef* node) {
     86   TensorProto tensor_proto;
     87   tensor.AsProtoTensorContent(&tensor_proto);
     88   SetNodeAttr(key, tensor_proto, node);
     89 }
     90 
     91 // Inserts a Tensor into the specified attribute of a NodeDef.
     92 template <class T>
     93 inline void SetNodeTensorAttr(const string& key, const TensorShape& shape,
     94                               const std::vector<T>& values, NodeDef* node) {
     95   const DataType dtype = DataTypeToEnum<T>::v();
     96   CHECK_EQ(shape.num_elements(), values.size());
     97   Tensor tensor(dtype, shape);
     98   T* dest_data = tensor.flat<T>().data();
     99   std::copy_n(values.data(), values.size(), dest_data);
    100   SetNodeTensorAttr<T>(key, tensor, node);
    101 }
    102 
    103 // Retrieves a tensor value from a NodeDef attribute.
    104 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key);
    105 
    106 // Creates a copy of the input GraphDef, but only containing the nodes where the
    107 // supplied selector function returned true.
    108 void FilterGraphDef(const GraphDef& input_graph_def,
    109                     std::function<bool(const NodeDef&)> selector,
    110                     GraphDef* output_graph_def);
    111 
    112 // Creates a copy of the input graph, with all occurrences of the attributes
    113 // with the names in the argument removed from the node defs.
    114 void RemoveAttributes(const GraphDef& input_graph_def,
    115                       const std::vector<string>& attributes,
    116                       GraphDef* output_graph_def);
    117 
    118 // For a lot of replacement and matching operations it's useful to have the
    119 // nodes processed in a controlled order, so this does a topological sort to
    120 // ensure that nodes always appear in the GraphDef.node list after their inputs.
    121 Status SortByExecutionOrder(const GraphDef& input_graph_def,
    122                             GraphDef* output_graph_def);
    123 
    124 // Finds inputs that refer to nodes that are not in the graph.
    125 void FindInvalidInputs(const GraphDef& graph_def,
    126                        std::vector<std::pair<string, string>>* invalid_inputs);
    127 
    128 // Returns a descriptive error status if there are problems spotted with the
    129 // graph.
    130 Status IsGraphValid(const GraphDef& graph_def);
    131 
    132 // Returns input and output types for a particular NodeDef.
    133 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
    134                      DataTypeVector* outputs);
    135 
    136 // Takes a comma-separated string of numbers and parses them into a shape.
    137 Status TensorShapeFromString(const string& shape_string, TensorShape* result);
    138 
    139 // This is used to spot particular subgraphs in a larger model. To use it,
    140 // create a pattern like:
    141 // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});
    142 // This defines a subgraph where a Conv2D has a ResizeBilinear input, which
    143 // pulls from a MirrorPad op.
    144 // Regular expressions aren't supported for the op names, but you can use "*" to
    145 // match any op. You can also use | as a separator to match multiple op names,
    146 // like "Reshape|Concat|Conv2D".
    147 struct OpTypePattern {
    148   string op;
    149   std::vector<OpTypePattern> inputs;
    150   string DebugString() const;
    151 };
    152 
    153 // Returns a sub-graph of nodes that match a pattern.
    154 struct NodeMatch {
    155   NodeMatch() : node() {}
    156   NodeDef node;
    157   std::vector<NodeMatch> inputs;
    158   string DebugString() const;
    159 };
    160 
    161 // Utility class to spot subgraphs matching particular patterns.
    162 class GraphMatcher {
    163  public:
    164   GraphMatcher(const GraphDef& graph_def);
    165 
    166   // Sorts the input nodes into execution order, and then skips any previously
    167   // matches so that no node appears in more than one match. The NodeDef
    168   // pointers contained in the results are owned by the GraphMatcher object, and
    169   // so will be invalid after its lifetime.
    170   Status GetOpTypeMatches(const OpTypePattern& pattern,
    171                           std::vector<NodeMatch>* matches);
    172 
    173  private:
    174   bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern,
    175                        const std::set<string>& previously_matched_nodes,
    176                        NodeMatch* match);
    177 
    178   GraphDef graph_def_;
    179   std::map<string, const NodeDef*> node_map_;
    180 };
    181 
    182 struct ReplaceMatchingOpTypesOptions {
    183   // Whether to raise an error if the graph is left with dangling inputs. If you
    184   // enable this option, you must fix inconsistencies in a later pass.
    185   bool allow_inconsistencies;
    186 };
    187 
    188 // Replaces all of the matching sub-graphs with new ops. This calls into the
    189 // given function, and expects to receive a set of new nodes to replace each
    190 // matched sub-graph. It has some logic to protect the integrity of the
    191 // resulting graph, for example making sure that nodes needed by other nodes
    192 // outside the sub-graph aren't removed. These are passed in as the set of
    193 // outputs, and nodes with the same names must be added to the new nodes
    194 // produced by the replacement function. Many of these checks can be disabled
    195 // by setting allow_inconsistencies to true in the options, but then it's the
    196 // caller's responsibility to patch up any problems before passing on the graph
    197 // to others. There's more comprehensive usage documentation in the README.
    198 Status ReplaceMatchingOpTypes(
    199     const GraphDef& input_graph_def, const OpTypePattern& pattern,
    200     const std::function<Status(const NodeMatch&, const std::set<string>&,
    201                                const std::set<string>&, std::vector<NodeDef>*)>&
    202         node_generator,
    203     const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def);
    204 
    205 // Returns a list of the unique nodes found in this match.
    206 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result);
    207 
    208 // Changes all input references to a particular node name. Any nodes with names
    209 // listed in nodes_to_ignore will not have their inputs rewritten.
    210 Status RenameNodeInputs(const GraphDef& input_graph_def,
    211                         const std::map<string, string>& inputs_to_rename,
    212                         const std::unordered_set<string>& nodes_to_ignore,
    213                         GraphDef* output_graph_def);
    214 
    215 // Utility function that copies all the nodes found in a match into the
    216 // new_nodes list. This is useful in replacement functions when you decide to
    217 // leave the original matched subgraph untouched and make no changes.
    218 void CopyOriginalMatch(const NodeMatch& match, std::vector<NodeDef>* new_nodes);
    219 
    220 // Holds information that's needed for transform functions.
    221 typedef std::map<string, std::vector<string>> TransformFuncParameters;
    222 struct TransformFuncContext {
    223   std::vector<string> input_names;
    224   std::vector<string> output_names;
    225   TransformFuncParameters params;
    226 
    227   // Returns how many occurrences of the given parameter are present.
    228   int CountParameters(const string& name) const;
    229 
    230   // Gets a single instance of a parameter, using a default if it's not present.
    231   Status GetOneStringParameter(const string& name, const string& default_value,
    232                                string* result) const;
    233 
    234   // Gets a single occurrence of a parameter as a 32-bit integer, falling back
    235   // to a default if it isn't present and returning an error if it isn't
    236   // convertible to a number.
    237   Status GetOneInt32Parameter(const string& name, int32 default_value,
    238                               int32* result) const;
    239 
    240   // Gets a single occurrence of a parameter as a 64-bit integer, falling back
    241   // to a default if it isn't present and returning an error if it isn't
    242   // convertible to a number.
    243   Status GetOneInt64Parameter(const string& name, int64 default_value,
    244                               int64* result) const;
    245 
    246   // Gets a single occurrence of a parameter as a floating point number, falling
    247   // back to a default if it isn't present and returning an error if it isn't
    248   // convertible to a number.
    249   Status GetOneFloatParameter(const string& name, float default_value,
    250                               float* result) const;
    251 
    252   // Gets a single occurrence of a parameter as a boolean, falling back to a
    253   // default if it isn't present and returning an error if it's not one of
    254   // "true", "1", "false", or "0".
    255   Status GetOneBoolParameter(const string& name, bool default_value,
    256                              bool* result) const;
    257 };
    258 
    259 // This is the function API for all graph transformations, taking an input
    260 // GraphDef and other arguments, and returning a transformed GraphDef.
    261 typedef std::function<Status(const GraphDef&,
    262                              const TransformFuncContext& context, GraphDef*)>
    263     TransformFunc;
    264 
    265 // To add a new graph transform function, call the macro:
    266 // REGISTER_GRAPH_TRANSFORM("fold_constants", FoldConstants);
    267 // Under the hood this adds the function to the list of known transforms, so you
    268 // just need to link in the .cc file with your registration call to have access
    269 // to it through the command line tool.
    270 // The rest of the machinery below is to enable that automagical registration.
    271 typedef std::map<string, TransformFunc> TransformRegistry;
    272 TransformRegistry* GetTransformRegistry();
    273 class TransformRegistrar {
    274  public:
    275   TransformRegistrar(const string& name, TransformFunc transform_func) {
    276     TransformRegistry* transform_registry = GetTransformRegistry();
    277     (*transform_registry)[name] = transform_func;
    278   }
    279 };
    280 #define REGISTER_GRAPH_TRANSFORM(name, func) \
    281   REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(__COUNTER__, name, func)
    282 #define REGISTER_GRAPH_TRANSFORM_UNIQ_HELPER(ctr, name, func) \
    283   REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)
    284 #define REGISTER_GRAPH_TRANSFORM_UNIQ(ctr, name, func)    \
    285   static tensorflow::graph_transforms::TransformRegistrar \
    286       registrar__body__##ctr##__object(name, func);
    287 
    288 }  // namespace graph_transforms
    289 }  // namespace tensorflow
    290 
    291 #endif  // TENSORFLOW_TOOLS_GRAPH_TRANSFORMS_TRANSFORM_UTILS_H_
    292