Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
     17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
     18 
     19 #include <functional>
     20 #include <memory>
     21 
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_mgr.h"
     24 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
     26 #include "tensorflow/core/framework/function.h"
     27 #include "tensorflow/core/graph/graph.h"
     28 #include "tensorflow/core/protobuf/config.pb.h"
     29 
     30 namespace tensorflow {
     31 
     32 static constexpr const char* const kNoInlineAttr = "_noinline";
     33 
     34 // Registers a default customizable kernel creator for a function call.
     35 //
     36 // If 'cb()' returns a non-OK, we still fall back to an executor-based
     37 // interpreter op kernel to execute a function. If 'cb()' returns OK,
     38 // takes ownership of the returned OpKernel.
     39 //
     40 // TODO(zhifengc/phawkins): b/32379046
     41 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
     42 
     43 // Creates a FunctionLibraryRuntime, which instantiates functions
     44 // defined in "lib_def" and executes functions on the "device".
     45 // "device_mgr" must contain the "device". If not nullptr,
     46 // "custom_kernel_creator" is consulted by the returned runtime to
     47 // create kernels.
     48 //
     49 // The returned object does not take ownerships of "device" or
     50 // "lib_def".  The caller must ensure "device" and "lib_def" outlives
     51 // the returned object.
     52 //
     53 // The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that
     54 // typically owns the created FunctionLibraryRuntime object. The parent pointer
     55 // is not owned by the FunctionLibraryRuntime object.
     56 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
     57     const DeviceMgr* device_mgr, Env* env, Device* device,
     58     int graph_def_version, const FunctionLibraryDefinition* lib_def,
     59     thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
     60     CustomKernelCreator custom_kernel_creator,
     61     ProcessFunctionLibraryRuntime* parent);
     62 
     63 // Same as above except that the returned runtime consults with the
     64 // global default custom kernel creator registered by
     65 // RegisterDefaultCustomKernelCreator.
     66 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
     67     const DeviceMgr* device_mgr, Env* env, Device* device,
     68     int graph_def_version, const FunctionLibraryDefinition* lib_def,
     69     thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
     70     ProcessFunctionLibraryRuntime* parent);
     71 
     72 // FunctionLibraryRuntime::GetFunctionBody returns a description of an
     73 // instantiated function that is represented as a Graph with arg/ret
     74 // nodes annotated.
     75 struct FunctionBody {
     76   FunctionDef fdef;
     77   Graph* graph = nullptr;  // owned.
     78   DataTypeVector arg_types;
     79   DataTypeVector ret_types;
     80   gtl::InlinedVector<Node*, 4> arg_nodes;
     81   gtl::InlinedVector<Node*, 4> ret_nodes;
     82   gtl::InlinedVector<Node*, 4> control_ret_nodes;
     83 
     84   FunctionBody() {}
     85   FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
     86                DataTypeSlice ret_types, Graph* g);
     87   ~FunctionBody();
     88 };
     89 
     90 // Debugging facility.  Returns a debug string for a graph
     91 // representing an instantiated function.
     92 string DebugString(const Graph* instantiated_func_graph);
     93 
     94 // A few hand-crafted optimization on the instantiated function body
     95 // (a Graph*).
     96 
     97 // Removes nodes that are
     98 //   1. not stateful; and
     99 //   2. not _Arg; and
    100 //   3. not reachable from _Retval.
    101 //
    102 // This function is triggered by function inlining, unlike 'PruneFunctionBody'
    103 // it doesn't preserve nodes that are reachable from control returns. Function
    104 // inlining is responsible for connecting control return nodes with the nodes
    105 // that have input control edges from the inlined function call node.
    106 //
    107 // Assuming that automatic control dependency tracking is correct, absence of
    108 // outgoing control edge from the function call node means that no one needs to
    109 // observe side-effect that might have been generated by the function (see
    110 // documentation in common_runtime/function.cc for details).
    111 //
    112 // Returns true iff any node is removed from "g".
    113 bool RemoveDeadNodes(Graph* g);
    114 
    115 // Find a pattern:
    116 //   src -(in)-> node -(out)-> dst, where
    117 // 1) node is an identity node;
    118 // 2) in is the only incoming data edge;
    119 // 3) out is the only outgoing data edge;
    120 //
    121 // Rewrites the above pattern with src->dst and relevant data
    122 // dependencies updated. Repeat the process until no such pattern
    123 // left.
    124 bool RemoveIdentityNodes(Graph* g);
    125 
    126 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
    127 bool RemoveListArrayConverter(Graph* g);
    128 
    129 // Dump the contents of the "graph" to log files if the logging level is
    130 // sufficiently high.
    131 void DumpGraph(StringPiece label, const Graph* g);
    132 
    133 // Applies graph rewrite optimization such as inlining, dead code
    134 // removal, etc.
    135 //
    136 // **g is a graph constructed based on the runtime library 'lib'.
    137 // OptimizeGraph mutates **g extensively and replaces '*g' with a
    138 // complete copy. Therefore, the caller should not keep any references
    139 // to nodes *g.
    140 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
    141                    const GraphOptimizer::Options& graph_optimizer_options);
    142 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
    143 
    144 // Convert the Graph of a function to a GraphDef.
    145 //
    146 // Handles renaming of nodes to avoid duplicate names which may
    147 // be present after various rewriting operations.
    148 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
    149 
    150 // Given a numerical function "f", returns another numerical function
    151 // "g", such that if "f" takes N inputs and produces M outputs, "g"
    152 // takes N + M inputs and produces N outputs. I.e., if
    153 //   (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
    154 // g is a function which is
    155 //   (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
    156 //                                     dL/dy1, dL/dy2, ..., dL/dy_M),
    157 // where L is a scalar-value function of (...x_i...).
    158 //
    159 // TODO(zhifengc): Asks math expert to say the comment again.
    160 FunctionBody* SymbolicGradient(const FunctionBody& f);
    161 
    162 struct InlineFunctionBodyOptions {
    163   // All nodes that have incoming control edge *from* the function call node,
    164   // will be forwarded to the "output control node". There are two options for
    165   // choosing which nodes will have a control edge *to* the "output control
    166   // node":
    167   //   a) control returns            (`control_ret` field in FunctionDef)
    168   //   b) data returns               (`ret` field in FunctionDef)
    169   enum class OutputControlSource { kDataOutputs, kControlOutputs };
    170 
    171   // Ignore '_noinline' function attribute.
    172   bool ignore_noinline = false;
    173   // If 'true' function inlining will override explicitly specified devices
    174   // inside function body with the caller node device.
    175   bool override_device = false;
    176   // For compatibility with Tensorflow v1 by default we will use data outputs.
    177   // Control returns were added to Tensorflow v2 with automatic control
    178   // dependencies tracking in Eager mode.
    179   OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
    180 
    181   // A human-readable debug string for this options.
    182   string DebugString() const;
    183 };
    184 
    185 // Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
    186 // based on the type signature of 'node' and 'fbody':
    187 //
    188 // (1) Caller node has the same number of inputs and outputs as the function.
    189 // (2) Caller node inputs and outputs have the same data types as function
    190 //     inputs and returns.
    191 // (3) Validation rules defined in InlineFunctionBodyOptions.
    192 //
    193 // If function can't be safely inlined, returns error message with details why
    194 // inlining is not possible or safe.
    195 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
    196                         const InlineFunctionBodyOptions& options);
    197 
    198 // Given a "caller" in graph "g", which is a function call of a function
    199 // to "fbody". Replaces the "caller" with fbody->graph and connects
    200 // edges properly. "override_device" specifies whether inlining should replace
    201 // explicitly specified devices inside fbody with the callee's device.
    202 //
    203 // Returns 'Status::OK()' if function was successfully inlined into the graph.
    204 // If function inlining is not possible returns a error with a reason, and
    205 // leaves the graph in unmodified state.
    206 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
    207                           Node* caller, const FunctionBody* fbody,
    208                           const InlineFunctionBodyOptions& options);
    209 
    210 // There are three types of function calls that could be invoked during
    211 // *Tensorflow graph execution*:
    212 //
    213 // 1) Native function call (node.type_string() is the function name). These
    214 //    functions are always executed on a single-device, which is the device of
    215 //    the function call node.
    216 //
    217 // 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
    218 //    ops) can execute on multiple devices and accept DT_RESOURCE inputs that
    219 //    belong to different devices. This type of functions was added in
    220 //    Tensorflow 2.0 Eager mode, and it has control outputs to represent
    221 //    side-effects that must always execute (see `control_ret` in FunctionDef).
    222 //
    223 // 3) SymbolicGradient has been deprecated for a while, but we still keep it and
    224 //    use `native` options for inlining for compatibility.
    225 //
    226 // We need to have distinct inlining rules for compatibility with Tensorflow v1.
    227 //
    228 // There are few other places in Tensorflow that could execute functions:
    229 //
    230 // 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
    231 //    functions directly via function library runtime, without going through
    232 //    the graph.
    233 // 2) tf.data pipelines - also execute functions directly via function library
    234 //    runtime with custom executors.
    235 struct ExpandInlineFunctionsOptions {
    236   ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
    237     using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
    238     multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
    239   }
    240 
    241   InlineFunctionBodyOptions native_options;
    242   InlineFunctionBodyOptions multi_device_options;
    243 };
    244 
    245 // WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
    246 // workaround that will be enabled only during the function inlining unification
    247 // (b/126811947). Contact ezhulenev@ if you think you need it.
    248 // TODO(ezhulenev): Delete this function.
    249 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
    250                            const ExpandInlineFunctionsOptions& options);
    251 
    252 // For each node in "graph", if "lib" indicates that the node is a
    253 // function call, inline the function body. Returns true if at least
    254 // one node is inlined.
    255 //
    256 // This routine goes through "graph" nodes once and applies the
    257 // inlining. The caller may decide to apply the inlining on "graph"
    258 // multiple times by calling ExpandInlineFunctions a few times.
    259 //
    260 // Function calls that can't be safely inlined into the graph (ValidateInlining
    261 // returns error), are ignored.
    262 //
    263 // TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
    264 // FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
    265 // lower_function_call.cc).
    266 inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
    267   return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
    268 }
    269 
    270 // Extracts function name and attributes from `call_def` and invokes
    271 // flr->Instantiate(name, attrs, handle).
    272 // `call_def` can be a native function call (where the op type is the function
    273 // name) or a call through PartitionedCall/StatefulPartitionedCall.
    274 Status InstantiateFunctionCall(const NodeDef& call_def,
    275                                FunctionLibraryRuntime& flr,
    276                                FunctionLibraryRuntime::Handle* handle);
    277 
    278 // Returns true iff `n` represents a function call. `n` can be a native
    279 // function call (n.type_string() is the function name),
    280 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
    281 // has been deprecated for a while).
    282 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
    283 
    284 // Instantiates FunctionDef into a graph. Set *fbody to point to the
    285 // FunctionBody that holds the instantiated FunctionDef.
    286 Status FunctionDefToBodyHelper(
    287     const FunctionDef& fdef, const AttrSlice& attrs,
    288     const FunctionLibraryDefinition* const lib_def,
    289     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
    290     FunctionBody** fbody);
    291 }  // end namespace tensorflow
    292 
    293 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
    294