Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/strided_slice_op.h"
     17 #include "tensorflow/compiler/tf2xla/literal_util.h"
     18 #include "tensorflow/compiler/tf2xla/type_util.h"
     19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/kernels/ops_util.h"
     26 #include "tensorflow/core/lib/core/status.h"
     27 #include "tensorflow/core/lib/gtl/array_slice.h"
     28 #include "tensorflow/core/platform/mem.h"
     29 
     30 namespace tensorflow {
     31 namespace {
     32 
     33 class StridedSliceOp : public XlaOpKernel {
     34  public:
     35   explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     36     OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
     37     OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
     38     OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
     39     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
     40     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
     41     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
     42   }
     43 
     44   void Compile(XlaOpKernelContext* ctx) override {
     45     const TensorShape input_shape = ctx->InputShape(0);
     46 
     47     TensorShape final_shape;
     48     gtl::InlinedVector<int64, 4> begin;
     49     gtl::InlinedVector<int64, 4> end;
     50     gtl::InlinedVector<int64, 4> strides;
     51 
     52     xla::Literal begin_literal, end_literal, strides_literal;
     53     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
     54     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
     55     OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
     56 
     57     Tensor begin_tensor, end_tensor, strides_tensor;
     58     OP_REQUIRES_OK(
     59         ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
     60     OP_REQUIRES_OK(ctx,
     61                    LiteralToHostTensor(end_literal, index_type_, &end_tensor));
     62     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
     63                                             &strides_tensor));
     64 
     65     TensorShape dummy_processing_shape;
     66     bool dummy = false;
     67     OP_REQUIRES_OK(ctx,
     68                    ValidateStridedSliceOp(
     69                        &begin_tensor, &end_tensor, strides_tensor, input_shape,
     70                        begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
     71                        shrink_axis_mask_, &dummy_processing_shape, &final_shape,
     72                        &dummy, &dummy, &dummy, &begin, &end, &strides));
     73 
     74     gtl::InlinedVector<int64, 4> dimensions_to_reverse;
     75     gtl::InlinedVector<int64, 4> slice_begin, slice_end, slice_strides;
     76 
     77     for (int i = 0; i < begin.size(); ++i) {
     78       if (strides[i] > 0) {
     79         slice_begin.push_back(begin[i]);
     80         slice_end.push_back(end[i]);
     81         slice_strides.push_back(strides[i]);
     82       } else {
     83         // Negative stride: swap begin and end, add 1 because the interval
     84         // is semi-open, and mark the dimension to be reversed.
     85         slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1);
     86         slice_end.push_back(input_shape.dim_size(i) - end[i] - 1);
     87         slice_strides.push_back(-strides[i]);
     88         dimensions_to_reverse.push_back(i);
     89       }
     90     }
     91 
     92     xla::ComputationDataHandle slice = ctx->Input(0);
     93     if (!dimensions_to_reverse.empty()) {
     94       slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
     95     }
     96 
     97     slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
     98 
     99     slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
    100     ctx->SetOutput(0, slice);
    101   }
    102 
    103  private:
    104   int32 begin_mask_, end_mask_;
    105   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
    106   DataType index_type_;
    107 };
    108 
    109 REGISTER_XLA_OP(Name("StridedSlice")
    110                     .CompileTimeConstInput("begin")
    111                     .CompileTimeConstInput("end")
    112                     .CompileTimeConstInput("strides"),
    113                 StridedSliceOp);
    114 
    115 class StridedSliceGradOp : public XlaOpKernel {
    116  public:
    117   explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    118     OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
    119     OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
    120     OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
    121     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
    122     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
    123     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
    124   }
    125 
    126   void Compile(XlaOpKernelContext* ctx) override {
    127     TensorShape processing_shape, final_shape;
    128     gtl::InlinedVector<int64, 4> begin;
    129     gtl::InlinedVector<int64, 4> end;
    130     gtl::InlinedVector<int64, 4> strides;
    131 
    132     TensorShape input_shape;
    133     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
    134 
    135     xla::Literal begin_literal, end_literal, strides_literal;
    136     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
    137     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
    138     OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
    139 
    140     Tensor begin_tensor, end_tensor, strides_tensor;
    141     OP_REQUIRES_OK(
    142         ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
    143     OP_REQUIRES_OK(ctx,
    144                    LiteralToHostTensor(end_literal, index_type_, &end_tensor));
    145     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
    146                                             &strides_tensor));
    147 
    148     bool dummy = false;
    149     OP_REQUIRES_OK(
    150         ctx, ValidateStridedSliceOp(
    151                  &begin_tensor, &end_tensor, strides_tensor, input_shape,
    152                  begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
    153                  shrink_axis_mask_, &processing_shape, &final_shape, &dummy,
    154                  &dummy, &dummy, &begin, &end, &strides));
    155 
    156     // Check to make sure dy is consistent with the original slice
    157     const TensorShape dy_shape = ctx->InputShape(4);
    158     OP_REQUIRES(
    159         ctx, final_shape == dy_shape,
    160         errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
    161                                 " instead of ", final_shape.DebugString()));
    162 
    163     OP_REQUIRES(
    164         ctx, input_shape.dims() == processing_shape.dims(),
    165         errors::Internal(
    166             "input shape and processing shape must have same number of dims"));
    167 
    168     auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
    169 
    170     xla::ComputationDataHandle grad = ctx->Input(4);
    171 
    172     // Undo any new/shrink axes.
    173     grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
    174 
    175     // Pad the input gradients.
    176     gtl::InlinedVector<int64, 4> dimensions_to_reverse;
    177     xla::PaddingConfig padding_config;
    178 
    179     for (int i = 0; i < processing_shape.dims(); ++i) {
    180       auto* dims = padding_config.add_dimensions();
    181       if (strides[i] > 0) {
    182         dims->set_edge_padding_low(begin[i]);
    183         dims->set_interior_padding(strides[i] - 1);
    184 
    185         // Pad the upper dimension up to the expected input shape. (It's
    186         // not sufficient simply to use "end[i]" to compute the padding in
    187         // cases where the stride does not divide evenly into the interval
    188         // between begin[i] and end[i].)
    189         int64 size =
    190             dims->edge_padding_low() + processing_shape.dim_size(i) +
    191             (processing_shape.dim_size(i) - 1) * dims->interior_padding();
    192         dims->set_edge_padding_high(input_shape.dim_size(i) - size);
    193       } else {
    194         dimensions_to_reverse.push_back(i);
    195         dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1);
    196         dims->set_interior_padding(-strides[i] - 1);
    197 
    198         // Pad the lower dimension up to the expected input shape.
    199         int64 size =
    200             dims->edge_padding_high() + processing_shape.dim_size(i) +
    201             (processing_shape.dim_size(i) - 1) * dims->interior_padding();
    202         dims->set_edge_padding_low(input_shape.dim_size(i) - size);
    203       }
    204     }
    205     if (!dimensions_to_reverse.empty()) {
    206       grad = ctx->builder()->Rev(grad, dimensions_to_reverse);
    207     }
    208     grad = ctx->builder()->Pad(grad, zero, padding_config);
    209     ctx->SetOutput(0, grad);
    210   }
    211 
    212  private:
    213   int32 begin_mask_, end_mask_;
    214   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
    215   DataType index_type_;
    216 };
    217 
    218 REGISTER_XLA_OP(Name("StridedSliceGrad")
    219                     .CompileTimeConstInput("shape")
    220                     .CompileTimeConstInput("begin")
    221                     .CompileTimeConstInput("end")
    222                     .CompileTimeConstInput("strides"),
    223                 StridedSliceGradOp);
    224 
    225 class StridedSliceAssignOp : public XlaOpKernel {
    226  public:
    227   explicit StridedSliceAssignOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    228     OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
    229     OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
    230     OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
    231     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
    232     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
    233     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
    234     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
    235   }
    236 
    237   void Compile(XlaOpKernelContext* ctx) override {
    238     TensorShape final_shape;
    239     gtl::InlinedVector<int64, 4> begin;
    240     gtl::InlinedVector<int64, 4> end;
    241     gtl::InlinedVector<int64, 4> strides;
    242 
    243     xla::Literal begin_literal, end_literal, strides_literal;
    244     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
    245     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
    246     OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
    247 
    248     Tensor begin_tensor, end_tensor, strides_tensor;
    249     OP_REQUIRES_OK(
    250         ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
    251     OP_REQUIRES_OK(ctx,
    252                    LiteralToHostTensor(end_literal, index_type_, &end_tensor));
    253     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
    254                                             &strides_tensor));
    255 
    256     TensorShape lhs_shape;
    257     xla::ComputationDataHandle lhs;
    258     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
    259 
    260     const TensorShape rhs_shape = ctx->InputShape(4);
    261 
    262     TensorShape dummy_processing_shape;
    263     bool dummy = false;
    264     OP_REQUIRES_OK(ctx,
    265                    ValidateStridedSliceOp(
    266                        &begin_tensor, &end_tensor, strides_tensor, lhs_shape,
    267                        begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_,
    268                        shrink_axis_mask_, &dummy_processing_shape, &final_shape,
    269                        &dummy, &dummy, &dummy, &begin, &end, &strides));
    270 
    271     if (final_shape.num_elements() == 0 && rhs_shape.num_elements() == 0) {
    272       // DynamicUpdateSlice does not allow 0-element updates. We should probably
    273       // check that rhs_shape can be broadcast to final_shape, but that is
    274       // probably better handled when implementing broadcasting more generally.
    275       return;
    276     }
    277 
    278     // TODO(aselle): This check is too strong, we only should need
    279     // input_shape to be broadcastable to final_shape
    280     OP_REQUIRES(ctx, final_shape == rhs_shape,
    281                 errors::Unimplemented(
    282                     "sliced l-value shape ", final_shape.DebugString(),
    283                     " does not match r-value shape ", rhs_shape.DebugString(),
    284                     ". Automatic broadcasting not yet implemented."));
    285 
    286     xla::ComputationDataHandle rhs = ctx->Input(4);
    287 
    288     gtl::InlinedVector<int64, 4> dimensions_to_reverse;
    289     gtl::InlinedVector<int64, 4> slice_begin, slice_dims;
    290     for (int i = 0; i < begin.size(); ++i) {
    291       // TODO(phawkins): implement strides != 1
    292       OP_REQUIRES(
    293           ctx, strides[i] == 1 || strides[i] == -1,
    294           errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
    295       if (strides[i] > 0) {
    296         slice_begin.push_back(begin[i]);
    297         slice_dims.push_back(end[i] - begin[i]);
    298       } else {
    299         // Negative stride: swap begin and end, add 1 because the interval
    300         // is semi-open, and mark the dimension to be reversed.
    301         slice_begin.push_back(end[i] + 1);
    302         slice_dims.push_back(begin[i] - end[i]);
    303         dimensions_to_reverse.push_back(i);
    304       }
    305     }
    306 
    307     if (!dimensions_to_reverse.empty()) {
    308       rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
    309     }
    310     rhs = ctx->builder()->Reshape(rhs, slice_dims);
    311 
    312     if (lhs_shape.dims() == 0) {
    313       // TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
    314       // and remove this workaround.
    315       lhs = rhs;
    316     } else {
    317       lhs = ctx->builder()->DynamicUpdateSlice(
    318           lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
    319     }
    320 
    321     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
    322   }
    323 
    324  private:
    325   int32 begin_mask_, end_mask_;
    326   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
    327   DataType index_type_;
    328   DataType dtype_;
    329 };
    330 
    331 REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
    332                     .CompileTimeConstInput("begin")
    333                     .CompileTimeConstInput("end")
    334                     .CompileTimeConstInput("strides"),
    335                 StridedSliceAssignOp);
    336 
    337 }  // namespace
    338 }  // namespace tensorflow
    339