Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // An optimization pass that groups nodes marked with a common
     17 // kXlaClusterAttr into functions, and replaces the original nodes by
     18 // calls. The calls are annotated with kXlaCompiledKernelAttr.
     19 
     20 #ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
     21 #define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
     22 
     23 #include "tensorflow/core/common_runtime/optimization_registry.h"
     24 #include "tensorflow/core/framework/function.h"
     25 #include "tensorflow/core/graph/graph.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 
     28 namespace tensorflow {
     29 
     30 // A rewriting function to apply to each subgraph during encapsulation.
     31 // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
     32 // 'input_permutation' is a mapping from old argument numbers to new argument
     33 // numbers, whereas 'output_permutation' is the same for outputs. Both
     34 // 'input_permutation' and 'output_permutation' are initialized to the identity
     35 // permutation. 'nodedef' is the NodeDef for the call to the function under
     36 // construction, provided to allow additional attributes to be set.
     37 // The rewrite may also change the NodeDef's operator name, and that
     38 // name will be used as the name of the generated function.
     39 typedef std::function<Status(
     40     std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
     41     std::vector<int>* output_permutation, NodeDef* node_def)>
     42     RewriteSubgraphFn;
     43 
     44 // Transformation that finds subgraphs whose nodes are marked with
     45 // 'group_attribute', splits those subgraphs into functions, and replaces
     46 // the originals with function calls.
     47 //
     48 // 'group_attribute' must be a string valued-attribute that names the new
     49 // functions to introduce.
     50 //
     51 // 'outside_compilation_attribute' must be a string-valued attribute that is
     52 // used to tag nodes within a subgraph to be part of an 'outside_compilation'
     53 // cluster within the subgraph. A cluster is formed from the set of nodes with
     54 // the same value of outside_compilation_subgraph and group_attribute. The nodes
     55 // in an outside_compilation cluster are left in the original graph. Edges
     56 // crossing from the subgraph to an outside_compilation cluster nested in the
     57 // subgraph are lifted into a SendToHost/RecvAtHost pair of nodes, and edges
     58 // crossing from an outside_compilation cluster into its enclosing subgraph are
     59 // lifted into a SendFromHost/RecvFromHost pair of nodes.
     60 //
     61 // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
     62 // function conversion.
     63 //
     64 // If 'parallel_checking' is true, the unencapsulated operators are added to the
     65 // output graph, together with a "ParallelCheck" operator, that verifies that
     66 // the original and encapsulated subgraphs produce similar results.
     67 //
     68 // If 'reuse_existing_functions' is set, use an existing function with the
     69 // same name, if any.
     70 //
     71 // TODO(phawkins): currently, some information in control edges
     72 // is not preserved. Suppose you have A and B in the main
     73 // graph, C and D in a subgraph. B and C have control deps from A, D has control
     74 // dep from B. Originally D must run after C, post-transformation this
     75 // dependency is lost.
     76 Status EncapsulateSubgraphsInFunctions(
     77     string group_attribute, string outside_compilation_attribute,
     78     const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
     79     bool parallel_checking, bool reuse_existing_functions,
     80     std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
     81 
     82 // The attribute that marks function calls produced by the encapsulate
     83 // subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
     84 extern const char* const kXlaCompiledKernelAttr;
     85 
     86 // Does `node` have the kXlaCompiledKernelAttr attribute?
     87 bool IsXlaCompiledKernel(const Node& node);
     88 
     89 // Functions produced by the EncapsulateSubgraphs pass have their arguments in
     90 // the order:
     91 // 1) compile-time constant arguments, in host memory,
     92 // 2) other arguments, in device memory.
     93 // 3) resource variable arguments, in host memory. Note that only the resource
     94 //    Tensor itself is in host memory; the underlying value may be in device
     95 //    memory.
     96 // The functions are annotated with the following attributes that describe how
     97 // many constant and resource arguments there are:
     98 
     99 // Name of the attribute containing the number of constant arguments.
    100 extern const char* const kXlaNumConstantArgsAttr;
    101 
    102 // Name of the attribute containing the number of resource variable arguments.
    103 extern const char* const kXlaNumResourceArgsAttr;
    104 
    105 class EncapsulateSubgraphsPass : public GraphOptimizationPass {
    106  public:
    107   Status Run(const GraphOptimizationPassOptions& options) override;
    108 };
    109 
    110 }  // namespace tensorflow
    111 
    112 #endif  // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
    113