Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/graph_def_builder.h"
     17 
     18 #include <utility>
     19 
     20 #include "tensorflow/core/graph/tensor_id.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 
     23 namespace tensorflow {
     24 
     25 GraphDefBuilder::Options::Options(Graph* graph, Status* status)
     26     : graph_(graph), status_(status) {}
     27 GraphDefBuilder::Options::~Options() {}
     28 
     29 GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
     30     StringPiece name) const {
     31   return Options(*this).WithNameImpl(name);
     32 }
     33 GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
     34     StringPiece device) const {
     35   return Options(*this).WithDeviceImpl(device);
     36 }
     37 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
     38     Node* control_input) const {
     39   return Options(*this).WithControlInputImpl(control_input);
     40 }
     41 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
     42     gtl::ArraySlice<Node*> control_inputs) const {
     43   return Options(*this).WithControlInputsImpl(control_inputs);
     44 }
     45 GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
     46     StringPiece name) {
     47   name_ = name.ToString();
     48   return *this;
     49 }
     50 GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
     51     StringPiece device) {
     52   device_ = device.ToString();
     53   return *this;
     54 }
     55 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
     56     Node* control_input) {
     57   control_inputs_.push_back(control_input);
     58   return *this;
     59 }
     60 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
     61     gtl::ArraySlice<Node*> control_inputs) {
     62   control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
     63                          control_inputs.end());
     64   return *this;
     65 }
     66 
     67 Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
     68   if (status_.ok()) {
     69     graph_.ToGraphDef(graph_def);
     70   }
     71   return status_;
     72 }
     73 
     74 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
     75   if (name_.empty()) return graph_->NewName(op);
     76   return name_;
     77 }
     78 
     79 Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
     80   builder->ControlInputs(control_inputs_);
     81   if (!device_.empty()) builder->Device(device_);
     82   for (const auto& attr : attrs_) {
     83     builder->Attr(attr.first, attr.second);
     84   }
     85 
     86   Node* returned_node;
     87   UpdateStatus(builder->Finalize(graph_, &returned_node));
     88   return returned_node;
     89 }
     90 
     91 void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
     92   if (status_ == nullptr) {
     93     TF_CHECK_OK(status);
     94   } else {
     95     status_->Update(status);
     96   }
     97 }
     98 
     99 namespace ops {
    100 
    101 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
    102   if (opts.HaveError()) return nullptr;
    103   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
    104                            opts.op_registry());
    105   return opts.FinalizeBuilder(&node_builder);
    106 }
    107 
    108 Node* UnaryOp(const string& op_name, NodeOut input,
    109               const GraphDefBuilder::Options& opts) {
    110   if (opts.HaveError()) return nullptr;
    111   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
    112                            opts.op_registry());
    113   node_builder.Input(std::move(input));
    114   return opts.FinalizeBuilder(&node_builder);
    115 }
    116 
    117 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
    118                const GraphDefBuilder::Options& opts) {
    119   if (opts.HaveError()) return nullptr;
    120   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
    121                            opts.op_registry());
    122   node_builder.Input(std::move(a)).Input(std::move(b));
    123   return opts.FinalizeBuilder(&node_builder);
    124 }
    125 
    126 }  // end namespace ops
    127 }  // end namespace tensorflow
    128