Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include <algorithm>
     18 #include <cmath>
     19 #include <random>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/kernels/fractional_pool_common.h"
     23 
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/lib/random/random.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/util/guarded_philox_random.h"
     31 
     32 namespace tensorflow {
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 
     35 template <typename T>
     36 class FractionalMaxPoolOp : public OpKernel {
     37  public:
     38   explicit FractionalMaxPoolOp(OpKernelConstruction* context)
     39       : OpKernel(context) {
     40     OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
     41     OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
     42     OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
     43 
     44     OP_REQUIRES(context, pooling_ratio_.size() == 4,
     45                 errors::InvalidArgument("pooling_ratio field must "
     46                                         "specify 4 dimensions"));
     47 
     48     OP_REQUIRES(
     49         context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
     50         errors::Unimplemented("Fractional max pooling is not yet "
     51                               "supported on the batch nor channel dimension."));
     52 
     53     OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
     54     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
     55     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
     56     if (deterministic_) {
     57       // If both seeds are not set when deterministic_ is true, force set seeds.
     58       if ((seed_ == 0) && (seed2_ == 0)) {
     59         seed_ = random::New64();
     60         seed2_ = random::New64();
     61       }
     62     } else {
     63       OP_REQUIRES(
     64           context, (seed_ == 0) && (seed2_ == 0),
     65           errors::InvalidArgument(
     66               "Both seed and seed2 should be 0 if deterministic is false."));
     67     }
     68   }
     69 
     70   void Compute(OpKernelContext* context) override {
     71     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
     72         ConstEigenMatrixMap;
     73     typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
     74         EigenMatrixMap;
     75 
     76     constexpr int tensor_in_and_out_dims = 4;
     77 
     78     const Tensor& tensor_in = context->input(0);
     79     OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
     80                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
     81 
     82     std::vector<int> input_size(tensor_in_and_out_dims);
     83     std::vector<int> output_size(tensor_in_and_out_dims);
     84     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
     85       input_size[i] = tensor_in.dim_size(i);
     86     }
     87     // Output size.
     88     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
     89       // This must match the same logic in the shape function in
     90       // core/ops/nn_ops.cc.
     91       output_size[i] =
     92           static_cast<int>(floor(input_size[i] / pooling_ratio_[i]));
     93       DCHECK_GT(output_size[i], 0);
     94     }
     95 
     96     // Generate pooling sequence.
     97     std::vector<int64> height_cum_seq;
     98     std::vector<int64> width_cum_seq;
     99     GuardedPhiloxRandom generator;
    100     generator.Init(seed_, seed2_);
    101     height_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
    102                                              &generator, pseudo_random_);
    103     width_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
    104                                             &generator, pseudo_random_);
    105 
    106     // Prepare output.
    107     Tensor* output_tensor = nullptr;
    108     OP_REQUIRES_OK(context, context->allocate_output(
    109                                 0,
    110                                 TensorShape({output_size[0], output_size[1],
    111                                              output_size[2], output_size[3]}),
    112                                 &output_tensor));
    113     Tensor* output_height_seq_tensor = nullptr;
    114     OP_REQUIRES_OK(
    115         context,
    116         context->allocate_output(
    117             1, TensorShape({static_cast<int64>(height_cum_seq.size())}),
    118             &output_height_seq_tensor));
    119     Tensor* output_width_seq_tensor = nullptr;
    120     OP_REQUIRES_OK(
    121         context, context->allocate_output(
    122                      2, TensorShape({static_cast<int64>(width_cum_seq.size())}),
    123                      &output_width_seq_tensor));
    124 
    125     ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
    126                                input_size[2] * input_size[1] * input_size[0]);
    127 
    128     EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
    129                            output_size[2] * output_size[1] * output_size[0]);
    130 
    131     // Initializes the output tensor with MIN<T>.
    132     output_tensor->flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
    133 
    134     auto output_height_seq_flat = output_height_seq_tensor->flat<int64>();
    135     auto output_width_seq_flat = output_width_seq_tensor->flat<int64>();
    136 
    137     // Set output tensors.
    138     for (int i = 0; i < height_cum_seq.size(); ++i) {
    139       output_height_seq_flat(i) = height_cum_seq[i];
    140     }
    141 
    142     for (int i = 0; i < width_cum_seq.size(); ++i) {
    143       output_width_seq_flat(i) = width_cum_seq[i];
    144     }
    145 
    146     // For both input and output,
    147     // 0: batch
    148     // 1: height / row
    149     // 2: width / col
    150     // 3: depth / channel
    151     const int64 height_max = input_size[1] - 1;
    152     const int64 width_max = input_size[2] - 1;
    153     for (int64 b = 0; b < input_size[0]; ++b) {
    154       // height sequence.
    155       for (int64 hs = 0; hs < height_cum_seq.size() - 1; ++hs) {
    156         // height start and end.
    157         const int64 height_start = height_cum_seq[hs];
    158         int64 height_end =
    159             overlapping_ ? height_cum_seq[hs + 1] : height_cum_seq[hs + 1] - 1;
    160         height_end = std::min(height_end, height_max);
    161 
    162         // width sequence.
    163         for (int64 ws = 0; ws < width_cum_seq.size() - 1; ++ws) {
    164           const int64 out_offset =
    165               (b * output_size[1] + hs) * output_size[2] + ws;
    166           // width start and end.
    167           const int64 width_start = width_cum_seq[ws];
    168           int64 width_end =
    169               overlapping_ ? width_cum_seq[ws + 1] : width_cum_seq[ws + 1] - 1;
    170           width_end = std::min(width_end, width_max);
    171           for (int64 h = height_start; h <= height_end; ++h) {
    172             for (int64 w = width_start; w <= width_end; ++w) {
    173               const int64 in_offset =
    174                   (b * input_size[1] + h) * input_size[2] + w;
    175               out_mat.col(out_offset) =
    176                   out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset));
    177             }
    178           }
    179         }
    180       }
    181     }
    182   }
    183 
    184  private:
    185   bool deterministic_;
    186   int64 seed_;
    187   int64 seed2_;
    188   std::vector<float> pooling_ratio_;
    189   bool pseudo_random_;
    190   bool overlapping_;
    191 };
    192 
    193 #define REGISTER_FRACTIONALMAXPOOL(type)                                      \
    194   REGISTER_KERNEL_BUILDER(                                                    \
    195       Name("FractionalMaxPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    196       FractionalMaxPoolOp<type>)
    197 
    198 REGISTER_FRACTIONALMAXPOOL(int32);
    199 REGISTER_FRACTIONALMAXPOOL(int64);
    200 REGISTER_FRACTIONALMAXPOOL(float);
    201 REGISTER_FRACTIONALMAXPOOL(double);
    202 
    203 #undef REGISTER_FRACTIONALMAXPOOL
    204 
    205 static const int kInvalidMaxPoolingIndex = -1;
    206 
    207 template <class T>
    208 class FractionalMaxPoolGradOp : public OpKernel {
    209  public:
    210   explicit FractionalMaxPoolGradOp(OpKernelConstruction* context)
    211       : OpKernel(context) {
    212     OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
    213   }
    214 
    215   void Compute(OpKernelContext* context) override {
    216     // There are two steps when calculating gradient for FractionalMaxPool.
    217     // 1) Walk through the process of calculating fractional pooling given
    218     //    pooling region; however, in the process, keep track of where the max
    219     //    element comes from. (arg_max)
    220     // 2) Populate the value of out_backprop to where arg_max indicates. If
    221     //    we support overlapping, it is likely to have multiple out_backprop[i]
    222     //    propagates back to the same arg_max value.
    223     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    224         ConstEigenMatrixMap;
    225     typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    226         EigenMatrixMap;
    227     typedef Eigen::Map<Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic>>
    228         EigenIndexMatrixMap;
    229 
    230     const Tensor& tensor_in = context->input(0);
    231     const Tensor& tensor_out = context->input(1);
    232     const Tensor& out_backprop = context->input(2);
    233     const Tensor& height_seq_tensor = context->input(3);
    234     const Tensor& width_seq_tensor = context->input(4);
    235 
    236     // Just to make it similar to FractionalMaxPoolOp.
    237     constexpr int tensor_in_and_out_dims = 4;
    238     std::vector<int64> input_size(tensor_in_and_out_dims);
    239     std::vector<int64> output_size(tensor_in_and_out_dims);
    240     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
    241       input_size[i] = tensor_in.dim_size(i);
    242     }
    243     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
    244       output_size[i] = tensor_out.dim_size(i);
    245     }
    246 
    247     // ---------
    248     // Step 1
    249     // ---------
    250     Tensor tensor_out_dup;
    251     OP_REQUIRES_OK(context, context->forward_input_or_allocate_temp(
    252                                 {1}, DataTypeToEnum<T>::v(), tensor_out.shape(),
    253                                 &tensor_out_dup));
    254     Tensor tensor_out_arg_max;
    255     OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<int64>::v(),
    256                                                    tensor_out.shape(),
    257                                                    &tensor_out_arg_max));
    258     // Find arg_max for each tensor_out
    259     ConstEigenMatrixMap tensor_in_mat(
    260         tensor_in.flat<T>().data(), input_size[3],
    261         input_size[2] * input_size[1] * input_size[0]);
    262     EigenMatrixMap tensor_out_dup_mat(
    263         tensor_out_dup.flat<T>().data(), output_size[3],
    264         output_size[2] * output_size[1] * output_size[0]);
    265     EigenIndexMatrixMap tensor_out_arg_max_mat(
    266         tensor_out_arg_max.flat<int64>().data(), output_size[3],
    267         output_size[2] * output_size[1] * output_size[0]);
    268 
    269     tensor_out_arg_max.flat<int64>().setConstant(kInvalidMaxPoolingIndex);
    270     // Initializes the duplicate output tensor with MIN<T>.
    271     tensor_out_dup.flat<T>().setConstant(Eigen::NumTraits<T>::lowest());
    272 
    273     auto height_seq_tensor_flat = height_seq_tensor.flat<int64>();
    274     auto width_seq_tensor_flat = width_seq_tensor.flat<int64>();
    275 
    276     // Now walk through the process of fractional max pooling again.
    277     // For both input and output,
    278     // 0: batch
    279     // 1: height / row
    280     // 2: width / col
    281     // 3: depth / channel
    282     const int64 height_max = input_size[1] - 1;
    283     const int64 width_max = input_size[2] - 1;
    284     for (int64 b = 0; b < input_size[0]; ++b) {
    285       // height sequence.
    286       for (int64 hs = 0; hs < height_seq_tensor.dim_size(0) - 1; ++hs) {
    287         // height start and end.
    288         const int64 height_start = height_seq_tensor_flat(hs);
    289         int64 height_end = overlapping_ ? height_seq_tensor_flat(hs + 1)
    290                                         : height_seq_tensor_flat(hs + 1) - 1;
    291         height_end = std::min(height_end, height_max);
    292 
    293         // width sequence.
    294         for (int64 ws = 0; ws < width_seq_tensor.dim_size(0) - 1; ++ws) {
    295           const int64 out_index =
    296               (b * output_size[1] + hs) * output_size[2] + ws;
    297           // width start and end.
    298           const int64 width_start = width_seq_tensor_flat(ws);
    299           int64 width_end = overlapping_ ? width_seq_tensor_flat(ws + 1)
    300                                          : width_seq_tensor_flat(ws + 1) - 1;
    301           width_end = std::min(width_end, width_max);
    302           for (int64 h = height_start; h <= height_end; ++h) {
    303             for (int64 w = width_start; w <= width_end; ++w) {
    304               const int64 in_index =
    305                   (b * input_size[1] + h) * input_size[2] + w;
    306               // Walk through each channel (depth).
    307               for (int64 d = 0; d < input_size[3]; ++d) {
    308                 const T& input_ref = tensor_in_mat.coeffRef(d, in_index);
    309                 T& output_ref = tensor_out_dup_mat.coeffRef(d, out_index);
    310                 int64& out_arg_max_ref =
    311                     tensor_out_arg_max_mat.coeffRef(d, out_index);
    312                 if (output_ref < input_ref ||
    313                     out_arg_max_ref == kInvalidMaxPoolingIndex) {
    314                   output_ref = input_ref;
    315                   int input_offset = in_index * input_size[3] + d;
    316                   out_arg_max_ref = input_offset;
    317                 }
    318               }
    319             }
    320           }
    321         }
    322       }
    323     }
    324 
    325     // Check tensor_out_dup is the same as tensor_out.
    326     ConstEigenMatrixMap tensor_out_mat(
    327         tensor_out.flat<T>().data(), output_size[3],
    328         output_size[2] * output_size[1] * output_size[0]);
    329     const int64 num_reshaped_cols =
    330         output_size[2] * output_size[1] * output_size[0];
    331     for (int64 i = 0; i < num_reshaped_cols; ++i) {
    332       for (int64 j = 0; j < output_size[3]; ++j) {
    333         DCHECK_EQ(tensor_out_dup_mat(j, i), tensor_out_mat(j, i));
    334       }
    335     }
    336 
    337     Tensor* output = nullptr;
    338     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    339                                 {0}, 0, tensor_in.shape(), &output));
    340     output->flat<T>().setZero();
    341 
    342     auto out_backprop_flat = out_backprop.flat<T>();
    343     auto input_backprop_flat = output->flat<T>();
    344     auto out_arg_max_flat = tensor_out_arg_max.flat<int64>();
    345     int num_total_outputs = out_backprop_flat.size();
    346     int num_total_inputs = input_backprop_flat.size();
    347 
    348     for (int index = 0; index < num_total_outputs; ++index) {
    349       int input_backprop_index = out_arg_max_flat(index);
    350       // According to maxpooling_op.cc, the performance impact below is small.
    351       CHECK(input_backprop_index >= 0 &&
    352             input_backprop_index < num_total_inputs)
    353           << "Invalid input backprop index: " << input_backprop_index << ", "
    354           << num_total_inputs;
    355       input_backprop_flat(input_backprop_index) += out_backprop_flat(index);
    356     }
    357   }
    358 
    359  private:
    360   bool overlapping_;
    361 };
    362 
    363 #define REGISTER_FRACTIONALMAXPOOLGRAD(type)              \
    364   REGISTER_KERNEL_BUILDER(Name("FractionalMaxPoolGrad")   \
    365                               .Device(DEVICE_CPU)         \
    366                               .TypeConstraint<type>("T"), \
    367                           FractionalMaxPoolGradOp<type>)
    368 
    369 REGISTER_FRACTIONALMAXPOOLGRAD(int32);
    370 REGISTER_FRACTIONALMAXPOOLGRAD(int64);
    371 REGISTER_FRACTIONALMAXPOOLGRAD(float);
    372 REGISTER_FRACTIONALMAXPOOLGRAD(double);
    373 
    374 #undef REGISTER_FRACTIONALMAXPOOLGRAD
    375 }  // namespace tensorflow
    376