1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ 17 #define TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/compiler/jit/encapsulate_util.h" 21 #include "tensorflow/compiler/xla/status_macros.h" 22 #include "tensorflow/core/graph/graph.h" 23 24 namespace tensorflow { 25 26 // Rewrite function for outside compilation subgraphs. It will perform the 27 // following steps: 28 // 29 // 1. Add a XLA computation key placeholder node (it will be used as input for 30 // XlaRecvAtHost and XlaSendFromHost); 31 // 2. Replace all _Arg nodes with one single XlaRecvAtHost node; 32 // 3. Replace all _Retval nodes with one single XlaSendFromHost node; 33 // 4. Mark all nodes except key placeholder with attr `xla_cluster_attr_name` 34 // and `outside_compilation_attr_name`; 35 // 5. For nodes marked with attr kXlaConnectedToXlaComputationAttrName, add a 36 // control edge from the node to XlaSendFromHost; for nodes marked with attr 37 // kXlaConnectedFromXlaComputationAttrName, add a control edge from 38 // XlaRecvAtHost node to the node; 39 // 6. Try pruning XlaRecvAtHost/XlaSendFromHost/key placeholder node. 40 // 7. Add necessary attributes to `node_def`, so we can replace it with a 41 // XlaHostCompute node later. If all input shapes for XlaSendFromHost are 42 // known, "shapes" attr will be set to the list of input shapes; otherwise 43 // "shape_inference_graph" attr will be set to shape inference function name. 44 class RewriteOutsideCompilationSubgraphFn { 45 public: 46 RewriteOutsideCompilationSubgraphFn( 47 const string& xla_cluster_attr_name, 48 const string& outside_compilation_attr_name, 49 const string& xla_cluster_name) 50 : xla_cluster_attr_name_(xla_cluster_attr_name), 51 outside_compilation_attr_name_(outside_compilation_attr_name), 52 xla_cluster_name_(xla_cluster_name) {} 53 54 Status operator()(const std::vector<OutputTensor>&, 55 std::unique_ptr<Graph>* graph, 56 std::vector<int>* input_permutation, 57 std::vector<int>* output_permutation, NodeDef* node_def); 58 59 private: 60 string xla_cluster_attr_name_; 61 string outside_compilation_attr_name_; 62 string xla_cluster_name_; 63 }; 64 65 // For an XLA computation function, replace all outside compilations with 66 // XlaHostCompute nodes. Each outside compilation subgraph will be rewritten by 67 // `RewriteOutsideCompilationSubgraphFn`, and they will be merged into one 68 // single host side graph (`host_graph`). 69 // 70 // xla_cluster_attr_name and outside_compilation_attr_name: attr name for XLA 71 // computation and outside compilation. Required for 72 // `RewriteOutsideCompilationSubgraphFn`. 73 // xla_cluster_name: XLA cluster name for this XLA computation. We need it 74 // because XLA cluster name might be different from `func_name`. 75 // func_name_attrs: they will be used to instantiate the XLA computation func. 76 // new_func_name: new function name for rewritten XLA computation func. 77 // host_compute_core: mapping from outside compilation cluster name to XLA 78 // device assignment. 79 // fld: FunctionLibraryDefinition object. 80 // host_graph: Graph object to store host side graph for all outside 81 // compilations within this XLA computation func. If there is no outside 82 // compilation, it will be empty. 83 // shape_inference_graphs: a list of outside compilation shape inference 84 // function names. These functions need to be rewritten later. 85 // has_outside_compilation: a bool indicating whether this function has any 86 // outside compilation nodes. 87 Status ExtractOutsideCompilationForFunction( 88 const string& xla_cluster_attr_name, 89 const string& outside_compilation_attr_name, const string& xla_cluster_name, 90 const NameAttrList& func_name_attrs, const string& new_func_name, 91 const string& host_graph_func_name, 92 const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr, 93 FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs, 94 bool* has_outside_compilation); 95 96 // Rewrites XLA computation in `clusters` to replace outside compilation nodes 97 // with XlaHostCompute, and moves those outside compilations into `g`. If shapes 98 // of outside compilation outputs cannot be determined now, we will store shape 99 // inference graph into `fld`. 100 Status ExtractOutsideCompilation( 101 const string& xla_cluster_attr_name, 102 const string& outside_compilation_attr_name, 103 const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g, 104 FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); 105 106 } // namespace tensorflow 107 108 #endif // TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ 109