Home | History | Annotate | Download | only in jit
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
     17 #include "tensorflow/compiler/jit/defs.h"
     18 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
     19 #include "tensorflow/compiler/tf2xla/dump_graph.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     21 #include "tensorflow/core/common_runtime/function.h"
     22 #include "tensorflow/core/common_runtime/optimization_registry.h"
     23 #include "tensorflow/core/framework/graph_def_util.h"
     24 #include "tensorflow/core/framework/node_def_builder.h"
     25 #include "tensorflow/core/framework/node_def_util.h"
     26 #include "tensorflow/core/graph/algorithm.h"
     27 #include "tensorflow/core/graph/graph.h"
     28 #include "tensorflow/core/graph/graph_constructor.h"
     29 #include "tensorflow/core/lib/core/status.h"
     30 #include "tensorflow/core/lib/hash/hash.h"
     31 #include "tensorflow/core/public/version.h"
     32 
     33 namespace tensorflow {
     34 
     35 static Status BuildLaunchNode(
     36     const string& nodename, const string& function_name,
     37     const AttrValueMap& function_attr, const string& device_name,
     38     const DataTypeVector& constant_dtypes, int num_resources,
     39     const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
     40     Graph* graph, Node** node) {
     41   NodeDef def;
     42   def.set_name(graph->NewName(nodename));
     43   def.set_op("_XlaLaunch");
     44   def.set_device(device_name);
     45   AddNodeAttr("Tconstants", constant_dtypes, &def);
     46   AddNodeAttr("Targs", arg_dtypes, &def);
     47   AddNodeAttr("Nresources", num_resources, &def);
     48   AddNodeAttr("Tresults", result_dtypes, &def);
     49   NameAttrList function;
     50   function.set_name(function_name);
     51   *function.mutable_attr() = function_attr;
     52   AddNodeAttr("function", function, &def);
     53 
     54   Status status;
     55   *node = graph->AddNode(def, &status);
     56   return status;
     57 }
     58 
     59 static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
     60   VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
     61 
     62   int num_constant_args, num_resource_args;
     63   TF_RETURN_IF_ERROR(
     64       GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
     65   TF_RETURN_IF_ERROR(
     66       GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
     67 
     68   if (num_constant_args < 0 || num_resource_args < 0 ||
     69       num_constant_args + num_resource_args > node->num_inputs()) {
     70     return errors::InvalidArgument(
     71         "Invalid number of constant/resource arguments to XLA kernel.");
     72   }
     73   const int num_nonconst_args =
     74       node->num_inputs() - num_constant_args - num_resource_args;
     75 
     76   DataTypeVector const_dtypes(node->input_types().begin(),
     77                               node->input_types().begin() + num_constant_args);
     78   DataTypeVector arg_dtypes(
     79       node->input_types().begin() + num_constant_args,
     80       node->input_types().begin() + num_constant_args + num_nonconst_args);
     81 
     82   // Build a _XlaLaunch operator to execute the function body.
     83   Node* launch_node;
     84   TF_RETURN_IF_ERROR(BuildLaunchNode(
     85       graph->NewName(node->name()), node->type_string(), node->def().attr(),
     86       node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
     87       node->output_types(), graph, &launch_node));
     88   launch_node->set_assigned_device_name(node->assigned_device_name());
     89 
     90   // Copy incoming edges to the launch node.
     91   for (const Edge* edge : node->in_edges()) {
     92     if (edge->IsControlEdge()) {
     93       graph->AddControlEdge(edge->src(), launch_node);
     94     } else {
     95       graph->AddEdge(edge->src(), edge->src_output(), launch_node,
     96                      edge->dst_input());
     97     }
     98   }
     99 
    100   // Copy outgoing edges to the launch node.
    101   std::vector<const Edge*> out_edges(node->out_edges().begin(),
    102                                      node->out_edges().end());
    103   for (const Edge* edge : out_edges) {
    104     Node* dst = edge->dst();
    105     int src_output = edge->src_output();
    106     int dst_input = edge->dst_input();
    107     graph->RemoveEdge(edge);
    108 
    109     if (edge->IsControlEdge()) {
    110       graph->AddControlEdge(launch_node, dst);
    111     } else {
    112       graph->AddEdge(launch_node, src_output, dst, dst_input);
    113     }
    114   }
    115   graph->RemoveNode(node);
    116 
    117   return Status::OK();
    118 }
    119 
    120 Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
    121   Graph* graph = options.graph->get();
    122 
    123   for (Node* n : graph->op_nodes()) {
    124     // In all cases, only try to compile computational nodes.
    125     if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
    126       continue;
    127     }
    128 
    129     // Only compile nodes that are marked for compilation by the
    130     // compilation-marking pass (via 'attr_name').
    131     if (IsXlaCompiledKernel(*n)) {
    132       TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
    133     }
    134   }
    135 
    136   if (VLOG_IS_ON(1)) {
    137     dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
    138                                 options.flib_def);
    139   }
    140   return Status::OK();
    141 }
    142 }  // namespace tensorflow
    143