Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
     17 #include "tensorflow/compiler/tf2xla/lib/while_loop.h"
     18 #include "tensorflow/compiler/tf2xla/shape_util.h"
     19 #include "tensorflow/compiler/tf2xla/type_util.h"
     20 #include "tensorflow/compiler/tf2xla/xla_context.h"
     21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     24 #include "tensorflow/core/framework/kernel_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 
     27 namespace tensorflow {
     28 
     29 Status XlaGather(const xla::ComputationDataHandle& input,
     30                  const TensorShape& input_shape,
     31                  const xla::ComputationDataHandle& indices,
     32                  TensorShape indices_shape, int64 axis, bool indices_are_nd,
     33                  DataType dtype, DataType index_type,
     34                  xla::ComputationBuilder* builder,
     35                  xla::ComputationDataHandle* gather_output) {
     36   // If the indices are N-dimensional, then the minor dimension of indices
     37   // should be of size N and correspond to the N indices.
     38   int64 num_index_dims = 1;
     39   if (indices_are_nd) {
     40     CHECK_GE(indices_shape.dims(), 1);
     41     num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
     42     indices_shape.RemoveLastDims(1);
     43   }
     44 
     45   // Although the indices Tensor is flattened into rank 1 during the lookup,
     46   // and each scalar entry is used as an index into the first dimension of the
     47   // input, the output is returned with shape:
     48   // input.shape[:axis] + indices.shape + input.shape[axis+1:]
     49 
     50   const int64 num_indices = indices_shape.num_elements();
     51   TensorShape input_shape_pre_axis(input_shape);
     52   input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
     53   TensorShape input_shape_post_axis(input_shape);
     54   input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
     55   // Each slice of the input tensor has shape:
     56   // [<input_shape_pre_axis>, 1, ..., 1, <input shape_post_axis>]
     57   TensorShape slice_shape(input_shape);
     58   for (int64 i = 0; i < num_index_dims; ++i) {
     59     slice_shape.set_dim(axis + i, 1);
     60   }
     61 
     62   TensorShape loop_out_shape;
     63   loop_out_shape.AppendShape(input_shape_pre_axis);
     64   loop_out_shape.AddDim(num_indices);
     65   loop_out_shape.AppendShape(input_shape_post_axis);
     66   TensorShape loop_out_slice_shape;
     67   loop_out_slice_shape.AppendShape(input_shape_pre_axis);
     68   loop_out_slice_shape.AddDim(1);
     69   loop_out_slice_shape.AppendShape(input_shape_post_axis);
     70 
     71   TensorShape out_shape;
     72   out_shape.AppendShape(input_shape_pre_axis);
     73   out_shape.AppendShape(indices_shape);
     74   out_shape.AppendShape(input_shape_post_axis);
     75 
     76   // Degenerate case: empty indices.
     77   if (num_indices == 0) {
     78     *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
     79                                         out_shape.dim_sizes());
     80     return Status::OK();
     81   }
     82 
     83   for (int64 i = 0; i < num_index_dims; ++i) {
     84     if (input_shape.dim_size(axis + i) == 0) {
     85       return errors::InvalidArgument("Gather dimension ", axis + i,
     86                                      " is of size zero in tensor with shape ",
     87                                      input_shape.DebugString());
     88     }
     89   }
     90 
     91   // Flatten the major dimensions of indices into a single dimension for ease of
     92   // iteration. If there is an axis dimension, we must leave it alone.
     93   std::vector<int64> flat_indices_shape = {num_indices};
     94   if (indices_are_nd) {
     95     flat_indices_shape.push_back(num_index_dims);
     96   }
     97 
     98   // Specify the shape of the loop-carried Tensor tuple.
     99 
    100   // Construct the initial values of the loop-carried Tensors.
    101   auto flat_indices = builder->Reshape(indices, flat_indices_shape);
    102   auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
    103                                      loop_out_shape.dim_sizes());
    104   auto init = {input, flat_indices, init_out};
    105 
    106   // Construct the while loop body's function. The implementation of gather is:
    107   // for i in range(num_indices):
    108   //   index = dynamic-slice(indices, i)
    109   //   xi = dynamic-slice(input, index)
    110   //   output = dynamic-update-slice(output, xi, i)
    111   auto body_fn = [&](xla::ComputationDataHandle i,
    112                      gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
    113                      xla::ComputationBuilder* bodyb) {
    114     auto input = loop_vars[0];
    115     auto indices = loop_vars[1];
    116     auto output = loop_vars[2];
    117 
    118     auto zero_index = XlaHelpers::Zero(bodyb, index_type);
    119 
    120     // Slice the i-th index from the indices array.
    121     xla::ComputationDataHandle index;
    122     auto indices_offset = bodyb->Reshape(i, {1});
    123     if (indices_are_nd) {
    124       // Slice out the entire nd index, if applicable.
    125       indices_offset = bodyb->Pad(indices_offset, zero_index,
    126                                   xla::MakeEdgePaddingConfig({{0, 1}}));
    127       index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims});
    128       index = bodyb->Collapse(index, {0, 1});
    129     } else {
    130       index = bodyb->DynamicSlice(indices, indices_offset, {1});
    131     }
    132 
    133     // Slice the corresponding data from the input array.
    134     auto start_indices = bodyb->Pad(
    135         index, zero_index,
    136         xla::MakeEdgePaddingConfig(
    137             {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}}));
    138     auto slice_i = bodyb->Reshape(
    139         bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()),
    140         loop_out_slice_shape.dim_sizes());
    141 
    142     // Construct the index into the output Tensor 0, ..., <index>, 0, ...
    143     std::vector<xla::ComputationDataHandle> out_index_vals(
    144         loop_out_shape.dims(), bodyb->Reshape(zero_index, {1}));
    145     out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1});
    146     auto out_index = bodyb->ConcatInDim(out_index_vals, 0);
    147 
    148     // Update the output Tensor
    149     auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index);
    150 
    151     return std::vector<xla::ComputationDataHandle>{input, indices,
    152                                                    updated_output};
    153   };
    154 
    155   // Construct the While loop, extract and reshape the output.
    156   xla::PrimitiveType ptype;
    157   TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
    158   TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
    159                                                     init, "gather", builder));
    160   *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
    161   return Status::OK();
    162 }
    163 
    164 class GatherOp : public XlaOpKernel {
    165  public:
    166   explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
    167 
    168   void Compile(XlaOpKernelContext* context) override {
    169     xla::ComputationBuilder* builder = context->builder();
    170     auto input = context->Input(0);
    171     auto input_shape = context->InputShape(0);
    172     auto indices = context->Input(1);
    173     auto indices_shape = context->InputShape(1);
    174     int64 axis = 0;
    175     if (context->num_inputs() == 3) {
    176       const TensorShape axis_shape = context->InputShape(2);
    177       OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape),
    178                   errors::InvalidArgument("axis must be scalar"));
    179       DataType axis_type = input_type(2);
    180       OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64,
    181                   errors::InvalidArgument("axis must be int32 or int64"));
    182 
    183       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis));
    184       const auto params_dims = input_shape.dims();
    185       if (axis < 0) {
    186         axis += params_dims;
    187       }
    188       OP_REQUIRES(
    189           context, 0 <= axis && axis < params_dims,
    190           errors::InvalidArgument("Expected axis in the range [", -params_dims,
    191                                   ", ", params_dims, "), but got ", axis));
    192     }
    193 
    194     DataType index_type = input_type(1);
    195     OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64,
    196                 errors::InvalidArgument("indices must be int32 or int64"));
    197 
    198     xla::ComputationDataHandle gather;
    199     OP_REQUIRES_OK(
    200         context, XlaGather(input, input_shape, indices, indices_shape, axis,
    201                            /*indices_are_nd=*/false, input_type(0), index_type,
    202                            builder, &gather));
    203     context->SetOutput(0, gather);
    204   }
    205 
    206  private:
    207   TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
    208 };
    209 
    210 REGISTER_XLA_OP(Name("Gather"), GatherOp);
    211 REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), GatherOp);
    212 
    213 class GatherNdOp : public XlaOpKernel {
    214  public:
    215   explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
    216 
    217   void Compile(XlaOpKernelContext* context) override {
    218     DataType params_type = context->input_type(0);
    219     DataType indices_type = context->input_type(1);
    220 
    221     TensorShape params_shape = context->InputShape(0);
    222     TensorShape indices_shape = context->InputShape(1);
    223     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape),
    224                 errors::InvalidArgument("params must be at least a vector"));
    225     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
    226                 errors::InvalidArgument("indices must be at least a vector"));
    227     const int64 num_index_dims =
    228         indices_shape.dim_size(indices_shape.dims() - 1);
    229     OP_REQUIRES(
    230         context, num_index_dims <= params_shape.dims(),
    231         errors::InvalidArgument(
    232             "index innermost dimension length must be <= params rank; saw: ",
    233             indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
    234             params_shape.dims()));
    235 
    236     xla::ComputationBuilder* builder = context->builder();
    237     auto params = context->Input(0);
    238     auto indices = context->Input(1);
    239     xla::ComputationDataHandle gather;
    240     OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices,
    241                                       indices_shape, /*axis=*/0,
    242                                       /*indices_are_nd=*/true, params_type,
    243                                       indices_type, builder, &gather));
    244     context->SetOutput(0, gather);
    245   }
    246 };
    247 
    248 REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp);
    249 
    250 }  // namespace tensorflow
    251