1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/graph_def_builder.h" 17 18 #include <utility> 19 20 #include "tensorflow/core/graph/tensor_id.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 23 namespace tensorflow { 24 25 GraphDefBuilder::Options::Options(Graph* graph, Status* status) 26 : graph_(graph), status_(status) {} 27 GraphDefBuilder::Options::~Options() {} 28 29 GraphDefBuilder::Options GraphDefBuilder::Options::WithName( 30 StringPiece name) const { 31 return Options(*this).WithNameImpl(name); 32 } 33 GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice( 34 StringPiece device) const { 35 return Options(*this).WithDeviceImpl(device); 36 } 37 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput( 38 Node* control_input) const { 39 return Options(*this).WithControlInputImpl(control_input); 40 } 41 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs( 42 gtl::ArraySlice<Node*> control_inputs) const { 43 return Options(*this).WithControlInputsImpl(control_inputs); 44 } 45 GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl( 46 StringPiece name) { 47 name_ = name.ToString(); 48 return *this; 49 } 50 GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl( 51 StringPiece device) { 52 device_ = device.ToString(); 53 return *this; 54 } 55 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl( 56 Node* control_input) { 57 control_inputs_.push_back(control_input); 58 return *this; 59 } 60 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl( 61 gtl::ArraySlice<Node*> control_inputs) { 62 control_inputs_.insert(control_inputs_.end(), control_inputs.begin(), 63 control_inputs.end()); 64 return *this; 65 } 66 67 Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { 68 if (status_.ok()) { 69 graph_.ToGraphDef(graph_def); 70 } 71 return status_; 72 } 73 74 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const { 75 if (name_.empty()) return graph_->NewName(op); 76 return name_; 77 } 78 79 Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const { 80 builder->ControlInputs(control_inputs_); 81 if (!device_.empty()) builder->Device(device_); 82 for (const auto& attr : attrs_) { 83 builder->Attr(attr.first, attr.second); 84 } 85 86 Node* returned_node; 87 UpdateStatus(builder->Finalize(graph_, &returned_node)); 88 return returned_node; 89 } 90 91 void GraphDefBuilder::Options::UpdateStatus(const Status& status) const { 92 if (status_ == nullptr) { 93 TF_CHECK_OK(status); 94 } else { 95 status_->Update(status); 96 } 97 } 98 99 namespace ops { 100 101 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) { 102 if (opts.HaveError()) return nullptr; 103 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, 104 opts.op_registry()); 105 return opts.FinalizeBuilder(&node_builder); 106 } 107 108 Node* UnaryOp(const string& op_name, NodeOut input, 109 const GraphDefBuilder::Options& opts) { 110 if (opts.HaveError()) return nullptr; 111 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, 112 opts.op_registry()); 113 node_builder.Input(std::move(input)); 114 return opts.FinalizeBuilder(&node_builder); 115 } 116 117 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, 118 const GraphDefBuilder::Options& opts) { 119 if (opts.HaveError()) return nullptr; 120 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name, 121 opts.op_registry()); 122 node_builder.Input(std::move(a)).Input(std::move(b)); 123 return opts.FinalizeBuilder(&node_builder); 124 } 125 126 } // end namespace ops 127 } // end namespace tensorflow 128