Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/type_util.h"
     17 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
     18 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
     19 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
     20 #include "tensorflow/core/util/tensor_format.h"
     21 
     22 namespace tensorflow {
     23 
     24 namespace {
     25 
     26 class ExtractImagePatchesOp : public XlaOpKernel {
     27  public:
     28   explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
     29     OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
     30     OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
     31     OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
     32     OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
     33   }
     34 
     35   void Compile(XlaOpKernelContext* ctx) override {
     36     const TensorFormat data_format = FORMAT_NHWC;
     37     const int num_dims = ksizes_.size();
     38 
     39     OP_REQUIRES(
     40         ctx, num_dims >= 3,
     41         errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
     42     const int num_spatial_dims = num_dims - 2;
     43 
     44     OP_REQUIRES(ctx, strides_.size() == num_dims,
     45                 errors::InvalidArgument("Sliding window strides field must "
     46                                         "specify ",
     47                                         num_dims, " dimensions"));
     48     OP_REQUIRES(ctx, dilations_.size() == num_dims,
     49                 errors::InvalidArgument("Dilations field must "
     50                                         "specify ",
     51                                         num_dims, " dimensions"));
     52 
     53     int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
     54     int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
     55     OP_REQUIRES(
     56         ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
     57         errors::Unimplemented("Current implementation does not yet support "
     58                               "kernel sizes > 1 in the batch and depth "
     59                               "dimensions."));
     60     OP_REQUIRES(
     61         ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
     62         errors::Unimplemented("Current implementation does not yet support "
     63                               "strides in the batch and depth dimensions."));
     64     OP_REQUIRES(
     65         ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
     66         errors::Unimplemented("Current implementation does not support "
     67                               "dilations in the batch and depth dimensions."));
     68 
     69     for (int i = 0; i < num_spatial_dims; ++i) {
     70       int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
     71       OP_REQUIRES(
     72           ctx, ksizes_[input_dim] >= 0,
     73           errors::Unimplemented("Kernel size values must be non-negative; ", i,
     74                                 "th spatial dimension had dilation ",
     75                                 dilations_[input_dim]));
     76       OP_REQUIRES(ctx, strides_[input_dim] >= 1,
     77                   errors::Unimplemented("Stride values must be positive; ", i,
     78                                         "th spatial dimension had dilation ",
     79                                         dilations_[input_dim]));
     80       OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
     81                   errors::Unimplemented("Dilation values must be positive; ", i,
     82                                         "th spatial dimension had dilation ",
     83                                         dilations_[input_dim]));
     84     }
     85 
     86     xla::PrimitiveType type;
     87     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
     88 
     89     const TensorShape input_shape = ctx->InputShape(0);
     90     OP_REQUIRES(
     91         ctx, input_shape.dims() == num_dims,
     92         errors::InvalidArgument("input must be ", num_dims, "-dimensional",
     93                                 input_shape.DebugString()));
     94     const int64 depth = input_shape.dim_size(feature_dim);
     95 
     96     xla::ComputationBuilder* builder = ctx->builder();
     97 
     98     // The following code is equivalent to:
     99     // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
    100     int64 kernel_size = 1;
    101     std::vector<int64> lhs_shape(num_dims, 1);
    102     for (int i = 0; i < num_spatial_dims; ++i) {
    103       int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
    104       lhs_shape[i] = ksizes_[input_dim];
    105       kernel_size *= ksizes_[input_dim];
    106     }
    107     lhs_shape[num_spatial_dims] = depth;
    108     lhs_shape[num_spatial_dims + 1] = 1;
    109 
    110     // Builds an identity matrix as a broadcast equality of iotas.
    111     // iota = np.arange(np.prod(ksize), depth)
    112     // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
    113     xla::ComputationDataHandle iota;
    114     TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
    115                                  kernel_size * depth, &iota));
    116 
    117     auto lhs = builder->Reshape(iota, lhs_shape);
    118     auto filter = builder->ConvertElementType(
    119         builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
    120 
    121     xla::ConvolutionDimensionNumbers dims;
    122     std::vector<int64> window_strides(num_spatial_dims);
    123     std::vector<int64> lhs_dilation(num_spatial_dims, 1);
    124     std::vector<int64> rhs_dilation(num_spatial_dims);
    125     std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
    126 
    127     dims.set_input_batch_dimension(batch_dim);
    128     dims.set_output_batch_dimension(batch_dim);
    129     dims.set_input_feature_dimension(feature_dim);
    130     dims.set_output_feature_dimension(feature_dim);
    131     dims.set_kernel_input_feature_dimension(num_spatial_dims);
    132     dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
    133 
    134     for (int i = 0; i < num_spatial_dims; ++i) {
    135       const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
    136       dims.add_input_spatial_dimensions(dim);
    137       dims.add_kernel_spatial_dimensions(i);
    138       dims.add_output_spatial_dimensions(dim);
    139       window_strides[i] = strides_.at(dim);
    140       rhs_dilation[i] = dilations_.at(dim);
    141 
    142       int64 unused_output_size;
    143       OP_REQUIRES_OK(
    144           ctx, GetWindowedOutputSizeVerboseV2(
    145                    input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
    146                    window_strides[i], padding_, &unused_output_size,
    147                    &padding[i].first, &padding[i].second));
    148     }
    149 
    150     xla::ComputationDataHandle conv =
    151         builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
    152                                     padding, lhs_dilation, rhs_dilation, dims);
    153     ctx->SetOutput(0, conv);
    154   }
    155 
    156  protected:
    157   std::vector<int32> ksizes_;
    158   std::vector<int32> dilations_;
    159   std::vector<int32> strides_;
    160   Padding padding_;
    161 
    162  private:
    163   TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
    164 };
    165 
    166 REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
    167 
    168 }  // namespace
    169 }  // namespace tensorflow
    170