1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/graph_compiler.h" 17 18 #include <deque> 19 #include <numeric> 20 #include <vector> 21 #include "tensorflow/compiler/tf2xla/const_analysis.h" 22 #include "tensorflow/compiler/tf2xla/dump_graph.h" 23 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" 24 #include "tensorflow/compiler/tf2xla/literal_util.h" 25 #include "tensorflow/compiler/tf2xla/shape_util.h" 26 #include "tensorflow/compiler/tf2xla/type_util.h" 27 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" 28 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 29 #include "tensorflow/compiler/tf2xla/xla_context.h" 30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 31 #include "tensorflow/compiler/xla/client/client_library.h" 32 #include "tensorflow/core/common_runtime/device.h" 33 #include "tensorflow/core/common_runtime/executor.h" 34 #include "tensorflow/core/common_runtime/function.h" 35 #include "tensorflow/core/common_runtime/graph_optimizer.h" 36 #include "tensorflow/core/framework/attr_value_util.h" 37 #include "tensorflow/core/framework/node_def_util.h" 38 #include "tensorflow/core/framework/op_kernel.h" 39 #include "tensorflow/core/graph/algorithm.h" 40 #include "tensorflow/core/graph/graph_constructor.h" 41 #include "tensorflow/core/graph/node_builder.h" 42 #include "tensorflow/core/lib/gtl/cleanup.h" 43 #include "tensorflow/core/lib/hash/hash.h" 44 #include "tensorflow/core/platform/logging.h" 45 #include "tensorflow/core/public/version.h" 46 47 namespace tensorflow { 48 49 namespace { 50 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, 51 const std::vector<const XlaExpression*>& expressions, 52 std::vector<XlaCompiler::Argument>* args) { 53 auto builder = ctx->builder(); 54 std::vector<bool> compile_time_constant_flags(expressions.size()); 55 56 TF_RETURN_IF_ERROR( 57 BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); 58 59 args->resize(expressions.size()); 60 for (int i = 0; i < args->size(); ++i) { 61 XlaCompiler::Argument& arg = (*args)[i]; 62 arg.type = ctx->input_type(i); 63 arg.shape = ctx->InputShape(i); 64 65 if (arg.type == DT_RESOURCE) { 66 return errors::InvalidArgument( 67 "Resource as function argument is not yet implemented."); 68 } else if (expressions[i]->has_constant_value()) { 69 arg.kind = XlaCompiler::Argument::kConstant; 70 arg.constant_value = expressions[i]->constant_value(); 71 } else if (compile_time_constant_flags[i]) { 72 arg.kind = XlaCompiler::Argument::kConstant; 73 TF_RET_CHECK(expressions[i]->resource() == nullptr) 74 << "Input with resource is not yet implemented."; 75 TF_ASSIGN_OR_RETURN(auto literal, 76 builder->ComputeConstant(expressions[i]->handle())); 77 TF_RETURN_IF_ERROR( 78 LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); 79 } else { 80 arg.kind = XlaCompiler::Argument::kParameter; 81 } 82 } 83 return Status::OK(); 84 } 85 } // namespace 86 Status GraphCompiler::Compile() { 87 // Maintain a mapping from node id to node outputs. 88 using NodeOutputs = std::vector<TensorValue>; 89 std::vector<NodeOutputs> output_registry(graph_->num_node_ids()); 90 auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] { 91 for (const NodeOutputs& outputs : output_registry) { 92 for (const TensorValue& value : outputs) { 93 CHECK(!value.is_ref()); 94 delete value.tensor; 95 } 96 } 97 }); 98 99 // XLA requires determinism, generate a stable ordering from DFS. 100 std::vector<Node*> topo_sorted_nodes; 101 GetReversePostOrder(*graph_, &topo_sorted_nodes, 102 /*stable_comparator=*/NodeComparatorName()); 103 104 OpKernelContext::Params params; 105 PartiallySetupParams(¶ms); 106 107 for (Node* n : topo_sorted_nodes) { 108 OpKernel* op_kernel_raw = nullptr; 109 Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); 110 // Transfer ownership of the kernel to a local smart pointer. 111 std::unique_ptr<OpKernel> op_kernel(op_kernel_raw); 112 113 if (!s.ok()) { 114 s = AttachDef(s, *n); 115 LOG(ERROR) << "Executor failed to create kernel. " << s; 116 return s; 117 } 118 119 TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) 120 << "Not supported node: " << n->DebugString(); 121 params.op_kernel = op_kernel.get(); 122 gtl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs()); 123 params.output_attr_array = output_attr.data(); 124 125 // tensor_inputs_ is a buffer reused across graph traversal. We clean up and 126 // reinitialize the buffer before we visit a new node. 127 tensor_inputs_.clear(); 128 tensor_inputs_.resize(n->num_inputs()); 129 130 // Set up inputs from outputs of previous nodes. 131 for (auto* e : n->in_edges()) { 132 if (e->IsControlEdge()) continue; 133 Node* src = e->src(); 134 TF_RET_CHECK(src->id() < output_registry.size()); 135 const NodeOutputs& src_outputs = output_registry[src->id()]; 136 137 tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output()); 138 } 139 140 OpKernelContext op_context(¶ms, n->num_outputs()); 141 if (IsFunctional(n)) { 142 TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); 143 } else { 144 device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); 145 Status s = op_context.status(); 146 if (!s.ok()) { 147 return AttachDef(s, n->def()); 148 } 149 } 150 151 // Set up outputs. Also check if outputs from the previous computation is 152 // valid. 153 NodeOutputs& outputs = output_registry[n->id()]; 154 outputs.resize(n->num_outputs()); 155 for (int o = 0; o < n->num_outputs(); ++o) { 156 outputs[o] = op_context.release_output(o); 157 if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) { 158 return errors::Internal("Missing xla_context ", o, "-th output from ", 159 (*op_context.is_output_dead() ? "(dead)" : ""), 160 SummarizeNode(*n)); 161 } 162 } 163 } 164 return Status::OK(); 165 } 166 167 bool GraphCompiler::IsFunctional(Node* n) { 168 return n->type_string() == FunctionLibraryDefinition::kGradientOp || 169 (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) != 170 nullptr); 171 } 172 173 Status GraphCompiler::CompileFunctionalNode(Node* n, 174 OpKernelContext* op_context) { 175 TF_RET_CHECK(IsFunctional(n)); 176 // For functional nodes, compile them using compiler from the context and call 177 // into the functions. 178 XlaOpKernelContext xla_op_context(op_context); 179 180 XlaCompiler* compiler = xla_op_context.compiler(); 181 182 NameAttrList func; 183 if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) { 184 func.set_name(n->def().op()); 185 } else { 186 func.set_name(FunctionLibraryDefinition::kGradientOp); 187 } 188 *func.mutable_attr() = n->def().attr(); 189 190 std::vector<const XlaExpression*> expressions; 191 192 for (auto tensor : tensor_inputs_) { 193 auto expression = 194 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data()); 195 expressions.push_back(expression); 196 } 197 198 // Prepare the arguments and compile the function. 199 std::vector<XlaCompiler::Argument> arguments; 200 const FunctionBody* fbody; 201 TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody)); 202 203 auto graph = compiler->GetGraph(fbody); 204 205 TF_RETURN_IF_ERROR( 206 PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); 207 208 XlaCompiler::CompilationResult result; 209 210 TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), 211 func, arguments, &result)); 212 213 TF_RET_CHECK(arguments.size() == expressions.size()); 214 215 std::vector<xla::ComputationDataHandle> handles; 216 for (int64 i = 0; i < expressions.size(); ++i) { 217 if (arguments[i].kind == XlaCompiler::Argument::kConstant) { 218 continue; 219 } 220 handles.push_back(expressions[i]->handle()); 221 } 222 223 XlaContext& context = XlaContext::Get(op_context); 224 auto* b = context.builder(); 225 226 auto output_handle = b->Call(*result.computation, handles); 227 // The output handle of `Call` computation is a tuple type. Unzip it so 228 // that it can fit into future computations. 229 for (int64 i = 0; i < n->num_outputs(); ++i) { 230 if (result.outputs[i].is_constant) { 231 xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); 232 } else { 233 xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); 234 } 235 } 236 return b->first_error(); 237 } 238 239 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) { 240 params->device = device_; 241 params->inputs = &tensor_inputs_; 242 params->step_container = step_container_; 243 params->resource_manager = device_->resource_manager(); 244 } 245 246 } // namespace tensorflow 247