Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include <algorithm>
     19 #include <cmath>
     20 #include <type_traits>
     21 
     22 #include "tensorflow/core/framework/numeric_op.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/register_types.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/tensor_shape.h"
     27 #include "tensorflow/core/framework/tensor_types.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/kernels/bounds_check.h"
     30 #include "tensorflow/core/kernels/conv_ops.h"
     31 #include "tensorflow/core/kernels/depthwise_conv_op.h"
     32 #include "tensorflow/core/kernels/ops_util.h"
     33 #include "tensorflow/core/lib/core/status.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/types.h"
     36 #include "tensorflow/core/util/padding.h"
     37 #include "tensorflow/core/util/tensor_format.h"
     38 #include "tensorflow/core/util/use_cudnn.h"
     39 #include "tensorflow/core/util/work_sharder.h"
     40 
     41 #if GOOGLE_CUDA
     42 #include "tensorflow/core/platform/stream_executor.h"
     43 #endif  // GOOGLE_CUDA
     44 
     45 namespace tensorflow {
     46 
     47 // In depthwise convolution, one input is convolved into depth_multipler
     48 // outputs and the outputs don't need to be reduced again like what regular
     49 // convolution does.
     50 //  However, the way to apply filters to inputs is exactly the same as the
     51 // regular convolution. Please refer to the regular convolution kernels for
     52 // more details.
     53 
     54 typedef Eigen::ThreadPoolDevice CPUDevice;
     55 typedef Eigen::GpuDevice GPUDevice;
     56 
     57 // Computes the vectorized product of 'input_buffer' and 'filter' and stores
     58 // result in 'output' at location specified by 'out_r' and 'out_c'.
     59 //
     60 // EX:
     61 //   in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4
     62 //   Both 'input_buffer' and 'filter' are padded to register-width boundaries.
     63 //
     64 //   input_buffer [rows, cols, in_depth, depth_multiplier]
     65 //     [a0, a0, a1, a1] [a2, a2, 0, 0] [b0, b0, b1, b1] [b2, b2, 0, 0]
     66 //     [e0, e0, e1, e1] [e2, e2, 0, 0] [f0, f0, f1, f1] [f2, f2, 0, 0]
     67 //
     68 //   filter [rows, cols, in_depth, depth_multiplier]
     69 //     [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0]
     70 //     [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0]
     71 //
     72 //   First output register [in_depth, depth_multiplier]
     73 //     [q0, q1, q2, q3] = ([a0, a0, a1, a1] x [u0, v0, w0, x0]) +
     74 //                        ([b0, b0, b1, b1] x [u1, v1, w1, x1]) +
     75 //                        ([e0, e0, e1, e1] x [u2, v2, w2, x2]) +
     76 //                        ([f0, f0, f1, f1] x [u3, v3, w3, x3])
     77 //
     78 // TODO(andydavis) Experiment with processing multiple inputs per input buffer.
     79 template <typename T>
     80 struct DepthwiseConv2DKernel {
     81   static void Run(const DepthwiseArgs& args,
     82                   const int64 padded_filter_inner_dim_size, const int64 out_r,
     83                   const int64 out_c, const T* filter, const T* input_buffer,
     84                   T* output, TensorFormat data_format) {
     85     typedef typename Eigen::internal::packet_traits<T>::type Packet;
     86     static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
     87 
     88     const int64 out_depth = args.out_depth;
     89     const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
     90     const int64 output_scalar_size = out_depth % kPacketSize;
     91     const int64 output_vectorized_size =
     92         (out_depth / kPacketSize) * kPacketSize;
     93     const int64 base_output_index = (out_r * args.out_cols + out_c) * out_depth;
     94 
     95     for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
     96       // Reset accumulator.
     97       auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
     98       for (int j = 0; j < filter_spatial_size; ++j) {
     99         // Calculate index.
    100         const int64 index = i + j * padded_filter_inner_dim_size;
    101         // Load filter.
    102         // TODO(andydavis) Unroll 'out_c' loop in caller so we can load
    103         // multiple inputs here to amortize the cost of each filter block load.
    104         const auto filter_block =
    105             Eigen::internal::ploadu<Packet>(filter + index);
    106         // Load input.
    107         const auto data_block =
    108             Eigen::internal::ploadu<Packet>(input_buffer + index);
    109         // Vector multiply-add.
    110         vaccum =
    111             Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
    112       }
    113       // Store vector accumulator to output.
    114       Eigen::internal::pstoreu<T>(output + base_output_index + i, vaccum);
    115     }
    116 
    117     if (output_scalar_size > 0) {
    118       auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
    119       for (int j = 0; j < filter_spatial_size; ++j) {
    120         const int64 index =
    121             output_vectorized_size + j * padded_filter_inner_dim_size;
    122         const auto filter_block =
    123             Eigen::internal::ploadu<Packet>(filter + index);
    124         const auto data_block =
    125             Eigen::internal::ploadu<Packet>(input_buffer + index);
    126         vaccum =
    127             Eigen::internal::pmadd<Packet>(filter_block, data_block, vaccum);
    128       }
    129       // Load accumulator into an array and loop through output.
    130       T out_buf[kPacketSize];
    131       Eigen::internal::pstoreu<T>(out_buf, vaccum);
    132       const int64 last_output_index =
    133           base_output_index + output_vectorized_size;
    134       for (int j = 0; j < output_scalar_size; ++j) {
    135         output[last_output_index + j] = out_buf[j];
    136       }
    137     }
    138   }
    139 };
    140 
    141 // Computes the depthwise conv2d of 'input' by 'depthwise_filter' and stores
    142 // the result in 'output'. This implementation trades off copying small patches
    143 // of the input to achieve better data alignment, which enables vectorized
    144 // load/store and multiply-add operations (see comments at InputBufferCopyOp and
    145 // DepthwiseConv2DKernel for details).
    146 //
    147 // TODO(andydavis) Evaluate the performance of processing multiple input
    148 // patches in the inner loop.
    149 // TODO(andydavis) Consider a zero-copy implementation for the case when
    150 // 'in_depth' is a multiple of register width, and 'depth_multipler' is one.
    151 // TODO(andydavis) Evaluate the performance of alternative implementations.
    152 template <typename T>
    153 struct LaunchDepthwiseConvOp<CPUDevice, T> {
    154   typedef typename Eigen::internal::packet_traits<T>::type Packet;
    155 
    156   void operator()(OpKernelContext* ctx, const DepthwiseArgs& args,
    157                   const T* input, const T* depthwise_filter, T* output,
    158                   TensorFormat data_format) {
    159     OP_REQUIRES(
    160         ctx, data_format == FORMAT_NHWC,
    161         errors::Unimplemented(
    162             "Depthwise convolution on CPU is only supported for NHWC format"));
    163     static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
    164 
    165     // Pad 'depthwise_filter' to vector register width (if needed).
    166     const bool pad_filter = (args.out_depth % kPacketSize) == 0 ? false : true;
    167     Tensor padded_filter;
    168     if (pad_filter) {
    169       // Allocate space for padded filter.
    170       const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
    171       const int64 padded_filter_inner_dim_size =
    172           ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
    173       OP_REQUIRES_OK(
    174           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    175                                   TensorShape({filter_spatial_size,
    176                                                padded_filter_inner_dim_size}),
    177                                   &padded_filter));
    178       // Write out padded filter.
    179       functor::DepthwiseFilterPadOp<T>()(
    180           args, depthwise_filter, padded_filter.template flat<T>().data());
    181     }
    182     const T* filter_data =
    183         pad_filter ? padded_filter.template flat<T>().data() : depthwise_filter;
    184 
    185     // Computes one shard of depthwise conv2d output.
    186     auto shard = [&ctx, &args, &input, &filter_data, &output, data_format](
    187                      int64 start, int64 limit) {
    188       static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
    189       const int64 input_image_size =
    190           args.in_rows * args.in_cols * args.in_depth;
    191       const int64 output_image_size =
    192           args.out_rows * args.out_cols * args.out_depth;
    193       const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
    194       const int64 padded_filter_inner_dim_size =
    195           ((args.out_depth + kPacketSize - 1) / kPacketSize) * kPacketSize;
    196 
    197       // Allocate buffer for local input regions.
    198       Tensor input_buffer;
    199       OP_REQUIRES_OK(
    200           ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
    201                                   TensorShape({filter_spatial_size,
    202                                                padded_filter_inner_dim_size}),
    203                                   &input_buffer));
    204       T* input_buffer_data = input_buffer.template flat<T>().data();
    205 
    206       for (int64 i = start; i < limit; ++i) {
    207         const int64 b = i / args.out_rows;
    208         const int64 in_base = b * input_image_size;
    209         const int64 out_base = b * output_image_size;
    210 
    211         const int64 out_r = i % args.out_rows;
    212 
    213         for (int64 out_c = 0; out_c < args.out_cols; ++out_c) {
    214           // Populate 'input_buffer_data' with data from local input region.
    215           functor::DepthwiseInputCopyOp<T>()(args, padded_filter_inner_dim_size,
    216                                              out_r, out_c, input + in_base,
    217                                              input_buffer_data);
    218 
    219           // Process buffered input across all filters and store to output.
    220           DepthwiseConv2DKernel<T>::Run(
    221               args, padded_filter_inner_dim_size, out_r, out_c, filter_data,
    222               input_buffer_data, output + out_base, data_format);
    223         }
    224       }
    225     };
    226 
    227     const int64 total_shards = args.batch * args.out_rows;
    228 
    229     // Empirically tested to give reasonable performance boosts at batch size 1
    230     // without reducing throughput at batch size 32.
    231     const float kCostMultiplier = 2.5f;
    232 
    233     // TODO(andydavis): Estimate shard cost (in cycles) based on the number of
    234     // flops/loads/stores required to compute one shard.
    235     const int64 shard_cost = kCostMultiplier * args.out_cols * args.out_depth;
    236 
    237     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
    238     Shard(worker_threads.num_threads, worker_threads.workers, total_shards,
    239           shard_cost, shard);
    240   }
    241 };
    242 
    243 // Extern template instantiated in conv_ops.cc.
    244 extern template class LaunchConv2DOp<CPUDevice, float>;
    245 
    246 #if GOOGLE_CUDA
    247 
    248 // Extern template instantiated in depthwise_conv_op_gpu.cc.
    249 extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
    250 extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
    251 extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
    252 
    253 // Extern template instantiated in conv_ops.cc.
    254 extern template class LaunchConv2DOp<GPUDevice, float>;
    255 
    256 #endif
    257 
    258 template <typename Device, typename T>
    259 class DepthwiseConv2dNativeOp : public BinaryOp<T> {
    260  public:
    261   explicit DepthwiseConv2dNativeOp(OpKernelConstruction* context)
    262       : BinaryOp<T>(context) {
    263     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    264     string data_format;
    265     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    266     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    267                 errors::InvalidArgument("Invalid data format"));
    268 
    269     OP_REQUIRES(context, strides_.size() == 4,
    270                 errors::InvalidArgument("Sliding window strides field must "
    271                                         "specify 4 dimensions"));
    272     stride_ = GetTensorDim(strides_, data_format_, 'H');
    273     const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
    274     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
    275     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
    276 
    277     OP_REQUIRES(context, stride_ == stride_w,
    278                 errors::InvalidArgument(
    279                     "Current implementation only supports equal length "
    280                     "strides in the row and column dimensions."));
    281     OP_REQUIRES(
    282         context, (stride_n == 1 && stride_c == 1),
    283         errors::InvalidArgument("Current implementation does not yet support "
    284                                 "strides in the batch and depth dimensions."));
    285     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    286 
    287     // For special case when in_depth == 1.
    288     use_cudnn_ = CanUseCudnn();
    289     cudnn_use_autotune_ = CudnnUseAutotune();
    290   }
    291 
    292   void Compute(OpKernelContext* context) override {
    293     // Input tensor is of the following dimensions:
    294     // [ batch, in_rows, in_cols, in_depth ]
    295     const Tensor& input = context->input(0);
    296 
    297     // Input filter is of the following dimensions:
    298     // [ filter_rows, filter_cols, in_depth, depth_multiplier]
    299     const Tensor& filter = context->input(1);
    300 
    301     // For 2D convolution, there should be 4 dimensions.
    302     OP_REQUIRES(context, input.dims() == 4,
    303                 errors::InvalidArgument("input must be 4-dimensional",
    304                                         input.shape().DebugString()));
    305     OP_REQUIRES(context, filter.dims() == 4,
    306                 errors::InvalidArgument("filter must be 4-dimensional: ",
    307                                         filter.shape().DebugString()));
    308 
    309     // in_depth for input and filter must match.
    310     const int64 in_depth = GetTensorDim(input, data_format_, 'C');
    311     OP_REQUIRES(context, in_depth == filter.dim_size(2),
    312                 errors::InvalidArgument(
    313                     "input and filter must have the same depth: ", in_depth,
    314                     " vs ", filter.dim_size(2)));
    315 
    316     // The last dimension for filter is depth multiplier.
    317     const int32 depth_multiplier = filter.dim_size(3);
    318 
    319     // The output depth is input depth x depth multipler
    320     const int32 out_depth = in_depth * depth_multiplier;
    321 
    322     const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
    323     OP_REQUIRES(
    324         context,
    325         FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()),
    326         errors::InvalidArgument("Input rows too large"));
    327     const int32 input_rows = static_cast<int32>(input_rows_raw);
    328     const int32 filter_rows = filter.dim_size(0);
    329 
    330     const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
    331     OP_REQUIRES(
    332         context,
    333         FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()),
    334         errors::InvalidArgument("Input cols too large"));
    335     const int32 input_cols = static_cast<int32>(input_cols_raw);
    336     const int32 filter_cols = filter.dim_size(1);
    337 
    338     // The first dimension for input is batch.
    339     const int32 batch = input.dim_size(0);
    340 
    341     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
    342     OP_REQUIRES_OK(context,
    343                    GetWindowedOutputSize(input_rows, filter_rows, stride_,
    344                                          padding_, &out_rows, &pad_rows));
    345     OP_REQUIRES_OK(context,
    346                    GetWindowedOutputSize(input_cols, filter_cols, stride_,
    347                                          padding_, &out_cols, &pad_cols));
    348     TensorShape out_shape =
    349         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
    350     OP_REQUIRES(
    351         context,
    352         (!std::is_same<Device, GPUDevice>::value ||
    353          FastBoundsCheck(out_shape.num_elements(),
    354                          std::numeric_limits<int32>::max())),
    355         errors::InvalidArgument("Output elements too large for GPU kernel"));
    356 
    357     Tensor* output = nullptr;
    358     OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
    359 
    360     VLOG(2) << "DepthwiseConv2dNative: "
    361             << " Input: [" << batch << ", " << input_rows << ", " << input_cols
    362             << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
    363             << filter_cols << ", " << in_depth << ", " << depth_multiplier
    364             << "]; stride = " << stride_ << ", pad_rows = " << pad_rows
    365             << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
    366             << out_rows << ", " << out_cols << ", " << out_depth << "]";
    367 
    368     // If there is nothing to compute, return.
    369     if (out_shape.num_elements() == 0) {
    370       return;
    371     }
    372 
    373     // If in_depth==1, this operation is just a standard convolution, so
    374     // invoke that op.
    375     if (std::is_same<T, float>::value && in_depth == 1) {
    376       // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
    377       // conv is supported.
    378       launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
    379                 /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
    380                 padding_, output, data_format_);
    381       return;
    382     }
    383 
    384     DepthwiseArgs args;
    385     args.batch = batch;
    386     args.in_rows = input_rows;
    387     args.in_cols = input_cols;
    388     args.in_depth = in_depth;
    389     args.filter_rows = filter_rows;
    390     args.filter_cols = filter_cols;
    391     args.depth_multiplier = depth_multiplier;
    392     args.stride = stride_;
    393     args.pad_rows = pad_rows;
    394     args.pad_cols = pad_cols;
    395     args.out_rows = out_rows;
    396     args.out_cols = out_cols;
    397     args.out_depth = out_depth;
    398 
    399     auto input_ptr = input.template flat<T>().data();
    400     auto filter_ptr = filter.template flat<T>().data();
    401     auto output_ptr = output->template flat<T>().data();
    402     LaunchDepthwiseConvOp<Device, T>()(context, args, input_ptr, filter_ptr,
    403                                        output_ptr, data_format_);
    404   }
    405 
    406  private:
    407   std::vector<int32> strides_;
    408   Padding padding_;
    409   TensorFormat data_format_;
    410 
    411   int64 stride_;  // in height/width dimension.
    412 
    413   // For the case in_depth == 1.
    414   LaunchConv2DOp<Device, T> launcher_;
    415   bool use_cudnn_;
    416   bool cudnn_use_autotune_;
    417 
    418   TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
    419 };
    420 
    421 #define REGISTER_CPU_KERNEL(T)                                                 \
    422   REGISTER_KERNEL_BUILDER(                                                     \
    423       Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
    424       DepthwiseConv2dNativeOp<CPUDevice, T>);
    425 
    426 TF_CALL_half(REGISTER_CPU_KERNEL);
    427 TF_CALL_float(REGISTER_CPU_KERNEL);
    428 #if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
    429 TF_CALL_double(REGISTER_CPU_KERNEL);
    430 #endif
    431 
    432 #if GOOGLE_CUDA
    433 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
    434                             .Device(DEVICE_GPU)
    435                             .TypeConstraint<Eigen::half>("T"),
    436                         DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>);
    437 
    438 REGISTER_KERNEL_BUILDER(
    439     Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    440     DepthwiseConv2dNativeOp<GPUDevice, float>);
    441 
    442 REGISTER_KERNEL_BUILDER(Name("DepthwiseConv2dNative")
    443                             .Device(DEVICE_GPU)
    444                             .TypeConstraint<double>("T"),
    445                         DepthwiseConv2dNativeOp<GPUDevice, double>);
    446 #endif
    447 
    448 }  // namespace tensorflow
    449