Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/node_builder.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/node_def_util.h"
     20 #include "tensorflow/core/framework/versions.pb.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 
     23 namespace tensorflow {
     24 
     25 NodeBuilder::NodeOut::NodeOut(Node* n, int32 i)  // NOLINT(runtime/explicit)
     26     : node(n),
     27       error(false),
     28       name(node != nullptr ? node->name() : (error = true, "")),
     29       index(i),
     30       dt(SafeGetOutput(node, i, &error)) {}
     31 
     32 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
     33     : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {}
     34 
     35 NodeBuilder::NodeOut::NodeOut()
     36     : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
     37 
     38 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
     39                          const OpRegistryInterface* op_registry)
     40     : def_builder_(name, op_name, op_registry) {}
     41 
     42 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
     43     : def_builder_(name, op_def) {}
     44 
     45 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
     46     : def_builder_(def_builder) {}
     47 
     48 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
     49   inputs_.emplace_back(src_node, src_index);
     50   DataType dt;
     51   if (GetOutputType(src_node, src_index, &dt)) {
     52     def_builder_.Input(src_node->name(), src_index, dt);
     53   }
     54   return *this;
     55 }
     56 
     57 NodeBuilder& NodeBuilder::Input(NodeOut src) {
     58   if (src.error) {
     59     AddIndexError(src.node, src.index);
     60   } else {
     61     inputs_.emplace_back(src.node, src.index);
     62     def_builder_.Input(src.name, src.index, src.dt);
     63   }
     64   return *this;
     65 }
     66 
     67 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
     68   std::vector<NodeDefBuilder::NodeOut> srcs;
     69   srcs.reserve(src_list.size());
     70   for (const auto& node_out : src_list) {
     71     if (node_out.error) {
     72       AddIndexError(node_out.node, node_out.index);
     73     } else {
     74       srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
     75       inputs_.emplace_back(node_out.node, node_out.index);
     76     }
     77   }
     78   def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
     79   return *this;
     80 }
     81 
     82 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
     83   control_inputs_.emplace_back(src_node);
     84   def_builder_.ControlInput(src_node->name());
     85   return *this;
     86 }
     87 
     88 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
     89   control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
     90                          src_nodes.end());
     91   for (Node* src_node : src_nodes) {
     92     def_builder_.ControlInput(src_node->name());
     93   }
     94   return *this;
     95 }
     96 
     97 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
     98   def_builder_.Device(device_spec);
     99   return *this;
    100 }
    101 
    102 Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
    103   // In case of error, set *created_node to nullptr.
    104   if (created_node != nullptr) *created_node = nullptr;
    105   if (!errors_.empty()) {
    106     return errors::InvalidArgument(str_util::Join(errors_, "\n"));
    107   }
    108 
    109   NodeDef node_def;
    110   TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def));
    111   TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
    112   TF_RETURN_IF_ERROR(
    113       CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
    114   Status status;
    115   Node* node = graph->AddNode(node_def, &status);
    116   if (!status.ok()) return status;
    117 
    118   for (size_t i = 0; i < inputs_.size(); ++i) {
    119     if (inputs_[i].node != nullptr) {  // Skip back edges.
    120       graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
    121     }
    122   }
    123   for (Node* control_input : control_inputs_) {
    124     graph->AddControlEdge(control_input, node);
    125   }
    126   if (created_node != nullptr) *created_node = node;
    127   return Status::OK();
    128 }
    129 
    130 void NodeBuilder::AddIndexError(Node* node, int i) {
    131   if (node == nullptr) {
    132     errors_.emplace_back(
    133         strings::StrCat("Attempt to add nullptr Node to node with type ",
    134                         def_builder_.op_def().name()));
    135   } else {
    136     errors_.emplace_back(
    137         strings::StrCat("Attempt to add output ", i, " of ", node->name(),
    138                         " not in range [0, ", node->num_outputs(),
    139                         ") to node with type ", def_builder_.op_def().name()));
    140   }
    141 }
    142 
    143 bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) {
    144   bool error;
    145   *dt = SafeGetOutput(node, i, &error);
    146   if (error) AddIndexError(node, i);
    147   return !error;
    148 }
    149 
    150 }  // namespace tensorflow
    151