Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <deque>
     17 #include <vector>
     18 
     19 #include "tensorflow/core/common_runtime/device.h"
     20 #include "tensorflow/core/common_runtime/executor.h"
     21 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     22 #include "tensorflow/core/framework/function.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 #include "tensorflow/core/framework/node_def_util.h"
     25 #include "tensorflow/core/framework/op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/graph/algorithm.h"
     28 #include "tensorflow/core/graph/gradients.h"
     29 #include "tensorflow/core/graph/graph_constructor.h"
     30 #include "tensorflow/core/graph/optimizer_cse.h"
     31 #include "tensorflow/core/lib/gtl/map_util.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 
     34 namespace tensorflow {
     35 
     36 // TODO(andydavis) Remove some of the code duplicated between this module
     37 // and that in 'common_runtime/function.cc'.
     38 // A few string constant used throughout this module.
     39 static const char* const kGradientOp = "SymbolicGradient";
     40 static const char* const kNodeLabel = "Func";
     41 
     42 string NodeOut::name() const {
     43   if (index == 0) {
     44     return node->name();
     45   } else {
     46     return strings::StrCat(node->name(), ":", index);
     47   }
     48 }
     49 
     50 DataType NodeOut::dtype() const { return node->output_type(index); }
     51 
     52 struct NodeOutHash {
     53   uint64 operator()(const NodeOut& x) const {
     54     return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
     55                   x.index);
     56   }
     57 };
     58 
     59 struct NodeOutEq {
     60   bool operator()(const NodeOut& x, const NodeOut& y) const {
     61     return (x.node == y.node) && (x.index == y.index);
     62   }
     63 };
     64 
     65 static Node* AddZerosLike(Graph* g, NodeOut input) {
     66   DCHECK_LT(0, input.dtype());
     67   DCHECK_LT(input.dtype(), DT_FLOAT_REF);
     68   NodeDef ndef;
     69   ndef.set_name(g->NewName(kNodeLabel));
     70   ndef.set_op("ZerosLike");
     71   ndef.add_input(input.name());
     72   AddNodeAttr("T", input.dtype(), &ndef);
     73   Status s;
     74   Node* ret = g->AddNode(ndef, &s);
     75   TF_CHECK_OK(s);
     76   g->AddEdge(input.node, input.index, ret, 0);
     77   return ret;
     78 }
     79 
     80 static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
     81   const int num_x = n->num_inputs();
     82   const int num_y = n->num_outputs();
     83   CHECK_EQ(num_y, grads.size());
     84 
     85   NodeDef ndef;
     86   ndef.set_name(g->NewName(kNodeLabel));
     87   ndef.set_op(kGradientOp);
     88 
     89   // The gradient node should have num_x + num_y inputs.
     90   std::vector<NodeOut> n_inputs(num_x);
     91   for (const Edge* e : n->in_edges()) {
     92     if (e->IsControlEdge()) continue;
     93     n_inputs[e->dst_input()] = {e->src(), e->src_output()};
     94   }
     95   DataTypeVector in_types;
     96   for (const NodeOut& nout : n_inputs) {
     97     ndef.add_input(nout.name());
     98     in_types.push_back(nout.dtype());
     99   }
    100   for (const NodeOut& nout : grads) {
    101     ndef.add_input(nout.name());
    102     in_types.push_back(nout.dtype());
    103   }
    104   CHECK_EQ(ndef.input_size(), num_x + num_y);
    105 
    106   AddNodeAttr("Tin", in_types, &ndef);
    107 
    108   // The gradient node's outputs have the same types as the node 'n's
    109   // inputs.
    110   AddNodeAttr("Tout", n->input_types(), &ndef);
    111   NameAttrList func;
    112   func.set_name(n->type_string());
    113   for (const auto& attr : n->attrs()) {
    114     (*func.mutable_attr())[attr.first] = attr.second;
    115   }
    116   AddNodeAttr("f", func, &ndef);
    117   Status s;
    118   Node* ret = g->AddNode(ndef, &s);
    119   TF_CHECK_OK(s);
    120   return ret;
    121 }
    122 
    123 class SymbolicGradientBuilder {
    124  public:
    125   SymbolicGradientBuilder(gtl::ArraySlice<NodeOut> y_node_outputs,
    126                           gtl::ArraySlice<NodeOut> x_node_outputs,
    127                           gtl::ArraySlice<NodeOut> y_grad_node_outputs,
    128                           std::vector<NodeOut>* x_grad_node_outputs,
    129                           Graph* graph);
    130 
    131   Status Compute();
    132 
    133  private:
    134   gtl::ArraySlice<NodeOut> y_node_outputs_;
    135   gtl::ArraySlice<NodeOut> x_node_outputs_;
    136   gtl::ArraySlice<NodeOut> y_grad_node_outputs_;
    137   std::vector<NodeOut>* x_grad_node_outputs_;
    138   Graph* graph_;  // Not owned.
    139 
    140   // A vector of output endpoints which represents backpropagated
    141   // gradients
    142   typedef std::vector<NodeOut> BackpropedGradients;
    143 
    144   // backprops_ is a map from a node output to its accumulated
    145   // gradients.  When a node output has accumulated all its
    146   // gradients, we add a node which sums them up.
    147   std::unordered_map<NodeOut, BackpropedGradients, NodeOutHash, NodeOutEq>
    148       backprops_;
    149 
    150   // pending[i] is count-down counter for i-th node's expected
    151   // backprops.  When pending[i] becomes zero, we collected all
    152   // backprop gradients for all outputs of the ith-node.
    153   std::vector<int> pending_;
    154 
    155   // 'ready' keeps track of nodes that have been completely
    156   // backpropped. Initially, for every output y of the function f, we
    157   // add dy as an input of the gradient function.
    158   std::deque<Node*> ready_;
    159 
    160   // The set of node ids at which to stop backprop.
    161   std::unordered_set<int> stop_nodes_;
    162 
    163   // Initialize pending_ and ready_.
    164   void InitBackprop();
    165 
    166   // In the original function body, there is a forward edge from 'src'
    167   // to 'dst', when the backprop algorithm constructs the node
    168   // 'dst_grad' which computes the gradient, we need to propagate it
    169   // to 'src'.
    170   void BackpropAlongEdge(const NodeOut& dst_grad, const NodeOut& src);
    171   void BackpropZerosAlongEdge(const NodeOut& src);
    172 
    173   NodeOut SumGradients(const NodeOut& src);
    174 
    175   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder);
    176 };
    177 
    178 SymbolicGradientBuilder::SymbolicGradientBuilder(
    179     gtl::ArraySlice<NodeOut> y_node_outputs,
    180     gtl::ArraySlice<NodeOut> x_node_outputs,
    181     gtl::ArraySlice<NodeOut> y_grad_node_outputs,
    182     std::vector<NodeOut>* x_grad_node_outputs, Graph* graph)
    183     : y_node_outputs_(y_node_outputs),
    184       x_node_outputs_(x_node_outputs),
    185       y_grad_node_outputs_(y_grad_node_outputs),
    186       x_grad_node_outputs_(x_grad_node_outputs),
    187       graph_(graph) {
    188   CHECK_EQ(y_node_outputs_.size(), y_grad_node_outputs.size());
    189   x_grad_node_outputs_->clear();
    190   x_grad_node_outputs_->resize(x_node_outputs_.size());
    191   stop_nodes_.reserve(x_node_outputs_.size());
    192   for (int i = 0; i < x_node_outputs_.size(); ++i) {
    193     stop_nodes_.insert(x_node_outputs_[i].node->id());
    194   }
    195 }
    196 
    197 void SymbolicGradientBuilder::BackpropAlongEdge(const NodeOut& dst_grad,
    198                                                 const NodeOut& src) {
    199   CHECK_NOTNULL(src.node);
    200   auto iter = backprops_.find(src);
    201   if (iter != backprops_.end()) {
    202     auto* grads = &iter->second;
    203     grads->push_back(dst_grad);
    204     if (--pending_[src.node->id()] == 0) {
    205       ready_.push_back(src.node);
    206     }
    207   }
    208 }
    209 
    210 void SymbolicGradientBuilder::BackpropZerosAlongEdge(const NodeOut& src) {
    211   CHECK_NOTNULL(src.node);
    212   auto iter = backprops_.find(src);
    213   if (iter != backprops_.end()) {
    214     if (--pending_[src.node->id()] == 0) {
    215       ready_.push_back(src.node);
    216     }
    217   }
    218 }
    219 
    220 void SymbolicGradientBuilder::InitBackprop() {
    221   pending_.resize(graph_->num_node_ids(), 0);
    222   {
    223     backprops_.clear();
    224     std::unordered_set<Node*> visited;
    225     std::deque<Node*> queue;
    226     for (const NodeOut& nout : x_node_outputs_) {
    227       queue.push_back(nout.node);
    228       visited.insert(nout.node);
    229     }
    230 
    231     // Going forward to figure out which endpoints need backprop-ed.
    232     // A node's endpoints need to be backprop-ed only if one of the
    233     // arg node can reach the node via data edges.
    234     while (!queue.empty()) {
    235       Node* n = queue.front();
    236       queue.pop_front();
    237       for (int i = 0; i < n->num_outputs(); ++i) {
    238         backprops_[{n, i}].clear();
    239       }
    240       int num_expected_backprops = 0;
    241       for (const Edge* e : n->out_edges()) {
    242         if (e->IsControlEdge()) continue;
    243         ++num_expected_backprops;
    244         if (visited.find(e->dst()) == visited.end()) {
    245           queue.push_back(e->dst());
    246           visited.insert(e->dst());
    247         }
    248       }
    249       pending_[n->id()] = num_expected_backprops;
    250     }
    251   }
    252 
    253   {
    254     const int num_y = y_grad_node_outputs_.size();
    255     for (int i = 0; i < num_y; ++i) {
    256       Node* y = y_node_outputs_[i].node;
    257       for (const Edge* e : y->in_edges()) {
    258         if (e->IsControlEdge()) continue;
    259         BackpropAlongEdge(y_grad_node_outputs_[i], {e->src(), e->src_output()});
    260       }
    261     }
    262   }
    263   CHECK(!ready_.empty());
    264 }
    265 
    266 NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) {
    267   const DataType dtype = src.dtype();
    268   auto iter = backprops_.find(src);
    269   CHECK(iter != backprops_.end());
    270   const auto& grads = iter->second;
    271   if (grads.empty()) {
    272     // Nothing propagated back. The best we can come up is zeros.
    273     Node* zero_like = AddZerosLike(graph_, src);
    274     return {zero_like, 0};
    275   }
    276   if (grads.size() == 1) {
    277     // Just one backprop edge.
    278     return grads[0];
    279   }
    280   // Otherwise, adds backprop-ed gradients.
    281   NodeDef ndef;
    282   ndef.set_name(graph_->NewName(kNodeLabel));
    283   ndef.set_op("AddN");  // N-way Add
    284   for (const NodeOut& nout : grads) {
    285     ndef.add_input(nout.name());
    286   }
    287   AddNodeAttr("N", static_cast<int64>(grads.size()), &ndef);
    288   AddNodeAttr("T", dtype, &ndef);
    289   Status s;
    290   Node* add = graph_->AddNode(ndef, &s);
    291   TF_CHECK_OK(s);
    292   for (size_t i = 0; i < grads.size(); ++i) {
    293     const NodeOut& nout = grads[i];
    294     graph_->AddEdge(nout.node, nout.index, add, i);
    295   }
    296   return {add, 0};
    297 }
    298 
    299 static bool IsPrimitiveOpWithNoGrad(const string& func) {
    300   gradient::Creator creator;
    301   Status s = gradient::GetOpGradientCreator(func, &creator);
    302   return s.ok() && (creator == nullptr);
    303 }
    304 
    305 Status SymbolicGradientBuilder::Compute() {
    306   // Initialize backprops.
    307   InitBackprop();
    308 
    309   // Backward propagation.
    310   gtl::InlinedVector<NodeOut, 8> dy;
    311   while (!ready_.empty()) {
    312     // n has collected all gradients.
    313     Node* n = ready_.front();
    314     ready_.pop_front();
    315 
    316     // "n" has num_x inputs and num_y outputs.
    317     const int num_x = n->num_inputs();
    318     const int num_y = n->num_outputs();
    319 
    320     auto iter = stop_nodes_.find(n->id());
    321     if (iter != stop_nodes_.end()) {
    322       // Stop backprop.
    323       // TODO(andydavis) Support stop nodes with more than one output.
    324       CHECK_EQ(1, num_y);
    325       continue;
    326     }
    327 
    328     // dy[i] is the sum of i-th output's backpropped gradients.
    329     dy.clear();
    330     dy.resize(num_y, {nullptr, 0});
    331     for (int i = 0; i < num_y; ++i) {
    332       dy[i] = SumGradients({n, i});
    333     }
    334 
    335     if (IsPrimitiveOpWithNoGrad(n->type_string())) {
    336       // No grad defined for this op: Backprop zeros along the in edges.
    337       for (const Edge* e : n->in_edges()) {
    338         if (e->IsControlEdge()) continue;
    339         BackpropZerosAlongEdge({e->src(), e->src_output()});
    340       }
    341       continue;
    342     }
    343 
    344     // Adds a gradient node with num_x + num_y inputs and num_x
    345     // outputs.
    346     // TODO(andydavis) Support primitive gradient ops.
    347     Node* grad = AddSymGrad(graph_, n, dy);
    348     for (const Edge* e : n->in_edges()) {
    349       if (e->IsControlEdge()) continue;
    350       graph_->AddEdge(e->src(), e->src_output(), grad, e->dst_input());
    351     }
    352     for (int i = 0; i < num_y; ++i) {
    353       graph_->AddEdge(dy[i].node, dy[i].index, grad, num_x + i);
    354     }
    355 
    356     // Backprops along the in edges.
    357     for (const Edge* e : n->in_edges()) {
    358       if (e->IsControlEdge()) continue;
    359       BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()});
    360     }
    361   }
    362 
    363   for (int i = 0; i < x_node_outputs_.size(); ++i) {
    364     (*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]);
    365   }
    366 
    367   return Status::OK();
    368 }
    369 
    370 Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
    371                             gtl::ArraySlice<NodeOut> x_node_outputs,
    372                             gtl::ArraySlice<NodeOut> y_grad_node_outputs,
    373                             std::vector<NodeOut>* x_grad_node_outputs,
    374                             Graph* graph) {
    375   SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs,
    376                                   y_grad_node_outputs, x_grad_node_outputs,
    377                                   graph);
    378   return builder.Compute();
    379 }
    380 
    381 }  // end namespace tensorflow
    382