Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
     17 #define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/kernels/eigen_volume_patch.h"
     21 
     22 namespace Eigen {
     23 
     24 /** CuboidConvolutionBackwardInput
     25  * \ingroup CXX11_NeuralNetworks_Module
     26  *
     27  * \brief Computes the backprop for the input of a 3D convolution.
     28  *
     29  * The output_backward parameter is expected to be a tensor with a rank of 4 or
     30  * more (channels, depth, height, width, and optionally others)
     31  * The kernel parameter is expected to be a 5D tensor (filters, channels,
     32  * kernel_depth, kernel_height, kernel_width)
     33  * output_backward and kernel have to be in the same layout.
     34  *
     35  * The dimensions of the result will be filters, depth, height, width (and
     36  * others if applicable).
     37  *
     38  * It is possible to swap the order of the depth, width and height dimensions
     39  * provided that the same order is used in the input, the kernel, and the
     40  * output.
     41  *
     42  * All dimension orders above are given for col-major, and should be reversed
     43  * for row-major.
     44  */
     45 
     46 template <typename OutputBackward, typename Kernel>
     47 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
     48     internal::traits<OutputBackward>::Layout == ColMajor,
     49     TensorReshapingOp<
     50         const DSizes<typename internal::traits<OutputBackward>::Index,
     51                      internal::traits<OutputBackward>::NumDimensions>,
     52         const TensorContractionOp<
     53             const array<
     54                 IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
     55             const TensorReshapingOp<
     56                 const DSizes<typename internal::traits<OutputBackward>::Index,
     57                              3>,
     58                 const TensorReverseOp<const array<bool, 5>, const Kernel> >,
     59             const TensorReshapingOp<
     60                 const DSizes<typename internal::traits<OutputBackward>::Index,
     61                              3>,
     62                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
     63                                           const OutputBackward> > > >,
     64     TensorReshapingOp<
     65         const DSizes<typename internal::traits<OutputBackward>::Index,
     66                      internal::traits<OutputBackward>::NumDimensions>,
     67         const TensorContractionOp<
     68             const array<
     69                 IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
     70             const TensorReshapingOp<
     71                 const DSizes<typename internal::traits<OutputBackward>::Index,
     72                              3>,
     73                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
     74                                           const OutputBackward> >,
     75             const TensorReshapingOp<
     76                 const DSizes<typename internal::traits<OutputBackward>::Index,
     77                              3>,
     78                 const TensorReverseOp<const array<bool, 5>,
     79                                       const Kernel> > > > >::type
     80 CuboidConvolutionBackwardInput(
     81     const Kernel& kernel, const OutputBackward& output_backward,
     82     typename internal::traits<OutputBackward>::Index inputPlanes,
     83     typename internal::traits<OutputBackward>::Index inputRows,
     84     typename internal::traits<OutputBackward>::Index inputCols,
     85     const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
     86     const DenseIndex strideCols = 1) {
     87   typedef typename internal::traits<OutputBackward>::Index TensorIndex;
     88   const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
     89                                internal::traits<Kernel>::NumDimensions,
     90                                internal::traits<Kernel>::Layout, TensorIndex> >
     91       kern(kernel);
     92   const TensorRef<
     93       const Tensor<typename internal::traits<OutputBackward>::Scalar,
     94                    internal::traits<OutputBackward>::NumDimensions,
     95                    internal::traits<OutputBackward>::Layout, TensorIndex> >
     96       out(output_backward);
     97 
     98   EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
     99                           internal::traits<OutputBackward>::Layout,
    100                       YOU_MADE_A_PROGRAMMING_MISTAKE);
    101 
    102   static const bool isColMajor =
    103       (internal::traits<OutputBackward>::Layout == ColMajor);
    104 
    105   static const int NumDims = internal::traits<OutputBackward>::NumDimensions;
    106 
    107   // Number of filters to apply. This is the same as the output depth of the
    108   // result
    109   const TensorIndex kernelFilters =
    110       isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
    111   // Number of channels. This is the same as the input depth.
    112   const TensorIndex kernelChannels =
    113       isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
    114   const TensorIndex kernelPlanes =
    115       isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
    116   const TensorIndex kernelRows =
    117       isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
    118   const TensorIndex kernelCols =
    119       isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
    120 
    121   const TensorIndex outputPlanes =
    122       isColMajor ? out.dimensions()[1] : out.dimensions()[NumDims - 2];
    123   const TensorIndex outputRows =
    124       isColMajor ? out.dimensions()[2] : out.dimensions()[NumDims - 3];
    125   const TensorIndex outputCols =
    126       isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4];
    127 
    128   TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
    129   const TensorIndex size_z =
    130       Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
    131   const TensorIndex size_y =
    132       Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
    133   const TensorIndex size_x =
    134       Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
    135 
    136   // Infer padding type.
    137   if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
    138     // SAME padding.
    139     const TensorIndex dz = numext::maxi<TensorIndex>(
    140         0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
    141     const TensorIndex dy = numext::maxi<TensorIndex>(
    142         0, (size_y - 1) * strideRows + kernelRows - inputRows);
    143     const TensorIndex dx = numext::maxi<TensorIndex>(
    144         0, (size_x - 1) * strideCols + kernelCols - inputCols);
    145 
    146     forward_pad_z = dz / 2;
    147     forward_pad_y = dy / 2;
    148     forward_pad_x = dx / 2;
    149   } else {
    150     // VALID padding.
    151     forward_pad_z = 0;
    152     forward_pad_y = 0;
    153     forward_pad_x = 0;
    154   }
    155   const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
    156   const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
    157   const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
    158 
    159   const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
    160                                       (outputPlanes - 1) * stridePlanes - 1 -
    161                                       padding_ztop;
    162   const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
    163                                      (outputRows - 1) * strideRows - 1 -
    164                                      padding_top;
    165   const TensorIndex padding_right = inputCols + kernelCols - 1 -
    166                                     (outputCols - 1) * strideCols - 1 -
    167                                     padding_left;
    168 
    169   eigen_assert(padding_ztop >= 0);
    170   eigen_assert(padding_zbottom >= 0);
    171   eigen_assert(padding_top >= 0);
    172   eigen_assert(padding_left >= 0);
    173   eigen_assert(padding_bottom >= 0);
    174   eigen_assert(padding_right >= 0);
    175 
    176   // The kernel has dimensions filters X channels X patch_planes X patch_rows X
    177   // patch_cols.
    178   // We need to reverse the kernel along the spatial dimensions.
    179   array<bool, 5> kernel_reverse;
    180   if (isColMajor) {
    181     kernel_reverse[0] = false;
    182     kernel_reverse[1] = false;
    183     kernel_reverse[2] = true;
    184     kernel_reverse[3] = true;
    185     kernel_reverse[4] = true;
    186   } else {
    187     kernel_reverse[0] = true;
    188     kernel_reverse[1] = true;
    189     kernel_reverse[2] = true;
    190     kernel_reverse[3] = false;
    191     kernel_reverse[4] = false;
    192   }
    193 
    194   DSizes<TensorIndex, 3> kernel_dims;
    195   if (isColMajor) {
    196     kernel_dims[0] = kernelFilters;
    197     kernel_dims[1] = kernelChannels;
    198     kernel_dims[2] = kernelRows * kernelCols * kernelPlanes;
    199   } else {
    200     kernel_dims[0] = kernelRows * kernelCols * kernelPlanes;
    201     kernel_dims[1] = kernelChannels;
    202     kernel_dims[2] = kernelFilters;
    203   }
    204 
    205   // The output_backward has dimensions out_depth X out_planes X out_rows X
    206   // out_cols X OTHERS
    207   // When we extract the image patches from output_backward, it will have
    208   // dimensions:
    209   //   out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes *
    210   //   input_rows * input_cols * OTHERS)
    211   DSizes<TensorIndex, 3> pre_contract_dims;
    212   if (isColMajor) {
    213     pre_contract_dims[0] = kernelFilters;
    214     pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
    215     pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
    216     for (int i = 4; i < NumDims; ++i) {
    217       pre_contract_dims[2] *= out.dimension(i);
    218     }
    219   } else {
    220     pre_contract_dims[2] = kernelFilters;
    221     pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
    222     pre_contract_dims[0] = inputRows * inputCols * inputPlanes;
    223     for (int i = 0; i < NumDims - 4; ++i) {
    224       pre_contract_dims[0] *= out.dimension(i);
    225     }
    226   }
    227 
    228   // We will contract along dimensions (0, 2) in kernel and (0, 1) in
    229   // output_backward, if this is col-major, and
    230   // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this
    231   // row-major.
    232   array<IndexPair<TensorIndex>, 2> contract_dims;
    233   if (isColMajor) {
    234     // col-major: kernel.contract(output.patches)
    235     contract_dims[0] = IndexPair<TensorIndex>(0, 0);
    236     contract_dims[1] = IndexPair<TensorIndex>(2, 1);
    237   } else {
    238     // row-major: output.patches.contract(kernel)
    239     contract_dims[0] = IndexPair<TensorIndex>(1, 0);
    240     contract_dims[1] = IndexPair<TensorIndex>(2, 2);
    241   }
    242 
    243   // Post contraction, the dimensions of the input_backprop is
    244   //  channels X input_planes X input_rows X input_cols X OTHERS
    245   DSizes<TensorIndex, NumDims> post_contract_dims;
    246   if (isColMajor) {
    247     post_contract_dims[0] = kernelChannels;
    248     post_contract_dims[1] = inputPlanes;
    249     post_contract_dims[2] = inputRows;
    250     post_contract_dims[3] = inputCols;
    251     for (int i = 4; i < NumDims; ++i) {
    252       post_contract_dims[i] = out.dimension(i);
    253     }
    254   } else {
    255     post_contract_dims[NumDims - 1] = kernelChannels;
    256     post_contract_dims[NumDims - 2] = inputPlanes;
    257     post_contract_dims[NumDims - 3] = inputRows;
    258     post_contract_dims[NumDims - 4] = inputCols;
    259     for (int i = 0; i < NumDims - 4; ++i) {
    260       post_contract_dims[i] = out.dimension(i);
    261     }
    262   }
    263 
    264   DSizes<TensorIndex, NumDims> strides;
    265   for (int i = 0; i < NumDims; i++) {
    266     strides[i] = 1;
    267   }
    268   if (isColMajor) {
    269     strides[1] = stridePlanes;
    270     strides[2] = strideRows;
    271     strides[3] = strideCols;
    272   } else {
    273     strides[NumDims - 2] = stridePlanes;
    274     strides[NumDims - 3] = strideRows;
    275     strides[NumDims - 4] = strideCols;
    276   }
    277 
    278   return choose(
    279       Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
    280       kernel.reverse(kernel_reverse)
    281           .reshape(kernel_dims)
    282           .contract(output_backward
    283                         .extract_volume_patches(
    284                             kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
    285                             stridePlanes, strideRows, strideCols, padding_ztop,
    286                             padding_zbottom, padding_top, padding_bottom,
    287                             padding_left, padding_right)
    288                         .reshape(pre_contract_dims),
    289                     contract_dims)
    290           .reshape(post_contract_dims),
    291       output_backward
    292           .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
    293                                   stridePlanes, strideRows, strideCols,
    294                                   padding_ztop, padding_zbottom, padding_top,
    295                                   padding_bottom, padding_left, padding_right)
    296           .reshape(pre_contract_dims)
    297           .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims),
    298                     contract_dims)
    299           .reshape(post_contract_dims));
    300 }
    301 
    302 /** CuboidConvolutionBackwardKernel
    303  * \ingroup CXX11_NeuralNetworks_Module
    304  *
    305  * \brief Computes the backprop for the filter of a 3D convolution.
    306  *
    307  * The output_backward parameter is expected to be a tensor with a rank of 4 or
    308  * more (channels, depth, height, width, and optionally others)
    309  * The kernel parameter is expected to be a 4D tensor (filters, channels,
    310  * kernel_depth, kernel_height, kernel_width)
    311  * output_backward and kernel have to be in the same layout.
    312  *
    313  * The dimensions of the result will be filters, depth, height, width (and
    314  * others if applicable).
    315  *
    316  * It is possible to swap the order of the depth, width and height dimensions
    317  * provided that the same order is used in the input, the kernel, and the
    318  * output.
    319  *
    320  * All dimension orders above are given for col-major, and should be reversed
    321  * for row-major.
    322  */
    323 template <typename OutputBackward, typename Input>
    324 EIGEN_ALWAYS_INLINE static const typename internal::conditional<
    325     internal::traits<OutputBackward>::Layout == ColMajor,
    326     const TensorShufflingOp<
    327         const array<typename internal::traits<OutputBackward>::Index, 5>,
    328         const TensorReverseOp<
    329             const array<bool, 5>,
    330             const TensorReshapingOp<
    331                 const DSizes<typename internal::traits<OutputBackward>::Index,
    332                              5>,
    333                 const TensorContractionOp<
    334                     const array<
    335                         IndexPair<typename internal::traits<Input>::Index>, 2>,
    336                     const TensorReshapingOp<
    337                         const DSizes<typename internal::traits<Input>::Index,
    338                                      3>,
    339                         const Input>,
    340                     const TensorReshapingOp<
    341                         const DSizes<
    342                             typename internal::traits<OutputBackward>::Index,
    343                             4>,
    344                         const TensorVolumePatchOp<
    345                             Dynamic, Dynamic, Dynamic,
    346                             const OutputBackward> > > > > >,
    347     const TensorShufflingOp<
    348         const array<typename internal::traits<OutputBackward>::Index, 5>,
    349         const TensorReverseOp<
    350             const array<bool, 5>,
    351             const TensorReshapingOp<
    352                 const DSizes<typename internal::traits<OutputBackward>::Index,
    353                              5>,
    354                 const TensorContractionOp<
    355                     const array<
    356                         IndexPair<typename internal::traits<Input>::Index>, 2>,
    357                     const TensorReshapingOp<
    358                         const DSizes<
    359                             typename internal::traits<OutputBackward>::Index,
    360                             4>,
    361                         const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
    362                                                   const OutputBackward> >,
    363                     const TensorReshapingOp<
    364                         const DSizes<typename internal::traits<Input>::Index,
    365                                      3>,
    366                         const Input> > > > > >::type
    367 CuboidConvolutionBackwardKernel(
    368     const Input& input, const OutputBackward& output_backward,
    369     typename internal::traits<Input>::Index kernelPlanes,
    370     typename internal::traits<Input>::Index kernelRows,
    371     typename internal::traits<Input>::Index kernelCols,
    372     const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
    373     const DenseIndex strideCols = 1) {
    374   typedef typename internal::traits<Input>::Index TensorIndex;
    375   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
    376                    internal::traits<Input>::NumDimensions,
    377                    internal::traits<Input>::Layout, TensorIndex> >
    378       in(input);
    379   TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
    380                    internal::traits<OutputBackward>::NumDimensions,
    381                    internal::traits<OutputBackward>::Layout, TensorIndex> >
    382       out(output_backward);
    383 
    384   EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
    385                           internal::traits<OutputBackward>::Layout,
    386                       YOU_MADE_A_PROGRAMMING_MISTAKE);
    387 
    388   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
    389 
    390   static const int NumDims = internal::traits<Input>::NumDimensions;
    391   EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions ==
    392                           internal::traits<OutputBackward>::NumDimensions,
    393                       YOU_MADE_A_PROGRAMMING_MISTAKE);
    394 
    395   const TensorIndex inputPlanes =
    396       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
    397   const TensorIndex inputRows =
    398       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
    399   const TensorIndex inputCols =
    400       isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
    401 
    402   const TensorIndex outputPlanes =
    403       isColMajor ? out.dimension(1) : out.dimension(NumDims - 2);
    404   const TensorIndex outputRows =
    405       isColMajor ? out.dimension(2) : out.dimension(NumDims - 3);
    406   const TensorIndex outputCols =
    407       isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
    408 
    409   const TensorIndex kernelFilters =
    410       isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
    411   const TensorIndex kernelChannels =
    412       isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
    413 
    414   TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
    415   const TensorIndex size_z =
    416       Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
    417   const TensorIndex size_y =
    418       Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
    419   const TensorIndex size_x =
    420       Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
    421 
    422   // Infer padding type.
    423   if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
    424     // SAME padding.
    425     const TensorIndex dz = numext::maxi<TensorIndex>(
    426         0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
    427     const TensorIndex dy = numext::maxi<TensorIndex>(
    428         0, (size_y - 1) * strideRows + kernelRows - inputRows);
    429     const TensorIndex dx = numext::maxi<TensorIndex>(
    430         0, (size_x - 1) * strideCols + kernelCols - inputCols);
    431 
    432     forward_pad_z = dz / 2;
    433     forward_pad_y = dy / 2;
    434     forward_pad_x = dx / 2;
    435   } else {
    436     // VALID padding.
    437     forward_pad_z = 0;
    438     forward_pad_y = 0;
    439     forward_pad_x = 0;
    440   }
    441 
    442   const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
    443   const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
    444   const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
    445 
    446   const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
    447                                       (outputPlanes - 1) * stridePlanes - 1 -
    448                                       padding_ztop;
    449   const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
    450                                      (outputRows - 1) * strideRows - 1 -
    451                                      padding_top;
    452   const TensorIndex padding_right = inputCols + kernelCols - 1 -
    453                                     (outputCols - 1) * strideCols - 1 -
    454                                     padding_left;
    455 
    456   eigen_assert(padding_ztop >= 0);
    457   eigen_assert(padding_zbottom >= 0);
    458   eigen_assert(padding_top >= 0);
    459   eigen_assert(padding_left >= 0);
    460   eigen_assert(padding_bottom >= 0);
    461   eigen_assert(padding_right >= 0);
    462 
    463   // The output_backward has dimensions out_depth X out_plaens X out_rows X
    464   // out_cols X OTHERS
    465   // When we extract the image patches from output_backward (with input as the
    466   // kernel), it will have dimensions
    467   //  (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes *
    468   //  kernel_rows * kernel_cols) X OTHERS
    469   DSizes<TensorIndex, 4> pre_contract_dims;
    470   if (isColMajor) {
    471     pre_contract_dims[0] = kernelFilters;
    472     pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
    473     pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
    474     pre_contract_dims[3] = 1;
    475     for (int i = 4; i < NumDims; ++i) {
    476       pre_contract_dims[3] *= out.dimension(i);
    477     }
    478   } else {
    479     pre_contract_dims[3] = kernelFilters;
    480     pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
    481     pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
    482     pre_contract_dims[0] = 1;
    483     for (int i = 0; i < NumDims - 4; ++i) {
    484       pre_contract_dims[0] *= out.dimension(i);
    485     }
    486   }
    487 
    488   // The input has dimensions in_depth X (input_planes * input_rows *
    489   // input_cols) X OTHERS
    490   DSizes<TensorIndex, 3> input_dims;
    491   if (isColMajor) {
    492     input_dims[0] = kernelChannels;
    493     input_dims[1] = inputRows * inputCols * inputPlanes;
    494     input_dims[2] = 1;
    495     for (int i = 4; i < NumDims; ++i) {
    496       input_dims[2] *= in.dimension(i);
    497     }
    498     eigen_assert(input_dims[2] == pre_contract_dims[3]);
    499   } else {
    500     input_dims[2] = kernelChannels;
    501     input_dims[1] = inputRows * inputCols * inputPlanes;
    502     input_dims[0] = 1;
    503     for (int i = 0; i < NumDims - 4; ++i) {
    504       input_dims[0] *= in.dimension(i);
    505     }
    506     eigen_assert(input_dims[0] == pre_contract_dims[0]);
    507   }
    508 
    509   // We will contract along dimensions (1, 2) in and (1, 3) in out, if
    510   // this is col-major.
    511   // For row-major, it's dimensions (0, 1) in and (0, 2) in out.
    512   array<IndexPair<TensorIndex>, 2> contract_dims;
    513   if (isColMajor) {
    514     // col-major: in.contract(output.patches)
    515     contract_dims[0] = IndexPair<TensorIndex>(1, 1);
    516     contract_dims[1] = IndexPair<TensorIndex>(2, 3);
    517   } else {
    518     // row-major: output.patches.contract(in)
    519     contract_dims[0] = IndexPair<TensorIndex>(0, 0);
    520     contract_dims[1] = IndexPair<TensorIndex>(2, 1);
    521   }
    522 
    523   // After the contraction, the kernel will have dimension
    524   //   in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
    525   // We will need to shuffle the first two dimensions and reverse the spatial
    526   // dimensions.
    527   // The end shape is:
    528   //   out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
    529 
    530   // This is the shape of the kernel *before* the shuffling.
    531   DSizes<TensorIndex, 5> kernel_dims;
    532   if (isColMajor) {
    533     kernel_dims[0] = kernelChannels;
    534     kernel_dims[1] = kernelFilters;
    535     kernel_dims[2] = kernelPlanes;
    536     kernel_dims[3] = kernelRows;
    537     kernel_dims[4] = kernelCols;
    538   } else {
    539     kernel_dims[0] = kernelCols;
    540     kernel_dims[1] = kernelRows;
    541     kernel_dims[2] = kernelPlanes;
    542     kernel_dims[3] = kernelFilters;
    543     kernel_dims[4] = kernelChannels;
    544   }
    545 
    546   // Flip filters and channels.
    547   array<TensorIndex, 5> kernel_shuffle;
    548   if (isColMajor) {
    549     kernel_shuffle[0] = 1;
    550     kernel_shuffle[1] = 0;
    551     kernel_shuffle[2] = 2;
    552     kernel_shuffle[3] = 3;
    553     kernel_shuffle[4] = 4;
    554   } else {
    555     kernel_shuffle[0] = 0;
    556     kernel_shuffle[1] = 1;
    557     kernel_shuffle[2] = 2;
    558     kernel_shuffle[3] = 4;
    559     kernel_shuffle[4] = 3;
    560   }
    561 
    562   // Reverse the spatial dimensions.
    563   array<bool, 5> kernel_reverse;
    564   if (isColMajor) {
    565     kernel_reverse[0] = false;
    566     kernel_reverse[1] = false;
    567     kernel_reverse[2] = true;
    568     kernel_reverse[3] = true;
    569     kernel_reverse[4] = true;
    570   } else {
    571     kernel_reverse[0] = true;
    572     kernel_reverse[1] = true;
    573     kernel_reverse[2] = true;
    574     kernel_reverse[3] = false;
    575     kernel_reverse[4] = false;
    576   }
    577 
    578   DSizes<TensorIndex, NumDims> strides;
    579   for (int i = 0; i < NumDims; i++) {
    580     strides[i] = 1;
    581   }
    582   if (isColMajor) {
    583     strides[1] = stridePlanes;
    584     strides[2] = strideRows;
    585     strides[3] = strideCols;
    586   } else {
    587     strides[NumDims - 2] = stridePlanes;
    588     strides[NumDims - 3] = strideRows;
    589     strides[NumDims - 4] = strideCols;
    590   }
    591   return choose(
    592       Cond<internal::traits<Input>::Layout == ColMajor>(),
    593       input.reshape(input_dims)
    594           .contract(output_backward
    595                         .extract_volume_patches(
    596                             inputPlanes, inputRows, inputCols, 1, 1, 1,
    597                             stridePlanes, strideRows, strideCols,
    598 
    599                             padding_ztop, padding_zbottom, padding_top,
    600                             padding_bottom, padding_left, padding_right)
    601                         .reshape(pre_contract_dims),
    602                     contract_dims)
    603           .reshape(kernel_dims)
    604           .reverse(kernel_reverse)
    605           .shuffle(kernel_shuffle),
    606       output_backward
    607           .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1,
    608                                   stridePlanes, strideRows, strideCols,
    609                                   padding_ztop, padding_zbottom, padding_top,
    610                                   padding_bottom, padding_left, padding_right)
    611           .reshape(pre_contract_dims)
    612           .contract(input.reshape(input_dims), contract_dims)
    613           .reshape(kernel_dims)
    614           .reverse(kernel_reverse)
    615           .shuffle(kernel_shuffle));
    616 }
    617 
    618 }  // end namespace Eigen
    619 
    620 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_
    621