Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
     17 #define TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
     18 
     19 #include "tensorflow/core/graph/graph.h"
     20 #include "tensorflow/core/lib/core/status.h"
     21 #include "tensorflow/core/lib/gtl/array_slice.h"
     22 
     23 namespace tensorflow {
     24 
     25 // Represents the output of 'node' at 'index'.
     26 struct NodeOut {
     27   Node* node;
     28   int index;
     29 
     30   // Returns the string name that represents the output of this node.
     31   string name() const;
     32   // Returns the data type of the output of this node.
     33   DataType dtype() const;
     34 };
     35 
     36 // NOTE: This API is a work in progress and will likely be changing frequently.
     37 //
     38 // Given initial gradient-node outputs 'y_grad_node_outputs' (which compute the
     39 // symbolic partial derivatives of some loss function 'L' w.r.t the node outputs
     40 // 'y_node_outputs'), adds gradient nodes to 'graph' that compute the symbolic
     41 // partial derivatives of 'L' w.r.t the node outputs 'x_node_outputs'.
     42 //
     43 // REQUIRES: Each node in 'x_node_outputs' to be unique, and so to have a single
     44 // output (this restriction will be removed in a subsequent change).
     45 
     46 // TODO(andydavis) Add symbolic gradient support for general graphs (the current
     47 // implementation only supports gradients for functions). In particular,
     48 // the nodes in 'x_nodes' are currently restricted to have one output.
     49 
     50 Status AddSymbolicGradients(gtl::ArraySlice<NodeOut> y_node_outputs,
     51                             gtl::ArraySlice<NodeOut> x_node_outputs,
     52                             gtl::ArraySlice<NodeOut> y_grad_node_outputs,
     53                             std::vector<NodeOut>* x_grad_node_outputs,
     54                             Graph* graph);
     55 
     56 }  // namespace tensorflow
     57 
     58 #endif  // TENSORFLOW_CORE_GRAPH_GRADIENTS_H_
     59