1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "tensorflow/core/common_runtime/device.h" 23 #include "tensorflow/core/common_runtime/device_mgr.h" 24 #include "tensorflow/core/common_runtime/graph_optimizer.h" 25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/graph/graph.h" 28 #include "tensorflow/core/protobuf/config.pb.h" 29 30 namespace tensorflow { 31 32 static constexpr const char* const kNoInlineAttr = "_noinline"; 33 34 // Registers a default customizable kernel creator for a function call. 35 // 36 // If 'cb()' returns a non-OK, we still fall back to an executor-based 37 // interpreter op kernel to execute a function. If 'cb()' returns OK, 38 // takes ownership of the returned OpKernel. 39 // 40 // TODO(zhifengc/phawkins): b/32379046 41 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb); 42 43 // Creates a FunctionLibraryRuntime, which instantiates functions 44 // defined in "lib_def" and executes functions on the "device". 45 // "device_mgr" must contain the "device". If not nullptr, 46 // "custom_kernel_creator" is consulted by the returned runtime to 47 // create kernels. 48 // 49 // The returned object does not take ownerships of "device" or 50 // "lib_def". The caller must ensure "device" and "lib_def" outlives 51 // the returned object. 52 // 53 // The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that 54 // typically owns the created FunctionLibraryRuntime object. The parent pointer 55 // is not owned by the FunctionLibraryRuntime object. 56 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( 57 const DeviceMgr* device_mgr, Env* env, Device* device, 58 int graph_def_version, const FunctionLibraryDefinition* lib_def, 59 thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, 60 CustomKernelCreator custom_kernel_creator, 61 ProcessFunctionLibraryRuntime* parent); 62 63 // Same as above except that the returned runtime consults with the 64 // global default custom kernel creator registered by 65 // RegisterDefaultCustomKernelCreator. 66 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( 67 const DeviceMgr* device_mgr, Env* env, Device* device, 68 int graph_def_version, const FunctionLibraryDefinition* lib_def, 69 thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options, 70 ProcessFunctionLibraryRuntime* parent); 71 72 // FunctionLibraryRuntime::GetFunctionBody returns a description of an 73 // instantiated function that is represented as a Graph with arg/ret 74 // nodes annotated. 75 struct FunctionBody { 76 FunctionDef fdef; 77 Graph* graph = nullptr; // owned. 78 DataTypeVector arg_types; 79 DataTypeVector ret_types; 80 gtl::InlinedVector<Node*, 4> arg_nodes; 81 gtl::InlinedVector<Node*, 4> ret_nodes; 82 gtl::InlinedVector<Node*, 4> control_ret_nodes; 83 84 FunctionBody() {} 85 FunctionBody(const FunctionDef& f, DataTypeSlice arg_types, 86 DataTypeSlice ret_types, Graph* g); 87 ~FunctionBody(); 88 }; 89 90 // Debugging facility. Returns a debug string for a graph 91 // representing an instantiated function. 92 string DebugString(const Graph* instantiated_func_graph); 93 94 // A few hand-crafted optimization on the instantiated function body 95 // (a Graph*). 96 97 // Removes nodes that are 98 // 1. not stateful; and 99 // 2. not _Arg; and 100 // 3. not reachable from _Retval. 101 // 102 // This function is triggered by function inlining, unlike 'PruneFunctionBody' 103 // it doesn't preserve nodes that are reachable from control returns. Function 104 // inlining is responsible for connecting control return nodes with the nodes 105 // that have input control edges from the inlined function call node. 106 // 107 // Assuming that automatic control dependency tracking is correct, absence of 108 // outgoing control edge from the function call node means that no one needs to 109 // observe side-effect that might have been generated by the function (see 110 // documentation in common_runtime/function.cc for details). 111 // 112 // Returns true iff any node is removed from "g". 113 bool RemoveDeadNodes(Graph* g); 114 115 // Find a pattern: 116 // src -(in)-> node -(out)-> dst, where 117 // 1) node is an identity node; 118 // 2) in is the only incoming data edge; 119 // 3) out is the only outgoing data edge; 120 // 121 // Rewrites the above pattern with src->dst and relevant data 122 // dependencies updated. Repeat the process until no such pattern 123 // left. 124 bool RemoveIdentityNodes(Graph* g); 125 126 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. 127 bool RemoveListArrayConverter(Graph* g); 128 129 // Dump the contents of the "graph" to log files if the logging level is 130 // sufficiently high. 131 void DumpGraph(StringPiece label, const Graph* g); 132 133 // Applies graph rewrite optimization such as inlining, dead code 134 // removal, etc. 135 // 136 // **g is a graph constructed based on the runtime library 'lib'. 137 // OptimizeGraph mutates **g extensively and replaces '*g' with a 138 // complete copy. Therefore, the caller should not keep any references 139 // to nodes *g. 140 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g, 141 const GraphOptimizer::Options& graph_optimizer_options); 142 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g); 143 144 // Convert the Graph of a function to a GraphDef. 145 // 146 // Handles renaming of nodes to avoid duplicate names which may 147 // be present after various rewriting operations. 148 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); 149 150 // Given a numerical function "f", returns another numerical function 151 // "g", such that if "f" takes N inputs and produces M outputs, "g" 152 // takes N + M inputs and produces N outputs. I.e., if 153 // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), 154 // g is a function which is 155 // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, 156 // dL/dy1, dL/dy2, ..., dL/dy_M), 157 // where L is a scalar-value function of (...x_i...). 158 // 159 // TODO(zhifengc): Asks math expert to say the comment again. 160 FunctionBody* SymbolicGradient(const FunctionBody& f); 161 162 struct InlineFunctionBodyOptions { 163 // All nodes that have incoming control edge *from* the function call node, 164 // will be forwarded to the "output control node". There are two options for 165 // choosing which nodes will have a control edge *to* the "output control 166 // node": 167 // a) control returns (`control_ret` field in FunctionDef) 168 // b) data returns (`ret` field in FunctionDef) 169 enum class OutputControlSource { kDataOutputs, kControlOutputs }; 170 171 // Ignore '_noinline' function attribute. 172 bool ignore_noinline = false; 173 // If 'true' function inlining will override explicitly specified devices 174 // inside function body with the caller node device. 175 bool override_device = false; 176 // For compatibility with Tensorflow v1 by default we will use data outputs. 177 // Control returns were added to Tensorflow v2 with automatic control 178 // dependencies tracking in Eager mode. 179 OutputControlSource output_control_src = OutputControlSource::kDataOutputs; 180 181 // A human-readable debug string for this options. 182 string DebugString() const; 183 }; 184 185 // Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node' 186 // based on the type signature of 'node' and 'fbody': 187 // 188 // (1) Caller node has the same number of inputs and outputs as the function. 189 // (2) Caller node inputs and outputs have the same data types as function 190 // inputs and returns. 191 // (3) Validation rules defined in InlineFunctionBodyOptions. 192 // 193 // If function can't be safely inlined, returns error message with details why 194 // inlining is not possible or safe. 195 Status ValidateInlining(const Node* node, const FunctionBody* fbody, 196 const InlineFunctionBodyOptions& options); 197 198 // Given a "caller" in graph "g", which is a function call of a function 199 // to "fbody". Replaces the "caller" with fbody->graph and connects 200 // edges properly. "override_device" specifies whether inlining should replace 201 // explicitly specified devices inside fbody with the callee's device. 202 // 203 // Returns 'Status::OK()' if function was successfully inlined into the graph. 204 // If function inlining is not possible returns a error with a reason, and 205 // leaves the graph in unmodified state. 206 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, 207 Node* caller, const FunctionBody* fbody, 208 const InlineFunctionBodyOptions& options); 209 210 // There are three types of function calls that could be invoked during 211 // *Tensorflow graph execution*: 212 // 213 // 1) Native function call (node.type_string() is the function name). These 214 // functions are always executed on a single-device, which is the device of 215 // the function call node. 216 // 217 // 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall 218 // ops) can execute on multiple devices and accept DT_RESOURCE inputs that 219 // belong to different devices. This type of functions was added in 220 // Tensorflow 2.0 Eager mode, and it has control outputs to represent 221 // side-effects that must always execute (see `control_ret` in FunctionDef). 222 // 223 // 3) SymbolicGradient has been deprecated for a while, but we still keep it and 224 // use `native` options for inlining for compatibility. 225 // 226 // We need to have distinct inlining rules for compatibility with Tensorflow v1. 227 // 228 // There are few other places in Tensorflow that could execute functions: 229 // 230 // 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level" 231 // functions directly via function library runtime, without going through 232 // the graph. 233 // 2) tf.data pipelines - also execute functions directly via function library 234 // runtime with custom executors. 235 struct ExpandInlineFunctionsOptions { 236 ExpandInlineFunctionsOptions() : native_options(), multi_device_options() { 237 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; 238 multi_device_options.output_control_src = OutputControlSrc::kControlOutputs; 239 } 240 241 InlineFunctionBodyOptions native_options; 242 InlineFunctionBodyOptions multi_device_options; 243 }; 244 245 // WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary 246 // workaround that will be enabled only during the function inlining unification 247 // (b/126811947). Contact ezhulenev@ if you think you need it. 248 // TODO(ezhulenev): Delete this function. 249 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, 250 const ExpandInlineFunctionsOptions& options); 251 252 // For each node in "graph", if "lib" indicates that the node is a 253 // function call, inline the function body. Returns true if at least 254 // one node is inlined. 255 // 256 // This routine goes through "graph" nodes once and applies the 257 // inlining. The caller may decide to apply the inlining on "graph" 258 // multiple times by calling ExpandInlineFunctions a few times. 259 // 260 // Function calls that can't be safely inlined into the graph (ValidateInlining 261 // returns error), are ignored. 262 // 263 // TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the 264 // FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see 265 // lower_function_call.cc). 266 inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { 267 return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions()); 268 } 269 270 // Extracts function name and attributes from `call_def` and invokes 271 // flr->Instantiate(name, attrs, handle). 272 // `call_def` can be a native function call (where the op type is the function 273 // name) or a call through PartitionedCall/StatefulPartitionedCall. 274 Status InstantiateFunctionCall(const NodeDef& call_def, 275 FunctionLibraryRuntime& flr, 276 FunctionLibraryRuntime::Handle* handle); 277 278 // Returns true iff `n` represents a function call. `n` can be a native 279 // function call (n.type_string() is the function name), 280 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which 281 // has been deprecated for a while). 282 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n); 283 284 // Instantiates FunctionDef into a graph. Set *fbody to point to the 285 // FunctionBody that holds the instantiated FunctionDef. 286 Status FunctionDefToBodyHelper( 287 const FunctionDef& fdef, const AttrSlice& attrs, 288 const FunctionLibraryDefinition* const lib_def, 289 const std::function<Status(const string&, const OpDef**)>& get_func_sig, 290 FunctionBody** fbody); 291 } // end namespace tensorflow 292 293 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ 294