1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/node_builder.h" 17 18 #include <vector> 19 #include "tensorflow/core/framework/node_def_util.h" 20 #include "tensorflow/core/framework/versions.pb.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 23 namespace tensorflow { 24 25 NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit) 26 : node(n), 27 error(false), 28 name(node != nullptr ? node->name() : (error = true, "")), 29 index(i), 30 dt(SafeGetOutput(node, i, &error)) {} 31 32 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t) 33 : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {} 34 35 NodeBuilder::NodeOut::NodeOut() 36 : node(nullptr), error(true), index(0), dt(DT_FLOAT) {} 37 38 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name, 39 const OpRegistryInterface* op_registry) 40 : def_builder_(name, op_name, op_registry) {} 41 42 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def) 43 : def_builder_(name, op_def) {} 44 45 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder) 46 : def_builder_(def_builder) {} 47 48 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { 49 inputs_.emplace_back(src_node, src_index); 50 DataType dt; 51 if (GetOutputType(src_node, src_index, &dt)) { 52 def_builder_.Input(src_node->name(), src_index, dt); 53 } 54 return *this; 55 } 56 57 NodeBuilder& NodeBuilder::Input(NodeOut src) { 58 if (src.error) { 59 AddIndexError(src.node, src.index); 60 } else { 61 inputs_.emplace_back(src.node, src.index); 62 def_builder_.Input(src.name, src.index, src.dt); 63 } 64 return *this; 65 } 66 67 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) { 68 std::vector<NodeDefBuilder::NodeOut> srcs; 69 srcs.reserve(src_list.size()); 70 for (const auto& node_out : src_list) { 71 if (node_out.error) { 72 AddIndexError(node_out.node, node_out.index); 73 } else { 74 srcs.emplace_back(node_out.name, node_out.index, node_out.dt); 75 inputs_.emplace_back(node_out.node, node_out.index); 76 } 77 } 78 def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs)); 79 return *this; 80 } 81 82 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) { 83 control_inputs_.emplace_back(src_node); 84 def_builder_.ControlInput(src_node->name()); 85 return *this; 86 } 87 88 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) { 89 control_inputs_.insert(control_inputs_.end(), src_nodes.begin(), 90 src_nodes.end()); 91 for (Node* src_node : src_nodes) { 92 def_builder_.ControlInput(src_node->name()); 93 } 94 return *this; 95 } 96 97 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { 98 def_builder_.Device(device_spec); 99 return *this; 100 } 101 102 Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { 103 // In case of error, set *created_node to nullptr. 104 if (created_node != nullptr) *created_node = nullptr; 105 if (!errors_.empty()) { 106 return errors::InvalidArgument(str_util::Join(errors_, "\n")); 107 } 108 109 NodeDef node_def; 110 TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def)); 111 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); 112 TF_RETURN_IF_ERROR( 113 CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); 114 Status status; 115 Node* node = graph->AddNode(node_def, &status); 116 if (!status.ok()) return status; 117 118 for (size_t i = 0; i < inputs_.size(); ++i) { 119 if (inputs_[i].node != nullptr) { // Skip back edges. 120 graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); 121 } 122 } 123 for (Node* control_input : control_inputs_) { 124 graph->AddControlEdge(control_input, node); 125 } 126 if (created_node != nullptr) *created_node = node; 127 return Status::OK(); 128 } 129 130 void NodeBuilder::AddIndexError(Node* node, int i) { 131 if (node == nullptr) { 132 errors_.emplace_back( 133 strings::StrCat("Attempt to add nullptr Node to node with type ", 134 def_builder_.op_def().name())); 135 } else { 136 errors_.emplace_back( 137 strings::StrCat("Attempt to add output ", i, " of ", node->name(), 138 " not in range [0, ", node->num_outputs(), 139 ") to node with type ", def_builder_.op_def().name())); 140 } 141 } 142 143 bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) { 144 bool error; 145 *dt = SafeGetOutput(node, i, &error); 146 if (error) AddIndexError(node, i); 147 return !error; 148 } 149 150 } // namespace tensorflow 151