1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // An optimization pass that groups nodes marked with a common 17 // kXlaClusterAttr into functions, and replaces the original nodes by 18 // calls. The calls are annotated with kXlaCompiledKernelAttr. 19 20 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ 21 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ 22 23 #include "tensorflow/core/common_runtime/optimization_registry.h" 24 #include "tensorflow/core/framework/function.h" 25 #include "tensorflow/core/graph/graph.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 30 // A rewriting function to apply to each subgraph during encapsulation. 31 // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; 32 // 'input_permutation' is a mapping from old argument numbers to new argument 33 // numbers, whereas 'output_permutation' is the same for outputs. Both 34 // 'input_permutation' and 'output_permutation' are initialized to the identity 35 // permutation. 'nodedef' is the NodeDef for the call to the function under 36 // construction, provided to allow additional attributes to be set. 37 // The rewrite may also change the NodeDef's operator name, and that 38 // name will be used as the name of the generated function. 39 typedef std::function<Status( 40 std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation, 41 std::vector<int>* output_permutation, NodeDef* node_def)> 42 RewriteSubgraphFn; 43 44 // Transformation that finds subgraphs whose nodes are marked with 45 // 'group_attribute', splits those subgraphs into functions, and replaces 46 // the originals with function calls. 47 // 48 // 'group_attribute' must be a string valued-attribute that names the new 49 // functions to introduce. 50 // 51 // 'outside_compilation_attribute' must be a string-valued attribute that is 52 // used to tag nodes within a subgraph to be part of an 'outside_compilation' 53 // cluster within the subgraph. A cluster is formed from the set of nodes with 54 // the same value of outside_compilation_subgraph and group_attribute. The nodes 55 // in an outside_compilation cluster are left in the original graph. Edges 56 // crossing from the subgraph to an outside_compilation cluster nested in the 57 // subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges 58 // crossing from an outside_compilation cluster into its enclosing subgraph are 59 // lifted into a SendFromHost/RecvFromHost pair of nodes. 60 // 61 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before 62 // function conversion. 63 // 64 // If 'parallel_checking' is true, the unencapsulated operators are added to the 65 // output graph, together with a "ParallelCheck" operator, that verifies that 66 // the original and encapsulated subgraphs produce similar results. 67 // 68 // If 'reuse_existing_functions' is set, use an existing function with the 69 // same name, if any. 70 // 71 // TODO(phawkins): currently, some information in control edges 72 // is not preserved. Suppose you have A and B in the main 73 // graph, C and D in a subgraph. B and C have control deps from A, D has control 74 // dep from B. Originally D must run after C, post-transformation this 75 // dependency is lost. 76 Status EncapsulateSubgraphsInFunctions( 77 string group_attribute, string outside_compilation_attribute, 78 const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, 79 bool parallel_checking, bool reuse_existing_functions, 80 std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library); 81 82 // The attribute that marks function calls produced by the encapsulate 83 // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. 84 extern const char* const kXlaCompiledKernelAttr; 85 86 // Does `node` have the kXlaCompiledKernelAttr attribute? 87 bool IsXlaCompiledKernel(const Node& node); 88 89 // Functions produced by the EncapsulateSubgraphs pass have their arguments in 90 // the order: 91 // 1) compile-time constant arguments, in host memory, 92 // 2) other arguments, in device memory. 93 // 3) resource variable arguments, in host memory. Note that only the resource 94 // Tensor itself is in host memory; the underlying value may be in device 95 // memory. 96 // The functions are annotated with the following attributes that describe how 97 // many constant and resource arguments there are: 98 99 // Name of the attribute containing the number of constant arguments. 100 extern const char* const kXlaNumConstantArgsAttr; 101 102 // Name of the attribute containing the number of resource variable arguments. 103 extern const char* const kXlaNumResourceArgsAttr; 104 105 class EncapsulateSubgraphsPass : public GraphOptimizationPass { 106 public: 107 Status Run(const GraphOptimizationPassOptions& options) override; 108 }; 109 110 } // namespace tensorflow 111 112 #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ 113