Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
     17 #define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
     18 
     19 #include <vector>
     20 #include "tensorflow/core/framework/graph.pb.h"
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/graph/graph.h"
     23 #include "tensorflow/core/graph/node_builder.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/core/stringpiece.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 
     28 namespace tensorflow {
     29 
     30 // Given a function like:
     31 //   namespace ops {
     32 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
     33 //     if (opts.HaveError()) return nullptr;
     34 //     static const string kOpName = "Identity";
     35 //     NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
     36 //                              opts.op_registry());
     37 //     node_builder.Input(input);
     38 //     return opts.FinalizeBuilder(&node_builder);
     39 //   }
     40 //   }  // namespace ops
     41 //
     42 //   // Or, alternatively:
     43 //   namespace ops {
     44 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
     45 //     static const string kOpName = "Identity";
     46 //     return UnaryOp(kOpName, input, opts);
     47 //   }
     48 //   }  // namespace ops
     49 //
     50 // You call it like:
     51 //   GraphDefBuilder b;
     52 //   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
     53 //   Node* na = Const(7, b.opts());
     54 //   // Note: WithName() returns a copy, opts is unchanged.
     55 //   Node* nb = Const(5, b.opts().WithName("control-input"));
     56 //   Node* nc = Identity(na, b.opts().WithControlInput(nb));
     57 //   GraphDef graph_def;
     58 //   Status status = b.ToGraphDef(&graph_def);
     59 //   if (!status.ok()) { /* Handle error */ }
     60 //
     61 // In tests you can skip the status handling via:
     62 //   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
     63 //   ...
     64 //   b.ToGraphDef(&graph_def);
     65 
     66 class GraphDefBuilder {
     67  public:
     68   // Options for adding a Node to a Graph.
     69   class Options {
     70    public:
     71     // Sets the Graph (that Nodes will be added to) and the status.  The
     72     // status may be set to nullptr, in which case errors cause CHECK
     73     // failures.  The graph and status must outlive *this.
     74     Options(Graph* graph, Status* status);
     75     ~Options();
     76 
     77     // Methods for setting options.  These are const methods: they
     78     // return a copy of *this with the option set.
     79     Options WithName(StringPiece name) const;
     80     Options WithDevice(StringPiece device) const;
     81     Options WithControlInput(Node* control_input) const;
     82     Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
     83 
     84     // Override the default value for an optional attr.
     85     template <class T>
     86     Options WithAttr(StringPiece attr_name, T&& value) const {
     87       return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
     88     }
     89     // Note: overload needed to allow {...} expressions for value.
     90     template <class T>
     91     Options WithAttr(StringPiece attr_name,
     92                      std::initializer_list<T> value) const {
     93       return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
     94     }
     95 
     96     // Methods for using options from a function that creates a Node.
     97 
     98     // Returns true if the status associated with *this has an error.
     99     // Use this to skip processing that may depend on prior results.
    100     bool HaveError() const { return status_ != nullptr && !status_->ok(); }
    101 
    102     // Returns a string representation of the status associated with *this.
    103     // Returns the string `"OK"` if the status doesn't have any error.
    104     string StatusToString() const { return status_->ToString(); }
    105 
    106     // Given the Op type name, return a name for a node of that type.
    107     // Uses the value set in WithName() if that has been called.  Otherwise,
    108     // returns a name built out of the Op type name.
    109     string GetNameForOp(StringPiece op) const;
    110 
    111     // Sets the device, adds control inputs, adds attrs, and calls Finalize().
    112     // If Finalize returns an error, it is saved and this function returns
    113     // nullptr.
    114     Node* FinalizeBuilder(NodeBuilder* builder) const;
    115 
    116     // Updates the associated status, if any, or calls TF_CHECK_OK if none.
    117     void UpdateStatus(const Status& status) const;
    118 
    119     // Accessor
    120     const OpRegistryInterface* op_registry() const {
    121       return graph_->op_registry();
    122     }
    123 
    124    private:
    125     Options WithNameImpl(StringPiece name);
    126     Options WithDeviceImpl(StringPiece device);
    127     Options WithControlInputImpl(Node* control_input);
    128     Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
    129     template <class T>
    130     Options WithAttrImpl(StringPiece name, T&& value) {
    131       attrs_.emplace_back(name.ToString(), AttrValue());
    132       SetAttrValue(std::forward<T>(value), &attrs_.back().second);
    133       return *this;
    134     }
    135 
    136     Graph* const graph_;
    137     Status* const status_;
    138     string name_;
    139     string device_;
    140     std::vector<Node*> control_inputs_;
    141     std::vector<std::pair<string, AttrValue>> attrs_;
    142   };
    143 
    144   // Start building a new graph.
    145   explicit GraphDefBuilder(
    146       const OpRegistryInterface* op_registry = OpRegistry::Global())
    147       : graph_(op_registry), opts_(&graph_, &status_) {}
    148 
    149   // For use in tests, where you want to fail immediately on error instead
    150   // of checking the status at the end.
    151   enum TestFailImmediatelyType { kFailImmediately };
    152   explicit GraphDefBuilder(
    153       TestFailImmediatelyType,
    154       const OpRegistryInterface* op_registry = OpRegistry::Global())
    155       : graph_(op_registry), opts_(&graph_, nullptr) {}
    156 
    157   // Gets the Options with the associated Graph and Status.
    158   const Options& opts() const { return opts_; }
    159 
    160   // Once all the nodes have been added, call this to get whether it was
    161   // successful, and if so fill *graph_def.
    162   Status ToGraphDef(GraphDef* graph_def) const;
    163 
    164   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
    165   // registry. Ignores duplicate functions, and returns a bad status if an
    166   // imported function differs from an existing function or op with the same
    167   // name.
    168   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
    169     return graph_.AddFunctionLibrary(fdef_lib);
    170   }
    171 
    172   // Returns whether a user-defined function with `name` already exists in the
    173   // graph.
    174   bool HasFunction(const string& name) {
    175     return graph_.flib_def().Find(name) != nullptr;
    176   }
    177 
    178  private:
    179   Graph graph_;
    180   Status status_;
    181   Options opts_;
    182 };
    183 
    184 namespace ops {
    185 
    186 // A NodeOut may either be a regular input or back input.  Regular
    187 // inputs are specified via either a Node* or a Node* and an output
    188 // index.  Back inputs are specified by a node name, output index, and
    189 // output type.
    190 typedef NodeBuilder::NodeOut NodeOut;
    191 
    192 // For adding an Op with no inputs to a GraphDefBuilder.
    193 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
    194 
    195 // For adding an Op with one input to a GraphDefBuilder.
    196 Node* UnaryOp(const string& op_name, NodeOut input,
    197               const GraphDefBuilder::Options& opts);
    198 
    199 // For adding an Op with two inputs to a GraphDefBuilder.
    200 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
    201                const GraphDefBuilder::Options& opts);
    202 
    203 }  // namespace ops
    204 }  // namespace tensorflow
    205 
    206 #endif  // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
    207