Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #define EIGEN_USE_THREADS
     16 
     17 #include <algorithm>
     18 #include <cmath>
     19 #include <random>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/kernels/fractional_pool_common.h"
     23 
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/lib/random/random.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/util/guarded_philox_random.h"
     31 
     32 namespace tensorflow {
     33 typedef Eigen::ThreadPoolDevice CPUDevice;
     34 
     35 template <typename T>
     36 class FractionalAvgPoolOp : public OpKernel {
     37  public:
     38   explicit FractionalAvgPoolOp(OpKernelConstruction* context)
     39       : OpKernel(context) {
     40     OP_REQUIRES_OK(context, context->GetAttr("pooling_ratio", &pooling_ratio_));
     41     OP_REQUIRES_OK(context, context->GetAttr("pseudo_random", &pseudo_random_));
     42     OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
     43     OP_REQUIRES(context, pooling_ratio_.size() == 4,
     44                 errors::InvalidArgument(
     45                     "pooling_ratio field must specify 4 dimensions"));
     46     OP_REQUIRES(
     47         context, pooling_ratio_[0] == 1 || pooling_ratio_[3] == 1,
     48         errors::Unimplemented("Fractional average pooling is not yet "
     49                               "supported on the batch nor channel dimension."));
     50     OP_REQUIRES_OK(context, context->GetAttr("deterministic", &deterministic_));
     51     OP_REQUIRES_OK(context, context->GetAttr("seed", &seed_));
     52     OP_REQUIRES_OK(context, context->GetAttr("seed2", &seed2_));
     53     if (deterministic_) {
     54       // If both seeds are not set when deterministic_ is true, force set seeds.
     55       if ((seed_ == 0) && (seed2_ == 0)) {
     56         seed_ = random::New64();
     57         seed2_ = random::New64();
     58       }
     59     } else {
     60       OP_REQUIRES(
     61           context, (seed_ == 0) && (seed2_ == 0),
     62           errors::InvalidArgument(
     63               "Both seed and seed2 should be 0 if deterministic is false."));
     64     }
     65   }
     66 
     67   void Compute(OpKernelContext* context) override {
     68     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
     69         ConstEigenMatrixMap;
     70     typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
     71         EigenMatrixMap;
     72 
     73     constexpr int tensor_in_and_out_dims = 4;
     74 
     75     const Tensor& tensor_in = context->input(0);
     76     OP_REQUIRES(context, tensor_in.dims() == tensor_in_and_out_dims,
     77                 errors::InvalidArgument("tensor_in must be 4-dimensional"));
     78 
     79     std::vector<int> input_size(tensor_in_and_out_dims);
     80     std::vector<int> output_size(tensor_in_and_out_dims);
     81     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
     82       input_size[i] = tensor_in.dim_size(i);
     83     }
     84     // Output size.
     85     for (int i = 0; i < tensor_in_and_out_dims; ++i) {
     86       output_size[i] =
     87           static_cast<int>(floor(input_size[i] / pooling_ratio_[i]));
     88       DCHECK_GT(output_size[i], 0);
     89     }
     90 
     91     // Generate pooling sequence.
     92     std::vector<int64> row_cum_seq;
     93     std::vector<int64> col_cum_seq;
     94     GuardedPhiloxRandom generator;
     95     generator.Init(seed_, seed2_);
     96     row_cum_seq = GeneratePoolingSequence(input_size[1], output_size[1],
     97                                           &generator, pseudo_random_);
     98     col_cum_seq = GeneratePoolingSequence(input_size[2], output_size[2],
     99                                           &generator, pseudo_random_);
    100 
    101     // Prepare output.
    102     Tensor* output_tensor = nullptr;
    103     OP_REQUIRES_OK(context, context->allocate_output(
    104                                 0,
    105                                 TensorShape({output_size[0], output_size[1],
    106                                              output_size[2], output_size[3]}),
    107                                 &output_tensor));
    108     Tensor* output_row_seq_tensor = nullptr;
    109     OP_REQUIRES_OK(context,
    110                    context->allocate_output(
    111                        1, TensorShape({static_cast<int64>(row_cum_seq.size())}),
    112                        &output_row_seq_tensor));
    113     Tensor* output_col_seq_tensor = nullptr;
    114     OP_REQUIRES_OK(context,
    115                    context->allocate_output(
    116                        2, TensorShape({static_cast<int64>(col_cum_seq.size())}),
    117                        &output_col_seq_tensor));
    118 
    119     ConstEigenMatrixMap in_mat(tensor_in.flat<T>().data(), input_size[3],
    120                                input_size[2] * input_size[1] * input_size[0]);
    121 
    122     EigenMatrixMap out_mat(output_tensor->flat<T>().data(), output_size[3],
    123                            output_size[2] * output_size[1] * output_size[0]);
    124     // out_count corresponds to number of elements in each pooling cell.
    125     Eigen::Matrix<T, Eigen::Dynamic, 1> out_count(out_mat.cols());
    126 
    127     // Initializes the output tensor and out_count with 0.
    128     out_mat.setZero();
    129     out_count.setZero();
    130 
    131     auto output_row_seq_flat = output_row_seq_tensor->flat<int64>();
    132     auto output_col_seq_flat = output_col_seq_tensor->flat<int64>();
    133 
    134     // Set output tensors.
    135     for (int i = 0; i < row_cum_seq.size(); ++i) {
    136       output_row_seq_flat(i) = row_cum_seq[i];
    137     }
    138 
    139     for (int i = 0; i < col_cum_seq.size(); ++i) {
    140       output_col_seq_flat(i) = col_cum_seq[i];
    141     }
    142 
    143     // For both input and output,
    144     // 0: batch
    145     // 1: row / row
    146     // 2: col / col
    147     // 3: depth / channel
    148     const int64 row_max = input_size[1] - 1;
    149     const int64 col_max = input_size[2] - 1;
    150     for (int64 b = 0; b < input_size[0]; ++b) {
    151       // row sequence.
    152       for (int64 hs = 0; hs < row_cum_seq.size() - 1; ++hs) {
    153         // row start and end.
    154         const int64 row_start = row_cum_seq[hs];
    155         int64 row_end =
    156             overlapping_ ? row_cum_seq[hs + 1] : row_cum_seq[hs + 1] - 1;
    157         row_end = std::min(row_end, row_max);
    158 
    159         // col sequence.
    160         for (int64 ws = 0; ws < col_cum_seq.size() - 1; ++ws) {
    161           const int64 out_offset =
    162               (b * output_size[1] + hs) * output_size[2] + ws;
    163           // col start and end.
    164           const int64 col_start = col_cum_seq[ws];
    165           int64 col_end =
    166               overlapping_ ? col_cum_seq[ws + 1] : col_cum_seq[ws + 1] - 1;
    167           col_end = std::min(col_end, col_max);
    168           for (int64 h = row_start; h <= row_end; ++h) {
    169             for (int64 w = col_start; w <= col_end; ++w) {
    170               const int64 in_offset =
    171                   (b * input_size[1] + h) * input_size[2] + w;
    172               out_mat.col(out_offset) += in_mat.col(in_offset);
    173               out_count(out_offset)++;
    174             }
    175           }
    176         }
    177       }
    178     }
    179     DCHECK_GT(out_count.minCoeff(), 0);
    180     out_mat.array().rowwise() /= out_count.transpose().array();
    181   }
    182 
    183  private:
    184   bool deterministic_;
    185   int64 seed_;
    186   int64 seed2_;
    187   std::vector<float> pooling_ratio_;
    188   bool pseudo_random_;
    189   bool overlapping_;
    190 };
    191 
    192 #define REGISTER_FRACTIONALAVGPOOL(type)                                      \
    193   REGISTER_KERNEL_BUILDER(                                                    \
    194       Name("FractionalAvgPool").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    195       FractionalAvgPoolOp<type>)
    196 
    197 REGISTER_FRACTIONALAVGPOOL(int32);
    198 REGISTER_FRACTIONALAVGPOOL(int64);
    199 REGISTER_FRACTIONALAVGPOOL(float);
    200 REGISTER_FRACTIONALAVGPOOL(double);
    201 
    202 #undef REGISTER_FRACTIONALAVGPOOL
    203 
    204 template <class T>
    205 class FractionalAvgPoolGradOp : public OpKernel {
    206  public:
    207   explicit FractionalAvgPoolGradOp(OpKernelConstruction* context)
    208       : OpKernel(context) {
    209     OP_REQUIRES_OK(context, context->GetAttr("overlapping", &overlapping_));
    210   }
    211 
    212   void Compute(OpKernelContext* context) override {
    213     // Here's the basic idea:
    214     // Batch and depth dimension are independent from row and col dimension. And
    215     // because FractionalAvgPool currently only support pooling along row and
    216     // col, we can basically think of this 4D tensor backpropagation as
    217     // operation of a series of 2D planes.
    218     //
    219     // For each element of a 'slice' (2D plane) of output_backprop, we need to
    220     // figure out its contributors when doing FractionalAvgPool operation. This
    221     // can be done based on row_pooling_sequence, col_pooling_seq and
    222     // overlapping.
    223     // Once we figure out the original contributors, we just need to evenly
    224     // divide the value of this element among these contributors.
    225     //
    226     // Internally, we divide the out_backprop tensor and store it in a temparary
    227     // tensor of double type. And cast it to the corresponding type.
    228     typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
    229         ConstEigenMatrixMap;
    230     typedef Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>>
    231         EigenDoubleMatrixMap;
    232 
    233     // Grab the inputs.
    234     const Tensor& orig_input_tensor_shape = context->input(0);
    235     OP_REQUIRES(context,
    236                 orig_input_tensor_shape.dims() == 1 &&
    237                     orig_input_tensor_shape.NumElements() == 4,
    238                 errors::InvalidArgument("original input tensor shape must be"
    239                                         "1-dimensional and 4 elements"));
    240     const Tensor& out_backprop = context->input(1);
    241     const Tensor& row_seq_tensor = context->input(2);
    242     const Tensor& col_seq_tensor = context->input(3);
    243 
    244     const int64 out_batch = out_backprop.dim_size(0);
    245     const int64 out_rows = out_backprop.dim_size(1);
    246     const int64 out_cols = out_backprop.dim_size(2);
    247     const int64 out_depth = out_backprop.dim_size(3);
    248 
    249     auto row_seq_tensor_flat = row_seq_tensor.flat<int64>();
    250     auto col_seq_tensor_flat = col_seq_tensor.flat<int64>();
    251     auto orig_input_tensor_shape_flat = orig_input_tensor_shape.flat<int64>();
    252 
    253     const int64 in_batch = orig_input_tensor_shape_flat(0);
    254     const int64 in_rows = orig_input_tensor_shape_flat(1);
    255     const int64 in_cols = orig_input_tensor_shape_flat(2);
    256     const int64 in_depth = orig_input_tensor_shape_flat(3);
    257 
    258     constexpr int tensor_in_and_out_dims = 4;
    259     // Transform orig_input_tensor_shape into TensorShape
    260     TensorShape in_shape;
    261     for (auto i = 0; i < tensor_in_and_out_dims; ++i) {
    262       in_shape.AddDim(orig_input_tensor_shape_flat(i));
    263     }
    264 
    265     // Create intermediate in_backprop.
    266     Tensor in_backprop_tensor_temp;
    267     OP_REQUIRES_OK(context, context->forward_input_or_allocate_temp(
    268                                 {0}, DataTypeToEnum<double>::v(), in_shape,
    269                                 &in_backprop_tensor_temp));
    270     in_backprop_tensor_temp.flat<double>().setZero();
    271     // Transform 4D tensor to 2D matrix.
    272     EigenDoubleMatrixMap in_backprop_tensor_temp_mat(
    273         in_backprop_tensor_temp.flat<double>().data(), in_depth,
    274         in_cols * in_rows * in_batch);
    275     ConstEigenMatrixMap out_backprop_mat(out_backprop.flat<T>().data(),
    276                                          out_depth,
    277                                          out_cols * out_rows * out_batch);
    278     // Loop through each element of out_backprop and evenly distribute the
    279     // element to the corresponding pooling cell.
    280     const int64 in_max_row_index = in_rows - 1;
    281     const int64 in_max_col_index = in_cols - 1;
    282     for (int64 b = 0; b < out_batch; ++b) {
    283       for (int64 r = 0; r < out_rows; ++r) {
    284         const int64 in_row_start = row_seq_tensor_flat(r);
    285         int64 in_row_end = overlapping_ ? row_seq_tensor_flat(r + 1)
    286                                         : row_seq_tensor_flat(r + 1) - 1;
    287         in_row_end = std::min(in_row_end, in_max_row_index);
    288         for (int64 c = 0; c < out_cols; ++c) {
    289           const int64 in_col_start = col_seq_tensor_flat(c);
    290           int64 in_col_end = overlapping_ ? col_seq_tensor_flat(c + 1)
    291                                           : col_seq_tensor_flat(c + 1) - 1;
    292           in_col_end = std::min(in_col_end, in_max_col_index);
    293 
    294           const int64 num_elements_in_pooling_cell =
    295               (in_row_end - in_row_start + 1) * (in_col_end - in_col_start + 1);
    296           const int64 out_index = (b * out_rows + r) * out_cols + c;
    297           // Now we can evenly distribute out_backprop(b, h, w, *) to
    298           // in_backprop(b, hs:he, ws:we, *).
    299           for (int64 in_r = in_row_start; in_r <= in_row_end; ++in_r) {
    300             for (int64 in_c = in_col_start; in_c <= in_col_end; ++in_c) {
    301               const int64 in_index = (b * in_rows + in_r) * in_cols + in_c;
    302               // Walk through each channel (depth).
    303               for (int64 d = 0; d < out_depth; ++d) {
    304                 const double out_backprop_element = static_cast<double>(
    305                     out_backprop_mat.coeffRef(d, out_index));
    306                 double& in_backprop_ref =
    307                     in_backprop_tensor_temp_mat.coeffRef(d, in_index);
    308                 in_backprop_ref +=
    309                     out_backprop_element / num_elements_in_pooling_cell;
    310               }
    311             }
    312           }
    313         }
    314       }
    315     }
    316 
    317     // Depending on the type, cast double to type T.
    318     Tensor* in_backprop_tensor = nullptr;
    319     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    320                                 {0}, 0, in_shape, &in_backprop_tensor));
    321     auto in_backprop_tensor_flat = in_backprop_tensor->flat<T>();
    322     auto in_backprop_tensor_temp_flat = in_backprop_tensor_temp.flat<double>();
    323     for (int64 i = 0; i < in_backprop_tensor_flat.size(); ++i) {
    324       in_backprop_tensor_flat(i) =
    325           static_cast<T>(in_backprop_tensor_temp_flat(i));
    326     }
    327   }
    328 
    329  private:
    330   bool overlapping_;
    331 };
    332 
    333 #define REGISTER_FRACTIONALAVGPOOLGRAD(type)              \
    334   REGISTER_KERNEL_BUILDER(Name("FractionalAvgPoolGrad")   \
    335                               .Device(DEVICE_CPU)         \
    336                               .TypeConstraint<type>("T"), \
    337                           FractionalAvgPoolGradOp<type>)
    338 
    339 REGISTER_FRACTIONALAVGPOOLGRAD(int32);
    340 REGISTER_FRACTIONALAVGPOOLGRAD(int64);
    341 REGISTER_FRACTIONALAVGPOOLGRAD(float);
    342 REGISTER_FRACTIONALAVGPOOLGRAD(double);
    343 
    344 #undef REGISTER_FRACTIONALAVGPOOLGRAD
    345 }  // namespace tensorflow
    346