Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
     17 #define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_mgr.h"
     24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     25 #include "tensorflow/core/framework/function.h"
     26 #include "tensorflow/core/graph/graph.h"
     27 #include "tensorflow/core/protobuf/config.pb.h"
     28 
     29 namespace tensorflow {
     30 
     31 static constexpr const char* const kNoInlineAttr = "_noinline";
     32 
     33 // Registers a default customizable kernel creator for a function call.
     34 //
     35 // If 'cb()' returns a non-OK, we still fall back to an executor-based
     36 // interpreter op kernel to execute a function. If 'cb()' returns OK,
     37 // takes ownership of the returned OpKernel.
     38 //
     39 // TODO(zhifengc/phawkins): b/32379046
     40 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
     41 
     42 // Creates a FunctionLibraryRuntime, which instantiates functions
     43 // defined in "lib_def" and executes functions on the "device".
     44 // "device_mgr" must contain the "device". If not nullptr,
     45 // "custom_kernel_creator" is consulted by the returned runtime to
     46 // create kernels.
     47 //
     48 // The returned object does not take ownerships of "device" or
     49 // "lib_def".  The caller must ensure "device" and "lib_def" outlives
     50 // the returned object.
     51 //
     52 // The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that
     53 // typically owns the created FunctionLibraryRuntime object. The parent pointer
     54 // is not owned by the FunctionLibraryRuntime object.
     55 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
     56     const DeviceMgr* device_mgr, Env* env, Device* device,
     57     int graph_def_version, const FunctionLibraryDefinition* lib_def,
     58     const OptimizerOptions& optimizer_options,
     59     CustomKernelCreator custom_kernel_creator,
     60     ProcessFunctionLibraryRuntime* parent);
     61 
     62 // Same as above except that the returned runtime consults with the
     63 // global default custom kernel creator registered by
     64 // RegisterDefaultCustomKernelCreator.
     65 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
     66     const DeviceMgr* device_mgr, Env* env, Device* device,
     67     int graph_def_version, const FunctionLibraryDefinition* lib_def,
     68     const OptimizerOptions& optimizer_options,
     69     ProcessFunctionLibraryRuntime* parent);
     70 
     71 // FunctionLibraryRuntime::GetFunctionBody returns a description of an
     72 // instantiated function that is represented as a Graph with arg/ret
     73 // nodes annotated.
     74 struct FunctionBody {
     75   FunctionDef fdef;
     76   Graph* graph = nullptr;  // owned.
     77   DataTypeVector arg_types;
     78   DataTypeVector ret_types;
     79   gtl::InlinedVector<Node*, 4> arg_nodes;
     80   gtl::InlinedVector<Node*, 4> ret_nodes;
     81 
     82   FunctionBody() {}
     83   FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
     84                DataTypeSlice ret_types, Graph* g);
     85   ~FunctionBody();
     86 };
     87 
     88 // Debugging facility.  Returns a debug string for a graph
     89 // representing an instantiated function.
     90 string DebugString(const Graph* instantiated_func_graph);
     91 
     92 // A few hand-crafted optimization on the instantiated function body
     93 // (a Graph*).
     94 
     95 // Removes nodes that are
     96 //   1. not stateful; and
     97 //   2. not _Arg; and
     98 //   3. not reachable from _Retval.
     99 // Returns true iff any node is removed from "g".
    100 bool RemoveDeadNodes(Graph* g);
    101 
    102 // Find a pattern:
    103 //   src -(in)-> node -(out)-> dst, where
    104 // 1) node is an identity node;
    105 // 2) in is the only incoming data edge;
    106 // 3) out is the only outgoing data edge;
    107 //
    108 // Rewrites the above pattern with src->dst and relevant data
    109 // dependencies updated. Repeat the process until no such pattern
    110 // left.
    111 bool RemoveIdentityNodes(Graph* g);
    112 
    113 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
    114 bool RemoveListArrayConverter(Graph* g);
    115 
    116 // For each node in "graph", if "lib" indicates that the node is a
    117 // function call, inline the function body.  Returns true if at least
    118 // one node is inlined.
    119 //
    120 // This routine goes through "graph" nodes once and applies the
    121 // inlining.  The caller may decide to apply the inlining on "graph"
    122 // multiple times by calling ExpandInlineFunctions a few times.
    123 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph);
    124 
    125 // Dump the contents of the "graph" to log files if the logging level is
    126 // sufficiently high.
    127 void DumpGraph(StringPiece label, const Graph* g);
    128 
    129 // Applies graph rewrite optimization such as inlining, dead code
    130 // removal, etc.
    131 //
    132 // **g is a graph constructed based on the runtime library 'lib'.
    133 // OptimizeGraph mutates **g extensively and replaces '*g' with a
    134 // complete copy. Therefore, the caller should not keep any references
    135 // to nodes *g.
    136 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
    137 
    138 // Convert the Graph of a function to a GraphDef.
    139 //
    140 // Handles renaming of nodes to avoid duplicate names which may
    141 // be present after various rewriting operations.
    142 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
    143 
    144 // Given a numerical function "f", returns another numerical function
    145 // "g", such that if "f" takes N inputs and produces M outputs, "g"
    146 // takes N + M inputs and produces N outputs. I.e., if
    147 //   (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
    148 // g is a function which is
    149 //   (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
    150 //                                     dL/dy1, dL/dy2, ..., dL/dy_M),
    151 // where L is a scalar-value function of (...x_i...).
    152 //
    153 // TODO(zhifengc): Asks math expert to say the comment again.
    154 FunctionBody* SymbolicGradient(const FunctionBody& f);
    155 
    156 // Given a "caller" in graph "g", which is a function call of a function
    157 // to "fbody". Replaces the "caller" with fbody->graph and connects
    158 // edges properly.
    159 void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
    160                         Node* caller, const FunctionBody* fbody);
    161 
    162 // Instantiates FunctionDef into a graph. Set *fbody to point to the
    163 // FunctionBody that holds the instantiated FunctionDef.
    164 Status FunctionDefToBodyHelper(
    165     const FunctionDef& fdef, const AttrSlice& attrs,
    166     const FunctionLibraryDefinition* const lib_def,
    167     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
    168     FunctionBody** fbody);
    169 }  // end namespace tensorflow
    170 
    171 #endif  // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
    172