Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/kernels/while_op.h"
     17 
     18 #include "tensorflow/compiler/tf2xla/shape_util.h"
     19 #include "tensorflow/compiler/tf2xla/type_util.h"
     20 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
     21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     24 #include "tensorflow/compiler/xla/client/computation_builder.h"
     25 #include "tensorflow/compiler/xla/literal_util.h"
     26 #include "tensorflow/core/framework/function.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 
     29 namespace tensorflow {
     30 
     31 namespace {
     32 
     33 // Builds XlaCompiler argument descriptions `args` from `ctx`.
     34 Status MakeXlaCompilerArgumentsFromInputs(
     35     XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
     36     bool* has_uninitialized_vars, bool* has_tensor_arrays) {
     37   VLOG(2) << "Num inputs " << ctx->num_inputs();
     38   args->resize(ctx->num_inputs());
     39   *has_uninitialized_vars = false;
     40   *has_tensor_arrays = false;
     41   for (int i = 0; i < ctx->num_inputs(); ++i) {
     42     VLOG(2) << " Input " << i
     43             << " type: " << DataTypeString(ctx->input_type(i))
     44             << " shape: " << ctx->InputShape(i).DebugString();
     45     XlaCompiler::Argument& arg = (*args)[i];
     46     DataType type = ctx->input_type(i);
     47     // When reading a resource input, use the type and shape of the resource's
     48     // current value.
     49     if (type == DT_RESOURCE) {
     50       XlaResource* resource;
     51       TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
     52 
     53       arg.initialized = resource->initialized();
     54       arg.kind = XlaCompiler::Argument::kResource;
     55       arg.resource_kind = resource->kind();
     56       if (arg.resource_kind == XlaResource::kTensorArray) {
     57         *has_tensor_arrays = true;
     58       }
     59 
     60       arg.type = resource->type();
     61       arg.shape = resource->shape();
     62       if (!arg.initialized) {
     63         *has_uninitialized_vars = true;
     64       }
     65       arg.tensor_array_size = resource->tensor_array_size();
     66       for (const auto& gradient : resource->tensor_array_gradients()) {
     67         arg.tensor_array_gradients.insert(gradient.first);
     68       }
     69       arg.name = resource->name();
     70       VLOG(2) << "    resource " << resource->name()
     71               << " type: " << DataTypeString(arg.type)
     72               << " shape: " << arg.shape.DebugString()
     73               << " initialized: " << arg.initialized;
     74 
     75     } else {
     76       arg.kind = XlaCompiler::Argument::kParameter;
     77       arg.type = ctx->input_type(i);
     78       arg.shape = ctx->InputShape(i);
     79     }
     80   }
     81   return Status::OK();
     82 }
     83 
     84 }  // anonymous namespace
     85 
     86 XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     87   const NameAttrList* name_attr;
     88   OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr));
     89   cond_name_attr_ = *name_attr;
     90   OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
     91   body_name_attr_ = *name_attr;
     92 }
     93 
     94 void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
     95   VLOG(1) << "WhileOp::Compile";
     96 
     97   std::vector<XlaCompiler::Argument> arguments;
     98   bool has_uninitialized_vars;
     99   bool has_tensor_arrays;
    100   OP_REQUIRES_OK(
    101       ctx, MakeXlaCompilerArgumentsFromInputs(
    102                ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays));
    103 
    104   xla::ComputationBuilder* builder = ctx->builder();
    105   XlaCompiler* compiler = ctx->compiler();
    106 
    107   VLOG(1) << "Compiling body";
    108 
    109   // All resource that are inputs to the loop's body must also be
    110   // present as loop body outputs; the signature of the loop's input and
    111   // output must match. We ensure this by asking the compiler to include the
    112   // current values of all resources, even if they haven't been updated by the
    113   // computation. We must also ask the compiler to keep compile-time constant
    114   // outputs as part of the generated computation, for the same reason.
    115   // TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
    116   // operator.
    117   XlaCompiler::CompileOptions body_options;
    118   body_options.use_tuple_arg = true;
    119   body_options.return_updated_values_for_all_resources = true;
    120   body_options.resolve_compile_time_constants = false;
    121   body_options.is_entry_computation = false;
    122   XlaCompiler::CompilationResult body;
    123   OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
    124                                                 arguments, &body));
    125 
    126   // We must use a static shape for parameters to an XLA compilation. However,
    127   // we may not know the shape of a resource if it is first
    128   // written inside the loop. Furthermore, we do not know ahead of time which
    129   // gradient TensorArrays will be created by the TensorArrayGradV3 operator.
    130   //
    131   // Ideally we would change TensorFlow to provide static shape always, but
    132   // but this is not easy to do. So if uninitialized resources or TensorArrays
    133   // are used by the loop body, we compile the body function twice:
    134   // 1) once with uninitialized resource inputs and no TensorArray gradient
    135   //    inputs. We then discard the computation but we assume resource shapes
    136   //    and the set of gradients read or written will reach a fixpoint after one
    137   //    iteration.
    138   //    Hence we can use the output shapes and TensorArray gradients of each
    139   //    resource as the "true" shapes.
    140   // 2) again with the "correct" resource information determined by (1).
    141   if (has_uninitialized_vars || has_tensor_arrays) {
    142     VLOG(2) << "Recompiling loop body: has_uninitialized_vars: "
    143             << has_uninitialized_vars
    144             << " has_tensor_arrays: " << has_tensor_arrays;
    145     // Initializes any uninitialized resource with zero values of the
    146     // shape determined by the first compilation.
    147     for (int i = 0; i < body.resource_updates.size(); ++i) {
    148       const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
    149       XlaResource* resource;
    150       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
    151 
    152       XlaCompiler::Argument& arg = arguments[update.input_index];
    153       if (!arg.initialized) {
    154         VLOG(2) << "Update shape for argument " << update.input_index << " "
    155                 << update.shape.DebugString();
    156         arg.initialized = true;
    157 
    158         arg.shape = update.shape;
    159         OP_REQUIRES_OK(ctx,
    160                        resource->SetTypeAndShape(update.type, update.shape));
    161 
    162         OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
    163       }
    164 
    165       // Add any TensorArray gradients touched by the body to the enclosing
    166       // graph.
    167       for (const string& grad_source : update.tensor_array_gradients_accessed) {
    168         VLOG(4) << "TensorArray " << resource->name() << " accessed gradient "
    169                 << grad_source;
    170         XlaResource* gradient;
    171         OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
    172                                 grad_source, builder, &gradient));
    173       }
    174 
    175       // Add all of the TensorArray gradients to the argument. For simplicity,
    176       // we always pass all known gradients.
    177       for (const auto& gradient : resource->tensor_array_gradients()) {
    178         arg.tensor_array_gradients.insert(gradient.first);
    179       }
    180     }
    181     // Recompile the body with the "correct" resource shapes.
    182     VLOG(1) << "Recompiling body with corrected resource shapes";
    183     body = {};
    184     OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
    185                                                   arguments, &body));
    186   }
    187 
    188   VLOG(1) << "Compiling condition";
    189 
    190   XlaCompiler::CompileOptions cond_options;
    191   cond_options.use_tuple_arg = true;
    192   cond_options.resolve_compile_time_constants = false;
    193   cond_options.is_entry_computation = false;
    194   XlaCompiler::CompilationResult cond;
    195   OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
    196                                                 arguments, &cond));
    197 
    198   OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1,
    199               errors::FailedPrecondition("Expected one input shape"));
    200   xla::Shape body_input_shape = body.xla_input_shapes[0];
    201   OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape),
    202               errors::FailedPrecondition("Expected tuple shape"));
    203   OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1,
    204               errors::FailedPrecondition("Expected one input shape"));
    205   xla::Shape cond_input_shape = cond.xla_input_shapes[0];
    206   OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape),
    207               errors::FailedPrecondition("Expected tuple shape"));
    208 
    209   VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
    210           << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
    211   VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape)
    212           << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape);
    213 
    214   OP_REQUIRES(ctx,
    215               xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape),
    216               errors::InvalidArgument(
    217                   "Input shapes of loop body and condition do not match: ",
    218                   xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
    219                   xla::ShapeUtil::HumanString(cond_input_shape)));
    220   OP_REQUIRES(
    221       ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape),
    222       errors::InvalidArgument(
    223           "Input and output shapes of loop body do not match: ",
    224           xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
    225           xla::ShapeUtil::HumanString(body.xla_output_shape)));
    226 
    227   xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape(
    228       {xla::ShapeUtil::MakeShape(xla::PRED, {})});
    229   OP_REQUIRES(ctx,
    230               xla::ShapeUtil::Compatible(cond.xla_output_shape,
    231                                          expected_cond_output_shape),
    232               errors::InvalidArgument(
    233                   "Output shape of loop condition should be (pred[]), got: ",
    234                   xla::ShapeUtil::HumanString(cond.xla_output_shape)));
    235 
    236   int num_inputs = body.input_mapping.size();
    237   std::vector<xla::ComputationDataHandle> inputs(num_inputs);
    238   for (int i = 0; i < num_inputs; ++i) {
    239     int input_num = body.input_mapping[i];
    240     if (ctx->input_type(input_num) == DT_RESOURCE) {
    241       XlaResource* resource;
    242       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
    243       OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
    244     } else {
    245       inputs[i] = ctx->Input(i);
    246     }
    247   }
    248 
    249   xla::ComputationDataHandle init = builder->Tuple(inputs);
    250 
    251   VLOG(1) << "Building while loop";
    252 
    253   // Wraps the condition in a computation that unpacks the output tuple.
    254   xla::Computation cond_wrapper;
    255   {
    256     std::unique_ptr<xla::ComputationBuilder> cb =
    257         builder->CreateSubBuilder("cond_wrapper");
    258     auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
    259     auto outputs = cb->Call(*cond.computation, {inputs});
    260     cb->GetTupleElement(outputs, 0);
    261     xla::StatusOr<xla::Computation> result = cb->Build();
    262     OP_REQUIRES_OK(ctx, result.status());
    263     cond_wrapper = std::move(result.ValueOrDie());
    264   }
    265 
    266   xla::ComputationDataHandle while_result =
    267       builder->While(cond_wrapper, *body.computation, init);
    268 
    269   // Sets non-variable outputs.
    270   for (int i = 0; i < ctx->num_outputs(); ++i) {
    271     if (ctx->input_type(i) != DT_RESOURCE) {
    272       ctx->SetOutput(body.input_mapping[i],
    273                      builder->GetTupleElement(while_result, i));
    274     }
    275   }
    276 
    277   // Updates the values of any resource variables modified by the loop.
    278   for (int i = 0; i < body.resource_updates.size(); ++i) {
    279     const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
    280     XlaResource* resource;
    281     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
    282     if (update.modified) {
    283       int pos = body.outputs.size() + i;
    284       OP_REQUIRES_OK(ctx,
    285                      resource->SetFromPack(
    286                          arguments[update.input_index].tensor_array_gradients,
    287                          builder->GetTupleElement(while_result, pos), builder));
    288     }
    289     VLOG(2) << "Loop-carried variable: pos: " << update.input_index
    290             << " name: " << resource->name() << " modified: " << update.modified
    291             << " type: " << DataTypeString(update.type)
    292             << " shape: " << update.shape.DebugString();
    293     // Copies the identity of the resource variable from input to output
    294     // unchanged, even if the variable was not modified.
    295     ctx->op_kernel_context()->set_output(
    296         update.input_index,
    297         ctx->op_kernel_context()->input(update.input_index));
    298   }
    299 
    300   VLOG(1) << "Done building while loop";
    301 }
    302 
    303 REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
    304 
    305 }  // namespace tensorflow
    306