Home | History | Annotate | Download | only in jit
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_
     17 #define TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_
     18 
     19 #include "absl/types/optional.h"
     20 #include "tensorflow/compiler/jit/encapsulate_util.h"
     21 #include "tensorflow/compiler/xla/status_macros.h"
     22 #include "tensorflow/core/graph/graph.h"
     23 
     24 namespace tensorflow {
     25 
     26 // Rewrite function for outside compilation subgraphs. It will perform the
     27 // following steps:
     28 //
     29 // 1. Add a XLA computation key placeholder node (it will be used as input for
     30 //    XlaRecvAtHost and XlaSendFromHost);
     31 // 2. Replace all _Arg nodes with one single XlaRecvAtHost node;
     32 // 3. Replace all _Retval nodes with one single XlaSendFromHost node;
     33 // 4. Mark all nodes except key placeholder with attr `xla_cluster_attr_name`
     34 //    and `outside_compilation_attr_name`;
     35 // 5. For nodes marked with attr kXlaConnectedToXlaComputationAttrName, add a
     36 //    control edge from the node to XlaSendFromHost; for nodes marked with attr
     37 //    kXlaConnectedFromXlaComputationAttrName, add a control edge from
     38 //    XlaRecvAtHost node to the node;
     39 // 6. Try pruning XlaRecvAtHost/XlaSendFromHost/key placeholder node.
     40 // 7. Add necessary attributes to `node_def`, so we can replace it with a
     41 //    XlaHostCompute node later. If all input shapes for XlaSendFromHost are
     42 //    known, "shapes" attr will be set to the list of input shapes; otherwise
     43 //    "shape_inference_graph" attr will be set to shape inference function name.
     44 class RewriteOutsideCompilationSubgraphFn {
     45  public:
     46   RewriteOutsideCompilationSubgraphFn(
     47       const string& xla_cluster_attr_name,
     48       const string& outside_compilation_attr_name,
     49       const string& xla_cluster_name)
     50       : xla_cluster_attr_name_(xla_cluster_attr_name),
     51         outside_compilation_attr_name_(outside_compilation_attr_name),
     52         xla_cluster_name_(xla_cluster_name) {}
     53 
     54   Status operator()(const std::vector<OutputTensor>&,
     55                     std::unique_ptr<Graph>* graph,
     56                     std::vector<int>* input_permutation,
     57                     std::vector<int>* output_permutation, NodeDef* node_def);
     58 
     59  private:
     60   string xla_cluster_attr_name_;
     61   string outside_compilation_attr_name_;
     62   string xla_cluster_name_;
     63 };
     64 
     65 // For an XLA computation function, replace all outside compilations with
     66 // XlaHostCompute nodes. Each outside compilation subgraph will be rewritten by
     67 // `RewriteOutsideCompilationSubgraphFn`, and they will be merged into one
     68 // single host side graph (`host_graph`).
     69 //
     70 // xla_cluster_attr_name and outside_compilation_attr_name: attr name for XLA
     71 //   computation and outside compilation. Required for
     72 //   `RewriteOutsideCompilationSubgraphFn`.
     73 // xla_cluster_name: XLA cluster name for this XLA computation. We need it
     74 //   because XLA cluster name might be different from `func_name`.
     75 // func_name_attrs: they will be used to instantiate the XLA computation func.
     76 // new_func_name: new function name for rewritten XLA computation func.
     77 // host_compute_core: mapping from outside compilation cluster name to XLA
     78 //   device assignment.
     79 // fld: FunctionLibraryDefinition object.
     80 // host_graph: Graph object to store host side graph for all outside
     81 //   compilations within this XLA computation func. If there is no outside
     82 //   compilation, it will be empty.
     83 // shape_inference_graphs: a list of outside compilation shape inference
     84 //   function names. These functions need to be rewritten later.
     85 // has_outside_compilation: a bool indicating whether this function has any
     86 //   outside compilation nodes.
     87 Status ExtractOutsideCompilationForFunction(
     88     const string& xla_cluster_attr_name,
     89     const string& outside_compilation_attr_name, const string& xla_cluster_name,
     90     const NameAttrList& func_name_attrs, const string& new_func_name,
     91     const string& host_graph_func_name,
     92     const std::map<string, int>& host_compute_core, FunctionLibraryRuntime* flr,
     93     FunctionLibraryDefinition* fld, std::vector<string>* shape_inference_graphs,
     94     bool* has_outside_compilation);
     95 
     96 // Rewrites XLA computation in `clusters` to replace outside compilation nodes
     97 // with XlaHostCompute, and moves those outside compilations into `g`. If shapes
     98 // of outside compilation outputs cannot be determined now, we will store shape
     99 // inference graph into `fld`.
    100 Status ExtractOutsideCompilation(
    101     const string& xla_cluster_attr_name,
    102     const string& outside_compilation_attr_name,
    103     const std::unordered_map<string, XlaClusterInfo>& clusters, Graph* g,
    104     FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld);
    105 
    106 }  // namespace tensorflow
    107 
    108 #endif  // TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_
    109