1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" 17 #include "tensorflow/compiler/jit/defs.h" 18 #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" 19 #include "tensorflow/compiler/tf2xla/dump_graph.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/core/common_runtime/function.h" 22 #include "tensorflow/core/common_runtime/optimization_registry.h" 23 #include "tensorflow/core/framework/graph_def_util.h" 24 #include "tensorflow/core/framework/node_def_builder.h" 25 #include "tensorflow/core/framework/node_def_util.h" 26 #include "tensorflow/core/graph/algorithm.h" 27 #include "tensorflow/core/graph/graph.h" 28 #include "tensorflow/core/graph/graph_constructor.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/hash/hash.h" 31 #include "tensorflow/core/public/version.h" 32 33 namespace tensorflow { 34 35 static Status BuildLaunchNode( 36 const string& nodename, const string& function_name, 37 const AttrValueMap& function_attr, const string& device_name, 38 const DataTypeVector& constant_dtypes, int num_resources, 39 const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, 40 Graph* graph, Node** node) { 41 NodeDef def; 42 def.set_name(graph->NewName(nodename)); 43 def.set_op("_XlaLaunch"); 44 def.set_device(device_name); 45 AddNodeAttr("Tconstants", constant_dtypes, &def); 46 AddNodeAttr("Targs", arg_dtypes, &def); 47 AddNodeAttr("Nresources", num_resources, &def); 48 AddNodeAttr("Tresults", result_dtypes, &def); 49 NameAttrList function; 50 function.set_name(function_name); 51 *function.mutable_attr() = function_attr; 52 AddNodeAttr("function", function, &def); 53 54 Status status; 55 *node = graph->AddNode(def, &status); 56 return status; 57 } 58 59 static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { 60 VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; 61 62 int num_constant_args, num_resource_args; 63 TF_RETURN_IF_ERROR( 64 GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); 65 TF_RETURN_IF_ERROR( 66 GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); 67 68 if (num_constant_args < 0 || num_resource_args < 0 || 69 num_constant_args + num_resource_args > node->num_inputs()) { 70 return errors::InvalidArgument( 71 "Invalid number of constant/resource arguments to XLA kernel."); 72 } 73 const int num_nonconst_args = 74 node->num_inputs() - num_constant_args - num_resource_args; 75 76 DataTypeVector const_dtypes(node->input_types().begin(), 77 node->input_types().begin() + num_constant_args); 78 DataTypeVector arg_dtypes( 79 node->input_types().begin() + num_constant_args, 80 node->input_types().begin() + num_constant_args + num_nonconst_args); 81 82 // Build a _XlaLaunch operator to execute the function body. 83 Node* launch_node; 84 TF_RETURN_IF_ERROR(BuildLaunchNode( 85 graph->NewName(node->name()), node->type_string(), node->def().attr(), 86 node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, 87 node->output_types(), graph, &launch_node)); 88 launch_node->set_assigned_device_name(node->assigned_device_name()); 89 90 // Copy incoming edges to the launch node. 91 for (const Edge* edge : node->in_edges()) { 92 if (edge->IsControlEdge()) { 93 graph->AddControlEdge(edge->src(), launch_node); 94 } else { 95 graph->AddEdge(edge->src(), edge->src_output(), launch_node, 96 edge->dst_input()); 97 } 98 } 99 100 // Copy outgoing edges to the launch node. 101 std::vector<const Edge*> out_edges(node->out_edges().begin(), 102 node->out_edges().end()); 103 for (const Edge* edge : out_edges) { 104 Node* dst = edge->dst(); 105 int src_output = edge->src_output(); 106 int dst_input = edge->dst_input(); 107 graph->RemoveEdge(edge); 108 109 if (edge->IsControlEdge()) { 110 graph->AddControlEdge(launch_node, dst); 111 } else { 112 graph->AddEdge(launch_node, src_output, dst, dst_input); 113 } 114 } 115 graph->RemoveNode(node); 116 117 return Status::OK(); 118 } 119 120 Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { 121 Graph* graph = options.graph->get(); 122 123 for (Node* n : graph->op_nodes()) { 124 // In all cases, only try to compile computational nodes. 125 if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { 126 continue; 127 } 128 129 // Only compile nodes that are marked for compilation by the 130 // compilation-marking pass (via 'attr_name'). 131 if (IsXlaCompiledKernel(*n)) { 132 TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); 133 } 134 } 135 136 if (VLOG_IS_ON(1)) { 137 dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph, 138 options.flib_def); 139 } 140 return Status::OK(); 141 } 142 } // namespace tensorflow 143