Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
     17 #define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
     18 
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/graph/graph.h"
     21 #include "tensorflow/core/graph/tensor_id.h"
     22 #include "tensorflow/core/lib/core/status.h"
     23 
     24 namespace tensorflow {
     25 class ShapeRefiner;
     26 
     27 // Construct a Graph *g out of a GraphDef gdef. Returns non-OK on
     28 // error, in which case *g is left in an incomplete state.
     29 //
     30 // *g is expected to be an empty graph (with no more than a source and sink
     31 // nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph,
     32 // see ImportGraphDef.
     33 struct GraphConstructorOptions {
     34   GraphConstructorOptions() {}
     35 
     36   // If true, allows internal ops in the GraphDef.
     37   bool allow_internal_ops = false;
     38 
     39   // If true, the graph def is expected to have fully specified
     40   // devices for all nodes. A node in the resulting graph "g" has the
     41   // device name set accordingly.
     42   //
     43   // TODO(zhifengc): if possible, consider removing this option.
     44   bool expect_device_spec = false;
     45 };
     46 extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
     47                                      const GraphDef& gdef, Graph* g);
     48 
     49 // Same as ConvertGraphDefToGraph, but takes just nodes.  Used by function
     50 // instantiation.
     51 // TODO(irving): This will turn into std::vector<NodeInfoPtr> soon.
     52 extern Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
     53                                      gtl::ArraySlice<NodeDef> nodes, Graph* g);
     54 
     55 // Options for calling ImportGraphDef().
     56 struct ImportGraphDefOptions {
     57   ImportGraphDefOptions()
     58       : uniquify_names(false),
     59         uniquify_prefix(false),
     60         skip_mapped_nodes(false),
     61         validate_shape(true) {}
     62 
     63   // Name prefix to use for nodes imported from the GraphDef.  For example, if
     64   // prefix="animals" and GraphDef contains a node "bunny" then the node will be
     65   // named "animals/bunny" in *g. Must not be already used as a node name or
     66   // prefix in the graph.
     67   string prefix;
     68 
     69   // If true, imported node names will be modified if their name already exists
     70   // in the graph. If false, conflicting names will be treated as an error. Note
     71   // that this option has no effect if `prefix` is specified, since `prefix`
     72   // will guarantee all node names are unique.
     73   bool uniquify_names;
     74 
     75   // If true, `prefix` will be modified if it already exists as a node name or
     76   // prefix in the graph. If false, a conflicting prefix will be treated as an
     77   // error. This option has no effect if `prefix` isn't specified.
     78   bool uniquify_prefix;
     79 
     80   // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef`
     81   // corresponding to `input_map` keys will be remapped to the nodes in `g`
     82   // corresponding to the values.
     83   //
     84   // Keys should not include `prefix`, i.e., a key TensorId's name should be the
     85   // name as it originally appears in `gdef`.
     86   //
     87   // If this is non-empty, ImportGraphDef must be called with the shape refiner
     88   // used to create the existing nodes referenced in `input_map`.
     89   // TODO(skyewm): can we remove this requirement? How do we access the original
     90   // shape refiner?
     91   std::map<TensorId, TensorId> input_map;
     92 
     93   // If true, nodes that will have all output edges removed because of
     94   // overrides in `input_map` will not be imported.
     95   bool skip_mapped_nodes;
     96 
     97   // The names of existing nodes in `g` that the imported graph should have
     98   // control dependencies on.
     99   //
    100   // Note that to avoid creating many redundant control edges, ImportGraphDef()
    101   // won't add control edges to nodes that will inherit the dependencies from
    102   // other nodes in `gdef`.
    103   std::vector<string> control_dependencies;
    104 
    105   // Tensors in `gdef` that will be returned via the ImportGraphDefResults
    106   // output parameter of `ImportGraphDef()`. If this list is non-empty, the
    107   // caller must pass a results object to `ImportGraphDef()`. The
    108   // `return_tensors` field will be populated with the imported nodes in `g`.
    109   //
    110   // Entries should not include `prefix`, i.e., each TensorId's name should be
    111   // the name as it originally appears in `gdef`.
    112   //
    113   // If this contains a tensor that's also being remapped via `input_map`, the
    114   // corresponding existing tensor in `g` will be returned.
    115   std::vector<TensorId> return_tensors;
    116 
    117   // The names of nodes in `gdef` that will be returned via the
    118   // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
    119   // is non-empty, the caller must pass a results object to
    120   // `ImportGraphDef()`. The `return_nodes` field will be populated with the
    121   // imported nodes in `g`.
    122   //
    123   // Entries should not include `prefix`, i.e., each node's name should be the
    124   // name as it originally appears in `gdef`.
    125   //
    126   // Unlike `return_tensors`, `input_map` has no effect on the nodes
    127   // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true.
    128   // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
    129   std::vector<string> return_nodes;
    130 
    131   // If true, checks that all colocation constraints are nodes in the GraphDef.
    132   bool validate_colocation_constraints = true;
    133 
    134   // If false skips shape validation.
    135   bool validate_shape;
    136 
    137   // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
    138   // with ops that are not defined in the binary calling ImportGraphDef.
    139   // Similar to the producer_op_list argument to import_graph_def in the
    140   // python API.
    141 };
    142 
    143 // Optional results that may be returned by ImportGraphDef.
    144 struct ImportGraphDefResults {
    145   // The requested tensors associated with
    146   // ImportGraphDefOptions::return_tensors. Note that the index may be different
    147   // than the requested index if the returned tensor has been remapped according
    148   // to `input_map`.
    149   typedef int Index;
    150   std::vector<std::pair<Node*, Index>> return_tensors;
    151 
    152   // The requested nodes associated with ImportGraphDefOptions::return_nodes.
    153   std::vector<Node*> return_nodes;
    154 
    155   // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
    156   // weren't used as an input to any node in `gdef`. These keys are likely due
    157   // to typos, and callers may wish to treat their existence as an error.
    158   std::vector<TensorId> missing_unused_input_map_keys;
    159 };
    160 
    161 // Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
    162 //
    163 // On error, returns non-OK and leaves `*g` unmodified.
    164 //
    165 // `refiner` can be null. It should be non-null if the caller
    166 // intends to add additional nodes to the graph after the import. This
    167 // allows the caller to validate shapes of those nodes (since
    168 // ShapeRefiner::AddNode must be called in topological order).
    169 //
    170 // `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is
    171 // non-empty. It can also be set to fetch the unused input map keys. If it's
    172 // non-null, all the vector fields must be empty.
    173 //
    174 // TODO(ashankar): Push this mechanism and get rid of Session::Extend()
    175 // as a means of enhancing an existing Graph.
    176 extern Status ImportGraphDef(const ImportGraphDefOptions& opts,
    177                              const GraphDef& gdef, Graph* g,
    178                              ShapeRefiner* refiner,
    179                              ImportGraphDefResults* results = nullptr);
    180 
    181 // Make a copy of "src" into "*dest".
    182 //
    183 // REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges
    184 // other than the implicit Source/Sink nodes.
    185 extern void CopyGraph(const Graph& src, Graph* dest);
    186 
    187 }  // namespace tensorflow
    188 
    189 #endif  // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
    190