1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <deque> 17 #include <vector> 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/executor.h" 21 #include "tensorflow/core/common_runtime/graph_optimizer.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 #include "tensorflow/core/framework/node_def_util.h" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/graph/algorithm.h" 28 #include "tensorflow/core/graph/gradients.h" 29 #include "tensorflow/core/graph/graph_constructor.h" 30 #include "tensorflow/core/graph/optimizer_cse.h" 31 #include "tensorflow/core/lib/gtl/map_util.h" 32 #include "tensorflow/core/platform/macros.h" 33 34 namespace tensorflow { 35 36 // TODO(andydavis) Remove some of the code duplicated between this module 37 // and that in 'common_runtime/function.cc'. 38 // A few string constant used throughout this module. 39 static const char* const kGradientOp = "SymbolicGradient"; 40 static const char* const kNodeLabel = "Func"; 41 42 string NodeOut::name() const { 43 if (index == 0) { 44 return node->name(); 45 } else { 46 return strings::StrCat(node->name(), ":", index); 47 } 48 } 49 50 DataType NodeOut::dtype() const { return node->output_type(index); } 51 52 struct NodeOutHash { 53 uint64 operator()(const NodeOut& x) const { 54 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*), 55 x.index); 56 } 57 }; 58 59 struct NodeOutEq { 60 bool operator()(const NodeOut& x, const NodeOut& y) const { 61 return (x.node == y.node) && (x.index == y.index); 62 } 63 }; 64 65 static Node* AddZerosLike(Graph* g, NodeOut input) { 66 DCHECK_LT(0, input.dtype()); 67 DCHECK_LT(input.dtype(), DT_FLOAT_REF); 68 NodeDef ndef; 69 ndef.set_name(g->NewName(kNodeLabel)); 70 ndef.set_op("ZerosLike"); 71 ndef.add_input(input.name()); 72 AddNodeAttr("T", input.dtype(), &ndef); 73 Status s; 74 Node* ret = g->AddNode(ndef, &s); 75 TF_CHECK_OK(s); 76 g->AddEdge(input.node, input.index, ret, 0); 77 return ret; 78 } 79 80 static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) { 81 const int num_x = n->num_inputs(); 82 const int num_y = n->num_outputs(); 83 CHECK_EQ(num_y, grads.size()); 84 85 NodeDef ndef; 86 ndef.set_name(g->NewName(kNodeLabel)); 87 ndef.set_op(kGradientOp); 88 89 // The gradient node should have num_x + num_y inputs. 90 std::vector<NodeOut> n_inputs(num_x); 91 for (const Edge* e : n->in_edges()) { 92 if (e->IsControlEdge()) continue; 93 n_inputs[e->dst_input()] = {e->src(), e->src_output()}; 94 } 95 DataTypeVector in_types; 96 for (const NodeOut& nout : n_inputs) { 97 ndef.add_input(nout.name()); 98 in_types.push_back(nout.dtype()); 99 } 100 for (const NodeOut& nout : grads) { 101 ndef.add_input(nout.name()); 102 in_types.push_back(nout.dtype()); 103 } 104 CHECK_EQ(ndef.input_size(), num_x + num_y); 105 106 AddNodeAttr("Tin", in_types, &ndef); 107 108 // The gradient node's outputs have the same types as the node 'n's 109 // inputs. 110 AddNodeAttr("Tout", n->input_types(), &ndef); 111 NameAttrList func; 112 func.set_name(n->type_string()); 113 for (const auto& attr : n->attrs()) { 114 (*func.mutable_attr())[attr.first] = attr.second; 115 } 116 AddNodeAttr("f", func, &ndef); 117 Status s; 118 Node* ret = g->AddNode(ndef, &s); 119 TF_CHECK_OK(s); 120 return ret; 121 } 122 123 class SymbolicGradientBuilder { 124 public: 125 SymbolicGradientBuilder(gtl::ArraySlice<NodeOut> y_node_outputs, 126 gtl::ArraySlice<NodeOut> x_node_outputs, 127 gtl::ArraySlice<NodeOut> y_grad_node_outputs, 128 std::vector<NodeOut>* x_grad_node_outputs, 129 Graph* graph); 130 131 Status Compute(); 132 133 private: 134 gtl::ArraySlice<NodeOut> y_node_outputs_; 135 gtl::ArraySlice<NodeOut> x_node_outputs_; 136 gtl::ArraySlice<NodeOut> y_grad_node_outputs_; 137 std::vector<NodeOut>* x_grad_node_outputs_; 138 Graph* graph_; // Not owned. 139 140 // A vector of output endpoints which represents backpropagated 141 // gradients 142 typedef std::vector<NodeOut> BackpropedGradients; 143 144 // backprops_ is a map from a node output to its accumulated 145 // gradients. When a node output has accumulated all its 146 // gradients, we add a node which sums them up. 147 std::unordered_map<NodeOut, BackpropedGradients, NodeOutHash, NodeOutEq> 148 backprops_; 149 150 // pending[i] is count-down counter for i-th node's expected 151 // backprops. When pending[i] becomes zero, we collected all 152 // backprop gradients for all outputs of the ith-node. 153 std::vector<int> pending_; 154 155 // 'ready' keeps track of nodes that have been completely 156 // backpropped. Initially, for every output y of the function f, we 157 // add dy as an input of the gradient function. 158 std::deque<Node*> ready_; 159 160 // The set of node ids at which to stop backprop. 161 std::unordered_set<int> stop_nodes_; 162 163 // Initialize pending_ and ready_. 164 void InitBackprop(); 165 166 // In the original function body, there is a forward edge from 'src' 167 // to 'dst', when the backprop algorithm constructs the node 168 // 'dst_grad' which computes the gradient, we need to propagate it 169 // to 'src'. 170 void BackpropAlongEdge(const NodeOut& dst_grad, const NodeOut& src); 171 void BackpropZerosAlongEdge(const NodeOut& src); 172 173 NodeOut SumGradients(const NodeOut& src); 174 175 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder); 176 }; 177 178 SymbolicGradientBuilder::SymbolicGradientBuilder( 179 gtl::ArraySlice<NodeOut> y_node_outputs, 180 gtl::ArraySlice<NodeOut> x_node_outputs, 181 gtl::ArraySlice<NodeOut> y_grad_node_outputs, 182 std::vector<NodeOut>* x_grad_node_outputs, Graph* graph) 183 : y_node_outputs_(y_node_outputs), 184 x_node_outputs_(x_node_outputs), 185 y_grad_node_outputs_(y_grad_node_outputs), 186 x_grad_node_outputs_(x_grad_node_outputs), 187 graph_(graph) { 188 CHECK_EQ(y_node_outputs_.size(), y_grad_node_outputs.size()); 189 x_grad_node_outputs_->clear(); 190 x_grad_node_outputs_->resize(x_node_outputs_.size()); 191 stop_nodes_.reserve(x_node_outputs_.size()); 192 for (int i = 0; i < x_node_outputs_.size(); ++i) { 193 stop_nodes_.insert(x_node_outputs_[i].node->id()); 194 } 195 } 196 197 void SymbolicGradientBuilder::BackpropAlongEdge(const NodeOut& dst_grad, 198 const NodeOut& src) { 199 CHECK_NOTNULL(src.node); 200 auto iter = backprops_.find(src); 201 if (iter != backprops_.end()) { 202 auto* grads = &iter->second; 203 grads->push_back(dst_grad); 204 if (--pending_[src.node->id()] == 0) { 205 ready_.push_back(src.node); 206 } 207 } 208 } 209 210 void SymbolicGradientBuilder::BackpropZerosAlongEdge(const NodeOut& src) { 211 CHECK_NOTNULL(src.node); 212 auto iter = backprops_.find(src); 213 if (iter != backprops_.end()) { 214 if (--pending_[src.node->id()] == 0) { 215 ready_.push_back(src.node); 216 } 217 } 218 } 219 220 void SymbolicGradientBuilder::InitBackprop() { 221 pending_.resize(graph_->num_node_ids(), 0); 222 { 223 backprops_.clear(); 224 std::unordered_set<Node*> visited; 225 std::deque<Node*> queue; 226 for (const NodeOut& nout : x_node_outputs_) { 227 queue.push_back(nout.node); 228 visited.insert(nout.node); 229 } 230 231 // Going forward to figure out which endpoints need backprop-ed. 232 // A node's endpoints need to be backprop-ed only if one of the 233 // arg node can reach the node via data edges. 234 while (!queue.empty()) { 235 Node* n = queue.front(); 236 queue.pop_front(); 237 for (int i = 0; i < n->num_outputs(); ++i) { 238 backprops_[{n, i}].clear(); 239 } 240 int num_expected_backprops = 0; 241 for (const Edge* e : n->out_edges()) { 242 if (e->IsControlEdge()) continue; 243 ++num_expected_backprops; 244 if (visited.find(e->dst()) == visited.end()) { 245 queue.push_back(e->dst()); 246 visited.insert(e->dst()); 247 } 248 } 249 pending_[n->id()] = num_expected_backprops; 250 } 251 } 252 253 { 254 const int num_y = y_grad_node_outputs_.size(); 255 for (int i = 0; i < num_y; ++i) { 256 Node* y = y_node_outputs_[i].node; 257 for (const Edge* e : y->in_edges()) { 258 if (e->IsControlEdge()) continue; 259 BackpropAlongEdge(y_grad_node_outputs_[i], {e->src(), e->src_output()}); 260 } 261 } 262 } 263 CHECK(!ready_.empty()); 264 } 265 266 NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) { 267 const DataType dtype = src.dtype(); 268 auto iter = backprops_.find(src); 269 CHECK(iter != backprops_.end()); 270 const auto& grads = iter->second; 271 if (grads.empty()) { 272 // Nothing propagated back. The best we can come up is zeros. 273 Node* zero_like = AddZerosLike(graph_, src); 274 return {zero_like, 0}; 275 } 276 if (grads.size() == 1) { 277 // Just one backprop edge. 278 return grads[0]; 279 } 280 // Otherwise, adds backprop-ed gradients. 281 NodeDef ndef; 282 ndef.set_name(graph_->NewName(kNodeLabel)); 283 ndef.set_op("AddN"); // N-way Add 284 for (const NodeOut& nout : grads) { 285 ndef.add_input(nout.name()); 286 } 287 AddNodeAttr("N", static_cast<int64>(grads.size()), &ndef); 288 AddNodeAttr("T", dtype, &ndef); 289 Status s; 290 Node* add = graph_->AddNode(ndef, &s); 291 TF_CHECK_OK(s); 292 for (size_t i = 0; i < grads.size(); ++i) { 293 const NodeOut& nout = grads[i]; 294 graph_->AddEdge(nout.node, nout.index, add, i); 295 } 296 return {add, 0}; 297 } 298 299 static bool IsPrimitiveOpWithNoGrad(const string& func) { 300 gradient::Creator creator; 301 Status s = gradient::GetOpGradientCreator(func, &creator); 302 return s.ok() && (creator == nullptr); 303 } 304 305 Status SymbolicGradientBuilder::Compute() { 306 // Initialize backprops. 307 InitBackprop(); 308 309 // Backward propagation. 310 gtl::InlinedVector<NodeOut, 8> dy; 311 while (!ready_.empty()) { 312 // n has collected all gradients. 313 Node* n = ready_.front(); 314 ready_.pop_front(); 315 316 // "n" has num_x inputs and num_y outputs. 317 const int num_x = n->num_inputs(); 318 const int num_y = n->num_outputs(); 319 320 auto iter = stop_nodes_.find(n->id()); 321 if (iter != stop_nodes_.end()) { 322 // Stop backprop. 323 // TODO(andydavis) Support stop nodes with more than one output. 324 CHECK_EQ(1, num_y); 325 continue; 326 } 327 328 // dy[i] is the sum of i-th output's backpropped gradients. 329 dy.clear(); 330 dy.resize(num_y, {nullptr, 0}); 331 for (int i = 0; i < num_y; ++i) { 332 dy[i] = SumGradients({n, i}); 333 } 334 335 if (IsPrimitiveOpWithNoGrad(n->type_string())) { 336 // No grad defined for this op: Backprop zeros along the in edges. 337 for (const Edge* e : n->in_edges()) { 338 if (e->IsControlEdge()) continue; 339 BackpropZerosAlongEdge({e->src(), e->src_output()}); 340 } 341 continue; 342 } 343 344 // Adds a gradient node with num_x + num_y inputs and num_x 345 // outputs. 346 // TODO(andydavis) Support primitive gradient ops. 347 Node* grad = AddSymGrad(graph_, n, dy); 348 for (const Edge* e : n->in_edges()) { 349 if (e->IsControlEdge()) continue; 350 graph_->AddEdge(e->src(), e->src_output(), grad, e->dst_input()); 351 } 352 for (int i = 0; i < num_y; ++i) { 353 graph_->AddEdge(dy[i].node, dy[i].index, grad, num_x + i); 354 } 355 356 // Backprops along the in edges. 357 for (const Edge* e : n->in_edges()) { 358 if (e->IsControlEdge()) continue; 359 BackpropAlongEdge({grad, e->dst_input()}, {e->src(), e->src_output()}); 360 } 361 } 362 363 for (int i = 0; i < x_node_outputs_.size(); ++i) { 364 (*x_grad_node_outputs_)[i] = SumGradients(x_node_outputs_[i]); 365 } 366 367 return Status::OK(); 368 } 369 370 Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs, 371 gtl::ArraySlice<NodeOut> x_node_outputs, 372 gtl::ArraySlice<NodeOut> y_grad_node_outputs, 373 std::vector<NodeOut>* x_grad_node_outputs, 374 Graph* graph) { 375 SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs, 376 y_grad_node_outputs, x_grad_node_outputs, 377 graph); 378 return builder.Compute(); 379 } 380 381 } // end namespace tensorflow 382