Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_
     17 #define TENSORFLOW_GRAPH_NODE_BUILDER_H_
     18 
     19 #include <vector>
     20 #include "tensorflow/core/framework/node_def_builder.h"
     21 #include "tensorflow/core/framework/op.h"
     22 #include "tensorflow/core/framework/op_def.pb.h"
     23 #include "tensorflow/core/graph/graph.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 #include "tensorflow/core/lib/core/stringpiece.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 
     28 namespace tensorflow {
     29 
     30 // This is a helper for creating a Node and adding it to a Graph.
     31 // Internally, it uses a NodeDefBuilder to automatically set attrs
     32 // that can be inferred from the inputs, and use default values
     33 // (where they exist) for unspecified attrs.  Example usage:
     34 //
     35 //  Node* node;
     36 //  Status status = NodeBuilder(node_name, op_name)
     37 //                           .Input(...)
     38 //                           .Attr(...)
     39 //                           .Finalize(&graph, &node);
     40 //  if (!status.ok()) return status;
     41 //  // Use node here.
     42 class NodeBuilder {
     43  public:
     44   // For specifying the output of a Node to provide to one of the Input()
     45   // functions below.  It supports both regular inputs (where you are
     46   // connecting to an existing Node*), and inputs from outside the graph
     47   // (or haven't been added to the graph yet, like back edges, where
     48   // you don't have a Node*). Both types can be mixed, e.g. in an
     49   // ArraySlice.
     50   struct NodeOut {
     51     // For referencing an existing Node.
     52     NodeOut(Node* n, int32 i = 0);
     53 
     54     // For referencing Nodes not in the graph being built. It is
     55     // useful when preparing a graph for ExtendSession or creating a
     56     // back edge to a node that hasn't been added to the graph yet,
     57     // but will be.
     58     NodeOut(StringPiece name, int32 i, DataType t);
     59 
     60     // Default constructor for std::vector<NodeOut>.
     61     NodeOut();
     62 
     63     Node* node;
     64     // error is set to true if:
     65     // * the NodeOut was default constructed and never overwritten,
     66     // * a nullptr Node* was passed to the NodeOut constructor, or
     67     // * an out-of-range index was passed to the NodeOut constructor.
     68     bool error;
     69     string name;
     70     int32 index;
     71     DataType dt;
     72   };
     73 
     74   // Specify the name and the Op (either via an OpDef or the name of
     75   // the Op plus a registry) for the Node.  Other fields are
     76   // specified by calling the methods below.
     77   // REQUIRES: The OpDef must satisfy ValidateOpDef().
     78   NodeBuilder(StringPiece name, StringPiece op_name,
     79               const OpRegistryInterface* op_registry = OpRegistry::Global());
     80   NodeBuilder(StringPiece name, const OpDef* op_def);
     81 
     82   // Create a NodeBuilder from an existing NodeDefBuilder.
     83   NodeBuilder(const NodeDefBuilder& def_builder);
     84 
     85   // You must call one Input() function per input_arg in the Op,
     86   // *and in the same order as the input_args appear in the OpDef.*
     87 
     88   // For inputs that take a single tensor.
     89   NodeBuilder& Input(Node* src_node, int src_index = 0);
     90   NodeBuilder& Input(NodeOut src);
     91 
     92   // For inputs that take a list of tensors.
     93   NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
     94 
     95   // Require that this node run after src_node(s).
     96   NodeBuilder& ControlInput(Node* src_node);
     97   NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes);
     98 
     99   // Sets the "requested device spec" in the NodeDef (not the
    100   // "assigned device" in the Node).
    101   NodeBuilder& Device(StringPiece device_spec);
    102 
    103   // Set the value of an attr.  attr_name must match the name of one of
    104   // attrs defined by the Op, and value must have the corresponding type
    105   // (see SetAttrValue() in ../framework/attr_value_util.h for legal
    106   // types for value).  Note that attrs will be set automatically if
    107   // they can be determined by the inputs.
    108   template <class T>
    109   NodeBuilder& Attr(StringPiece attr_name, T&& value);
    110   template <class T>
    111   NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
    112 
    113   // Validates the described node and adds it to *graph, adding edges
    114   // for all (non-back) inputs.  If created_node is not nullptr,
    115   // *created_node will be set to the new node (or nullptr on error).
    116   Status Finalize(Graph* graph, Node** created_node) const;
    117 
    118   // Accessors for the values set in the constructor.
    119   const string& node_name() const { return def_builder_.node_name(); }
    120   const OpDef& op_def() const { return def_builder_.op_def(); }
    121 
    122  private:
    123   static DataType SafeGetOutput(Node* node, int i, bool* error) {
    124     if (node != nullptr && i >= 0 && i < node->num_outputs()) {
    125       *error = false;
    126       return node->output_type(i);
    127     } else {
    128       *error = true;
    129       return DT_FLOAT;
    130     }
    131   }
    132 
    133   // If SafeGetOutput indicates a range error, add it to errors_.
    134   void AddIndexError(Node* node, int i);
    135 
    136   // Set *dt and returns true if i is in range. Combines
    137   // SafeGetOutput() and AddIndexError().
    138   bool GetOutputType(Node* node, int i, DataType* dt);
    139 
    140   NodeDefBuilder def_builder_;
    141   std::vector<NodeOut> inputs_;
    142   std::vector<Node*> control_inputs_;
    143   std::vector<string> errors_;
    144 };
    145 
    146 // IMPLEMENTATION -------------------------------------------------------------
    147 
    148 template <class T>
    149 NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) {
    150   def_builder_.Attr(attr_name, std::forward<T>(value));
    151   return *this;
    152 }
    153 
    154 template <class T>
    155 NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
    156                                std::initializer_list<T> value) {
    157   def_builder_.Attr(attr_name, value);
    158   return *this;
    159 }
    160 
    161 }  // namespace tensorflow
    162 
    163 #endif  // TENSORFLOW_GRAPH_NODE_BUILDER_H_
    164