Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
     17 
     18 #include <deque>
     19 #include <numeric>
     20 #include <vector>
     21 #include "tensorflow/compiler/tf2xla/const_analysis.h"
     22 #include "tensorflow/compiler/tf2xla/dump_graph.h"
     23 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
     24 #include "tensorflow/compiler/tf2xla/literal_util.h"
     25 #include "tensorflow/compiler/tf2xla/shape_util.h"
     26 #include "tensorflow/compiler/tf2xla/type_util.h"
     27 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     28 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     29 #include "tensorflow/compiler/tf2xla/xla_context.h"
     30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     31 #include "tensorflow/compiler/xla/client/client_library.h"
     32 #include "tensorflow/core/common_runtime/device.h"
     33 #include "tensorflow/core/common_runtime/executor.h"
     34 #include "tensorflow/core/common_runtime/function.h"
     35 #include "tensorflow/core/common_runtime/graph_optimizer.h"
     36 #include "tensorflow/core/framework/attr_value_util.h"
     37 #include "tensorflow/core/framework/node_def_util.h"
     38 #include "tensorflow/core/framework/op_kernel.h"
     39 #include "tensorflow/core/graph/algorithm.h"
     40 #include "tensorflow/core/graph/graph_constructor.h"
     41 #include "tensorflow/core/graph/node_builder.h"
     42 #include "tensorflow/core/lib/gtl/cleanup.h"
     43 #include "tensorflow/core/lib/hash/hash.h"
     44 #include "tensorflow/core/platform/logging.h"
     45 #include "tensorflow/core/public/version.h"
     46 
     47 namespace tensorflow {
     48 
     49 namespace {
     50 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
     51                         const std::vector<const XlaExpression*>& expressions,
     52                         std::vector<XlaCompiler::Argument>* args) {
     53   auto builder = ctx->builder();
     54   std::vector<bool> compile_time_constant_flags(expressions.size());
     55 
     56   TF_RETURN_IF_ERROR(
     57       BackwardsConstAnalysis(*graph, &compile_time_constant_flags));
     58 
     59   args->resize(expressions.size());
     60   for (int i = 0; i < args->size(); ++i) {
     61     XlaCompiler::Argument& arg = (*args)[i];
     62     arg.type = ctx->input_type(i);
     63     arg.shape = ctx->InputShape(i);
     64 
     65     if (arg.type == DT_RESOURCE) {
     66       return errors::InvalidArgument(
     67           "Resource as function argument is not yet implemented.");
     68     } else if (expressions[i]->has_constant_value()) {
     69       arg.kind = XlaCompiler::Argument::kConstant;
     70       arg.constant_value = expressions[i]->constant_value();
     71     } else if (compile_time_constant_flags[i]) {
     72       arg.kind = XlaCompiler::Argument::kConstant;
     73       TF_RET_CHECK(expressions[i]->resource() == nullptr)
     74           << "Input with resource is not yet implemented.";
     75       TF_ASSIGN_OR_RETURN(auto literal,
     76                           builder->ComputeConstant(expressions[i]->handle()));
     77       TF_RETURN_IF_ERROR(
     78           LiteralToHostTensor(*literal, arg.type, &arg.constant_value));
     79     } else {
     80       arg.kind = XlaCompiler::Argument::kParameter;
     81     }
     82   }
     83   return Status::OK();
     84 }
     85 }  // namespace
     86 Status GraphCompiler::Compile() {
     87   // Maintain a mapping from node id to node outputs.
     88   using NodeOutputs = std::vector<TensorValue>;
     89   std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
     90   auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
     91     for (const NodeOutputs& outputs : output_registry) {
     92       for (const TensorValue& value : outputs) {
     93         CHECK(!value.is_ref());
     94         delete value.tensor;
     95       }
     96     }
     97   });
     98 
     99   // XLA requires determinism, generate a stable ordering from DFS.
    100   std::vector<Node*> topo_sorted_nodes;
    101   GetReversePostOrder(*graph_, &topo_sorted_nodes,
    102                       /*stable_comparator=*/NodeComparatorName());
    103 
    104   OpKernelContext::Params params;
    105   PartiallySetupParams(&params);
    106 
    107   for (Node* n : topo_sorted_nodes) {
    108     OpKernel* op_kernel_raw = nullptr;
    109     Status s = flib_->CreateKernel(n->def(), &op_kernel_raw);
    110     // Transfer ownership of the kernel to a local smart pointer.
    111     std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
    112 
    113     if (!s.ok()) {
    114       s = AttachDef(s, *n);
    115       LOG(ERROR) << "Executor failed to create kernel. " << s;
    116       return s;
    117     }
    118 
    119     TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
    120         << "Not supported node: " << n->DebugString();
    121     params.op_kernel = op_kernel.get();
    122     gtl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
    123     params.output_attr_array = output_attr.data();
    124 
    125     // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
    126     // reinitialize the buffer before we visit a new node.
    127     tensor_inputs_.clear();
    128     tensor_inputs_.resize(n->num_inputs());
    129 
    130     // Set up inputs from outputs of previous nodes.
    131     for (auto* e : n->in_edges()) {
    132       if (e->IsControlEdge()) continue;
    133       Node* src = e->src();
    134       TF_RET_CHECK(src->id() < output_registry.size());
    135       const NodeOutputs& src_outputs = output_registry[src->id()];
    136 
    137       tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
    138     }
    139 
    140     OpKernelContext op_context(&params, n->num_outputs());
    141     if (IsFunctional(n)) {
    142       TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
    143     } else {
    144       device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
    145       Status s = op_context.status();
    146       if (!s.ok()) {
    147         return AttachDef(s, n->def());
    148       }
    149     }
    150 
    151     // Set up outputs. Also check if outputs from the previous computation is
    152     // valid.
    153     NodeOutputs& outputs = output_registry[n->id()];
    154     outputs.resize(n->num_outputs());
    155     for (int o = 0; o < n->num_outputs(); ++o) {
    156       outputs[o] = op_context.release_output(o);
    157       if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) {
    158         return errors::Internal("Missing xla_context ", o, "-th output from ",
    159                                 (*op_context.is_output_dead() ? "(dead)" : ""),
    160                                 SummarizeNode(*n));
    161       }
    162     }
    163   }
    164   return Status::OK();
    165 }
    166 
    167 bool GraphCompiler::IsFunctional(Node* n) {
    168   return n->type_string() == FunctionLibraryDefinition::kGradientOp ||
    169          (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) !=
    170           nullptr);
    171 }
    172 
    173 Status GraphCompiler::CompileFunctionalNode(Node* n,
    174                                             OpKernelContext* op_context) {
    175   TF_RET_CHECK(IsFunctional(n));
    176   // For functional nodes, compile them using compiler from the context and call
    177   // into the functions.
    178   XlaOpKernelContext xla_op_context(op_context);
    179 
    180   XlaCompiler* compiler = xla_op_context.compiler();
    181 
    182   NameAttrList func;
    183   if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) {
    184     func.set_name(n->def().op());
    185   } else {
    186     func.set_name(FunctionLibraryDefinition::kGradientOp);
    187   }
    188   *func.mutable_attr() = n->def().attr();
    189 
    190   std::vector<const XlaExpression*> expressions;
    191 
    192   for (auto tensor : tensor_inputs_) {
    193     auto expression =
    194         reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
    195     expressions.push_back(expression);
    196   }
    197 
    198   // Prepare the arguments and compile the function.
    199   std::vector<XlaCompiler::Argument> arguments;
    200   const FunctionBody* fbody;
    201   TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
    202 
    203   auto graph = compiler->GetGraph(fbody);
    204 
    205   TF_RETURN_IF_ERROR(
    206       PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
    207 
    208   XlaCompiler::CompilationResult result;
    209 
    210   TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(),
    211                                                func, arguments, &result));
    212 
    213   TF_RET_CHECK(arguments.size() == expressions.size());
    214 
    215   std::vector<xla::ComputationDataHandle> handles;
    216   for (int64 i = 0; i < expressions.size(); ++i) {
    217     if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
    218       continue;
    219     }
    220     handles.push_back(expressions[i]->handle());
    221   }
    222 
    223   XlaContext& context = XlaContext::Get(op_context);
    224   auto* b = context.builder();
    225 
    226   auto output_handle = b->Call(*result.computation, handles);
    227   // The output handle of `Call` computation is a tuple type. Unzip it so
    228   // that it can fit into future computations.
    229   for (int64 i = 0; i < n->num_outputs(); ++i) {
    230     if (result.outputs[i].is_constant) {
    231       xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
    232     } else {
    233       xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i));
    234     }
    235   }
    236   return b->first_error();
    237 }
    238 
    239 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
    240   params->device = device_;
    241   params->inputs = &tensor_inputs_;
    242   params->step_container = step_container_;
    243   params->resource_manager = device_->resource_manager();
    244 }
    245 
    246 }  // namespace tensorflow
    247