Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
     17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 
     21 namespace Eigen {
     22 
     23 namespace internal {
     24 
     25 // TODO: Consolidate this part of the code with the image patch extraction code
     26 // since they are both very similar.
     27 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
     28           typename ArgType, typename Device, typename Scalar_, typename Index,
     29           typename nocontract_t, typename contract_t, int Side, int packet_size,
     30           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
     31 class TensorContractionInputMapper<
     32     Scalar_, Index, Side,
     33     TensorEvaluator<
     34         const TensorReshapingOp<NewDimension,
     35                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
     36         Device>,
     37     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
     38     inner_dim_reordered, Alignment> {
     39  public:
     40   typedef Scalar_ Scalar;
     41   typedef TensorContractionInputMapper<
     42       Scalar, Index, Side,
     43       TensorEvaluator<
     44           const TensorReshapingOp<
     45               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
     46           Device>,
     47       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
     48       inner_dim_reordered, Alignment>
     49       Self;
     50   typedef TensorContractionSubMapper<
     51       Scalar, Index, Side,
     52       TensorEvaluator<
     53           const TensorReshapingOp<
     54               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
     55           Device>,
     56       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
     57       inner_dim_reordered, Alignment>
     58       SubMapper;
     59   typedef SubMapper VectorMapper;
     60   typedef SubMapper LinearMapper;
     61   typedef typename packet_traits<Scalar>::type Packet;
     62 
     63   EIGEN_DEVICE_FUNC
     64   TensorContractionInputMapper(
     65       const TensorEvaluator<
     66           const TensorReshapingOp<
     67               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
     68           Device>& tensor,
     69       const nocontract_t&, const nocontract_t&, const contract_t&,
     70       const contract_t&)
     71       : m_impl(tensor.impl().impl()) {
     72     Index patch_rows;
     73     Index patch_depth;
     74     if (internal::traits<ArgType>::Layout == ColMajor) {
     75       patch_depth = tensor.impl().dimensions()[0];
     76       patch_rows = tensor.impl().dimensions()[1];
     77       m_patch_cols = tensor.impl().dimensions()[2];
     78       m_num_patches = tensor.impl().dimensions()[3];
     79     } else {
     80       const int NumDims = tensor.impl().dimensions().size();
     81       patch_depth = tensor.impl().dimensions()[NumDims - 1];
     82       patch_rows = tensor.impl().dimensions()[NumDims - 2];
     83       m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
     84       m_num_patches = tensor.impl().dimensions()[NumDims - 4];
     85     }
     86     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
     87     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
     88 
     89     m_colStride = patch_rows;
     90 
     91     m_outputRows = tensor.impl().outputRows();
     92     m_row_strides = tensor.impl().userRowStride();
     93     m_col_strides = tensor.impl().userColStride();
     94 
     95     m_in_row_strides = tensor.impl().userInRowStride();
     96     m_in_col_strides = tensor.impl().userInColStride();
     97 
     98     if (internal::traits<ArgType>::Layout == ColMajor) {
     99       m_inputRows = tensor.impl().impl().dimensions()[1];
    100       m_inputCols = tensor.impl().impl().dimensions()[2];
    101     } else {
    102       const int NumDims = tensor.impl().impl().dimensions().size();
    103       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
    104       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
    105     }
    106 
    107     m_rowInputStride = patch_depth;
    108     m_colInputStride = patch_depth * m_inputRows;
    109     m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
    110 
    111     m_rowPaddingTop = tensor.impl().rowPaddingTop();
    112     m_colPaddingLeft = tensor.impl().colPaddingLeft();
    113 
    114     m_fastInputRowStride =
    115         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
    116     m_fastInputColStride =
    117         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
    118     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
    119     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
    120     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
    121     m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
    122   }
    123 
    124   EIGEN_DEVICE_FUNC
    125   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
    126       : m_impl(base_mapper.m_impl) {
    127     m_patch_cols = base_mapper.m_patch_cols;
    128     m_num_patches = base_mapper.m_num_patches;
    129     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
    130     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
    131 
    132     m_colStride = base_mapper.m_colStride;
    133 
    134     m_rowInputStride = base_mapper.m_rowInputStride;
    135     m_colInputStride = base_mapper.m_colInputStride;
    136     m_patchInputStride = base_mapper.m_patchInputStride;
    137 
    138     m_inputRows = base_mapper.m_inputRows;
    139     m_inputCols = base_mapper.m_inputCols;
    140 
    141     m_outputRows = base_mapper.m_outputRows;
    142     m_row_strides = base_mapper.m_row_strides;
    143     m_col_strides = base_mapper.m_col_strides;
    144 
    145     m_in_row_strides = base_mapper.m_in_row_strides;
    146     m_in_col_strides = base_mapper.m_in_col_strides;
    147 
    148     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
    149     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
    150 
    151     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
    152     m_fastInputColStride = base_mapper.m_fastInputColStride;
    153     m_fastNumPatches = base_mapper.m_fastNumPatches;
    154     m_fastColStride = base_mapper.m_fastColStride;
    155     m_fastOutputRows = base_mapper.m_fastOutputRows;
    156     m_fastDimZero = base_mapper.m_fastDimZero;
    157   }
    158 
    159   // If true, turns off some optimizations for loading packets since the image
    160   // patches are "non-standard" such as there are non-trivial strides or
    161   // inflations in the input.
    162   EIGEN_DEVICE_FUNC
    163   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
    164     return m_in_row_strides != 1 || m_in_col_strides != 1 ||
    165            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
    166   }
    167 
    168   EIGEN_DEVICE_FUNC
    169   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
    170     return SubMapper(*this, i, j);
    171   }
    172 
    173   EIGEN_DEVICE_FUNC
    174   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
    175     return LinearMapper(*this, i, j);
    176   }
    177 
    178   EIGEN_DEVICE_FUNC
    179   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
    180     Index rowIndex, colIndex, otherIndex;
    181     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
    182     return loadCoeff(row, rowIndex, colIndex, otherIndex);
    183   }
    184 
    185   // Load the coefficient at the patchIndex location instead of the usual
    186   // m_rowIndex,
    187   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
    188   // EIGEN_DEVICE_FUNC
    189   EIGEN_DEVICE_FUNC
    190   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
    191     Index rowIndex, colIndex, otherIndex;
    192     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
    193     return loadCoeff(row, rowIndex, colIndex, otherIndex);
    194   }
    195 
    196   EIGEN_DEVICE_FUNC
    197   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
    198     Index rowIndex, colIndex, otherIndex;
    199     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
    200     return loadPacket(row, rowIndex, colIndex, otherIndex);
    201   }
    202 
    203   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
    204   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
    205   EIGEN_DEVICE_FUNC
    206   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
    207     Index rowIndex, colIndex, otherIndex;
    208     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
    209     return loadPacket(row, rowIndex, colIndex, otherIndex);
    210   }
    211 
    212   EIGEN_DEVICE_FUNC
    213   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
    214     return m_impl;
    215   }
    216 
    217   EIGEN_DEVICE_FUNC
    218   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
    219   EIGEN_DEVICE_FUNC
    220   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
    221   EIGEN_DEVICE_FUNC
    222   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
    223 
    224   EIGEN_DEVICE_FUNC
    225   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
    226                                              const Index baseIndex) const {
    227     const Index inputIndex = depth + baseIndex;
    228     return m_impl.template packet<Unaligned>(inputIndex);
    229   }
    230 
    231  private:
    232   friend class TensorContractionSubMapper<
    233       Scalar, Index, Side,
    234       TensorEvaluator<
    235           const TensorReshapingOp<
    236               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    237           Device>,
    238       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
    239       inner_dim_reordered, Alignment>;
    240 
    241   EIGEN_DEVICE_FUNC
    242   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
    243                                        Index colIndex, Index otherIndex) const {
    244     // Find the offset of the element wrt the location of the first element.
    245     const Index patchOffset = patchId / m_fastDimZero;
    246 
    247     const Index colOffset = patchOffset / m_fastColStride;
    248     const Index inputCol = colIndex + colOffset * m_in_col_strides;
    249     const Index origInputCol =
    250         (m_patch_col_inflate_strides == 1)
    251             ? inputCol
    252             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
    253     const Index rowOffset = patchOffset - colOffset * m_colStride;
    254     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
    255     const Index origInputRow =
    256         (m_patch_row_inflate_strides == 1)
    257             ? inputRow
    258             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
    259     if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
    260         origInputRow >= m_inputRows ||
    261         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
    262         (inputRow != origInputRow * m_patch_row_inflate_strides)) {
    263       return Scalar(0);
    264     }
    265     const Index depth = patchId - patchOffset * patchDepth();
    266     const Index inputIndex = depth + origInputRow * m_rowInputStride +
    267                              origInputCol * m_colInputStride + otherIndex;
    268     return m_impl.coeff(inputIndex);
    269   }
    270 
    271   EIGEN_DEVICE_FUNC
    272   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
    273                                                Index colIndex,
    274                                                Index otherIndex) const {
    275     eigen_assert(!nonStandardPatches());
    276 
    277     // Find the offset of the element wrt the location of the first element.
    278     const Index patchOffset = patchId / m_fastDimZero;
    279 
    280     const Index colOffset = patchOffset / m_fastColStride;
    281     const Index inputCol = colIndex + colOffset;
    282     const Index rowOffset = patchOffset - colOffset * m_colStride;
    283     const Index inputRow = rowIndex + rowOffset;
    284     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
    285         inputRow >= m_inputRows) {
    286       return Scalar(0);
    287     }
    288     const Index depth = patchId - patchOffset * patchDepth();
    289     const Index inputIndex = depth + inputRow * m_rowInputStride +
    290                              inputCol * m_colInputStride + otherIndex;
    291     return m_impl.coeff(inputIndex);
    292   }
    293 
    294   EIGEN_DEVICE_FUNC
    295   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
    296                                         Index colIndex,
    297                                         Index otherIndex) const {
    298     const Index packetSize = internal::unpacket_traits<Packet>::size;
    299     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
    300     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
    301 
    302     if (nonStandardPatches()) {
    303       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
    304     }
    305     return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex);
    306   }
    307 
    308   EIGEN_DEVICE_FUNC
    309   EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex,
    310                                                 Index colIndex,
    311                                                 Index otherIndex) const {
    312     const Index packetSize = internal::unpacket_traits<Packet>::size;
    313     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
    314     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
    315 
    316     eigen_assert(!nonStandardPatches());
    317 
    318     if ((patchDepth() % packetSize) == 0) {
    319       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
    320     } else {
    321       const Index patchOffsets[2] = {
    322           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
    323 
    324       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
    325                                    patchOffsets[1] / m_fastColStride};
    326 
    327       const Index inputCols[2] = {colIndex + colOffsets[0],
    328                                   colIndex + colOffsets[1]};
    329       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
    330         // all zeros
    331         return internal::pset1<Packet>(Scalar(0));
    332       }
    333 
    334       if (inputCols[0] == inputCols[1]) {
    335         const Index rowOffsets[2] = {
    336             patchOffsets[0] - colOffsets[0] * m_colStride,
    337             patchOffsets[1] - colOffsets[1] * m_colStride};
    338         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
    339         const Index inputRows[2] = {rowIndex + rowOffsets[0],
    340                                     rowIndex + rowOffsets[1]};
    341 
    342         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
    343           // all zeros
    344           return internal::pset1<Packet>(Scalar(0));
    345         }
    346 
    347         if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
    348           // no padding
    349           const Index depth = patchId - patchOffsets[0] * patchDepth();
    350           const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
    351                                    inputCols[0] * m_colInputStride + otherIndex;
    352           return m_impl.template packet<Unaligned>(inputIndex);
    353         }
    354       }
    355     }
    356     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
    357   }
    358 
    359   EIGEN_DEVICE_FUNC
    360   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex,
    361                                             Index colIndex,
    362                                             Index otherIndex) const {
    363     const Index packetSize = internal::unpacket_traits<Packet>::size;
    364     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
    365     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
    366 
    367     eigen_assert(!nonStandardPatches());
    368     eigen_assert((patchDepth() % packetSize) == 0);
    369     // Find the offset of the element wrt the location of the first element.
    370     const Index patchOffset = patchId / m_fastDimZero;
    371     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
    372 
    373     const Index colOffset = patchOffset / m_fastColStride;
    374     const Index inputCol = colIndex + colOffset;
    375     const Index rowOffset = patchOffset - colOffset * m_colStride;
    376     const Index inputRow = rowIndex + rowOffset;
    377     if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
    378         inputRow >= m_inputRows) {
    379       // all zeros
    380       return internal::pset1<Packet>(Scalar(0));
    381     }
    382     // no padding
    383     const Index depth = patchId - patchOffset * patchDepth();
    384     const Index inputIndex = depth + inputRow * m_rowInputStride +
    385                              inputCol * m_colInputStride + otherIndex;
    386     return m_impl.template packet<Unaligned>(inputIndex);
    387   }
    388 
    389   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(
    390       Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
    391     const int packetSize = internal::unpacket_traits<Packet>::size;
    392     EIGEN_ALIGN_MAX
    393     typename internal::remove_const<Scalar>::type values[packetSize];
    394     for (int i = 0; i < packetSize; ++i) {
    395       values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
    396     }
    397     Packet rslt = internal::pload<Packet>(values);
    398     return rslt;
    399   }
    400 
    401   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
    402       Index patchIndex, Index& rowIndex, Index& colIndex,
    403       Index& otherIndex) const {
    404     const int NumInputDims = array_size<
    405         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
    406     otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
    407     const Index patch2DIndex = (NumInputDims == 3)
    408                                    ? patchIndex
    409                                    : (patchIndex - otherIndex * m_num_patches);
    410     otherIndex *= m_patchInputStride;
    411     colIndex = patch2DIndex / m_fastOutputRows;
    412     rowIndex = patch2DIndex - colIndex * m_outputRows;
    413     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
    414     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
    415   }
    416 
    417   Index m_patch_cols;                 // number of colums in the patch
    418   Index m_num_patches;                // number of patches to extract.
    419   Index m_patch_row_inflate_strides;  // the strides for row inflation in the
    420                                       // image patch
    421   Index m_patch_col_inflate_strides;  // the strides for col inflation in the
    422                                       // image patch
    423   // Fast representation of inflation strides.
    424   internal::TensorIntDivisor<Index> m_fastInputRowStride;
    425   internal::TensorIntDivisor<Index> m_fastInputColStride;
    426 
    427   Index m_otherStride;
    428   Index m_colStride;
    429   internal::TensorIntDivisor<Index> m_fastNumPatches;
    430   internal::TensorIntDivisor<Index> m_fastColStride;
    431 
    432   Index m_rowInputStride;    // row stride in the input tensor
    433   Index m_colInputStride;    // col stride in the input tensor
    434   Index m_patchInputStride;  // patch stride in the input tensor
    435 
    436   Index m_inputRows;  // Number of rows in the input tensor
    437   Index m_inputCols;  // Number of cols in the input tensor
    438 
    439   Index m_outputRows;  // Number of patch rows
    440 
    441   Index m_row_strides;  // User specified row stride
    442   Index m_col_strides;  // User specified col stride
    443 
    444   Index m_in_row_strides;  // User specified input row stride
    445   Index m_in_col_strides;  // User specified input col stride
    446 
    447   Index m_rowPaddingTop;   // Row padding
    448   Index m_colPaddingLeft;  // Column padding
    449 
    450   internal::TensorIntDivisor<Index> m_fastOutputRows;
    451   internal::TensorIntDivisor<Index> m_fastDimZero;
    452 
    453   const TensorEvaluator<ArgType, Device> m_impl;
    454 };
    455 
    456 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
    457           typename ArgType, typename Device, typename Scalar, typename Index,
    458           typename nocontract_t, typename contract_t, int Side, int packet_size,
    459           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
    460 class TensorContractionSubMapper<
    461     Scalar, Index, Side,
    462     TensorEvaluator<
    463         const TensorReshapingOp<NewDimension,
    464                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
    465         Device>,
    466     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
    467     inner_dim_reordered, Alignment> {
    468  public:
    469   typedef typename packet_traits<Scalar>::type Packet;
    470   typedef typename packet_traits<Scalar>::half HalfPacket;
    471 
    472   typedef TensorContractionInputMapper<
    473       Scalar, Index, Side,
    474       TensorEvaluator<
    475           const TensorReshapingOp<
    476               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    477           Device>,
    478       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
    479       inner_dim_reordered, Alignment>
    480       ParentMapper;
    481   typedef TensorContractionSubMapper<
    482       Scalar, Index, Side,
    483       TensorEvaluator<
    484           const TensorReshapingOp<
    485               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    486           Device>,
    487       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
    488       inner_dim_reordered, Alignment>
    489       Self;
    490   typedef Self LinearMapper;
    491 
    492   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
    493       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
    494       : m_base_mapper(base_mapper),
    495         m_depth_offset(vert_offset),
    496         m_col_offset(horiz_offset) {
    497     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
    498                                      m_otherIndex);
    499   }
    500   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
    501       const Self& base_mapper, Index vert_offset, Index horiz_offset)
    502       : m_base_mapper(base_mapper.m_base_mapper),
    503         m_depth_offset(vert_offset + base_mapper.m_depth_offset),
    504         m_col_offset(horiz_offset + base_mapper.m_col_offset) {
    505     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
    506                                      m_otherIndex);
    507   }
    508   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
    509     return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex,
    510                                    m_otherIndex);
    511   }
    512   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
    513                                                           Index j) const {
    514     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
    515   }
    516 
    517   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
    518     return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex,
    519                                     m_otherIndex);
    520   }
    521   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
    522                                                           Index j) const {
    523     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
    524                                                         j + m_col_offset);
    525   }
    526 
    527   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
    528   loadCoeffStandard(Index i) const {
    529     return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex,
    530                                            m_colIndex, m_otherIndex);
    531   }
    532 
    533   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
    534     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex,
    535                                         m_colIndex, m_otherIndex);
    536   }
    537   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
    538   loadPacketStandard(Index i) const {
    539     return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex,
    540                                             m_colIndex, m_otherIndex);
    541   }
    542   template <typename Packet>
    543   EIGEN_DEVICE_FUNC bool aligned(Index) const {
    544     return false;
    545   }
    546 
    547   EIGEN_DEVICE_FUNC
    548   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
    549     return m_base_mapper.nonStandardPatches();
    550   }
    551 
    552   EIGEN_DEVICE_FUNC
    553   EIGEN_ALWAYS_INLINE Index patchDepth() const {
    554     return m_base_mapper.m_rowInputStride;
    555   }
    556   EIGEN_DEVICE_FUNC
    557   EIGEN_ALWAYS_INLINE Index patchRows() const {
    558     return m_base_mapper.m_colStride;
    559   }
    560   EIGEN_DEVICE_FUNC
    561   EIGEN_ALWAYS_INLINE Index patchCols() const {
    562     return m_base_mapper.m_patch_cols;
    563   }
    564 
    565   EIGEN_DEVICE_FUNC
    566   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
    567                                              const Index baseIndex) const {
    568     const Index inputIndex = depth + baseIndex;
    569     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
    570   }
    571 
    572   EIGEN_DEVICE_FUNC
    573   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
    574     const Index r = m_rowIndex + row;
    575     return r < 0 || r >= m_base_mapper.m_inputRows;
    576   }
    577   EIGEN_DEVICE_FUNC
    578   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
    579     const Index c = m_colIndex + col;
    580     return c < 0 || c >= m_base_mapper.m_inputCols;
    581   }
    582   EIGEN_DEVICE_FUNC
    583   EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const {
    584     const Index r = m_rowIndex + row;
    585     const Index c = m_colIndex + col;
    586     return r * m_base_mapper.m_rowInputStride +
    587            c * m_base_mapper.m_colInputStride + m_otherIndex;
    588   }
    589 
    590   EIGEN_DEVICE_FUNC
    591   EIGEN_ALWAYS_INLINE Index rowOffset() const {
    592     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
    593     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
    594     return patchOffset - colOffset * m_base_mapper.m_colStride;
    595   }
    596 
    597   EIGEN_DEVICE_FUNC
    598   EIGEN_ALWAYS_INLINE Index colOffset() const {
    599     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
    600     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
    601     return colOffset;
    602   }
    603 
    604   EIGEN_DEVICE_FUNC
    605   EIGEN_ALWAYS_INLINE Index depthOffset() const {
    606     const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
    607     return patchOffset;
    608   }
    609 
    610   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
    611   getLinearMapper(Index i, Index j) const {
    612     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
    613   }
    614 
    615  private:
    616   const ParentMapper& m_base_mapper;  // that was a reference before
    617   Index m_depth_offset;               // First row in the input matrix
    618   Index m_col_offset;                 // First col in the input matrix
    619 
    620   Index m_rowIndex;  // precomputed row index corresponding to the col offset
    621   Index m_colIndex;  // precomputed col index corresponding to the col offset
    622   Index
    623       m_otherIndex;  // precomputed other index corresponding to the col offset
    624 };
    625 
    626 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
    627           typename ArgType, typename Device, typename Scalar, typename Index,
    628           typename nocontract_t, typename contract_t, int packet_size,
    629           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
    630           int nr>
    631 struct gemm_pack_rhs<
    632     Scalar, Index,
    633     TensorContractionSubMapper<
    634         Scalar, Index, Rhs,
    635         TensorEvaluator<
    636             const TensorReshapingOp<
    637                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    638             Device>,
    639         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
    640         inner_dim_reordered, Alignment>,
    641     nr, ColMajor, false, false> {
    642   typedef TensorContractionSubMapper<
    643       Scalar, Index, Rhs,
    644       TensorEvaluator<
    645           const TensorReshapingOp<
    646               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    647           Device>,
    648       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
    649       inner_dim_reordered, Alignment>
    650       SubMapper;
    651   typedef SubMapper DataMapper;
    652 
    653   EIGEN_DEVICE_FUNC
    654   static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
    655 
    656   EIGEN_DEVICE_FUNC
    657   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
    658                                     Index depth, Index cols, Index stride = 0,
    659                                     Index offset = 0) const {
    660     eigen_assert(stride == 0);
    661     eigen_assert(offset == 0);
    662 
    663     EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
    664     typedef typename packet_traits<Scalar>::type Packet;
    665 
    666     const Index packet_cols4 = (cols / 4) * 4;
    667     const Index peeled_k = (depth / packet_size) * packet_size;
    668     const bool non_standard_patches = rhs.nonStandardPatches();
    669 
    670     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
    671       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
    672       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
    673       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
    674       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
    675 
    676       Index k = 0;
    677       if ((packet_size % 4) == 0 && !non_standard_patches) {
    678         const Index patch_depth = rhs.patchDepth();
    679         if ((patch_depth % packet_size) == 0) {
    680           const Index patch_cols = rhs.patchCols();
    681           const Index patch_rows = rhs.patchRows();
    682 
    683           const Index startCol = rhs.colOffset();
    684           const Index max_cols = std::min<Index>(
    685               ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
    686               patch_cols);
    687 
    688           for (Index c = startCol; c < max_cols; ++c) {
    689             eigen_assert(k < peeled_k);
    690             const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
    691             const Index max_rows = std::min<Index>(
    692                 ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
    693                     startRow,
    694                 patch_rows);
    695 
    696             const bool pad_col0 = dm0.padCol(c);
    697             const bool pad_col1 = dm1.padCol(c);
    698             const bool pad_col2 = dm2.padCol(c);
    699             const bool pad_col3 = dm3.padCol(c);
    700             for (Index r = startRow; r < max_rows; ++r) {
    701               eigen_assert(k < peeled_k);
    702               const bool pad0 = pad_col0 || dm0.padRow(r);
    703               const bool pad1 = pad_col1 || dm1.padRow(r);
    704               const bool pad2 = pad_col2 || dm2.padRow(r);
    705               const bool pad3 = pad_col3 || dm3.padRow(r);
    706 
    707               const Index idx0 = dm0.baseIndex(r, c);
    708               const Index idx1 = dm1.baseIndex(r, c);
    709               const Index idx2 = dm2.baseIndex(r, c);
    710               const Index idx3 = dm3.baseIndex(r, c);
    711 
    712               const Index startDepth =
    713                   ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
    714               const Index max_depth =
    715                   std::min<Index>(peeled_k - c * patch_rows * patch_depth -
    716                                       r * patch_depth + startDepth,
    717                                   patch_depth);
    718               eigen_assert((max_depth - startDepth) % packet_size == 0);
    719               for (Index d = startDepth; d < max_depth; d += packet_size) {
    720                 eigen_assert(k < peeled_k);
    721                 PacketBlock<Packet, 4> kernel;
    722                 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
    723                                         : rhs.packetNoPadding(d, idx0);
    724                 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
    725                                         : rhs.packetNoPadding(d, idx1);
    726                 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
    727                                         : rhs.packetNoPadding(d, idx2);
    728                 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
    729                                         : rhs.packetNoPadding(d, idx3);
    730                 ptranspose(kernel);
    731                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
    732                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
    733                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
    734                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
    735                 block += 4 * packet_size;
    736                 k += packet_size;
    737               }
    738             }
    739           }
    740 
    741           for (; k < peeled_k; k += packet_size) {
    742             PacketBlock<Packet, 4> kernel;
    743             kernel.packet[0] = dm0.loadPacketFast(k);
    744             kernel.packet[1] = dm1.loadPacketFast(k);
    745             kernel.packet[2] = dm2.loadPacketFast(k);
    746             kernel.packet[3] = dm3.loadPacketFast(k);
    747             ptranspose(kernel);
    748             pstoreu(block + 0 * packet_size, kernel.packet[0]);
    749             pstoreu(block + 1 * packet_size, kernel.packet[1]);
    750             pstoreu(block + 2 * packet_size, kernel.packet[2]);
    751             pstoreu(block + 3 * packet_size, kernel.packet[3]);
    752             block += 4 * packet_size;
    753           }
    754         } else {
    755           for (; k < peeled_k; k += packet_size) {
    756             PacketBlock<Packet, 4> kernel;
    757             kernel.packet[0] = dm0.loadPacketStandard(k);
    758             kernel.packet[1] = dm1.loadPacketStandard(k);
    759             kernel.packet[2] = dm2.loadPacketStandard(k);
    760             kernel.packet[3] = dm3.loadPacketStandard(k);
    761             ptranspose(kernel);
    762             pstoreu(block + 0 * packet_size, kernel.packet[0]);
    763             pstoreu(block + 1 * packet_size, kernel.packet[1]);
    764             pstoreu(block + 2 * packet_size, kernel.packet[2]);
    765             pstoreu(block + 3 * packet_size, kernel.packet[3]);
    766             block += 4 * packet_size;
    767           }
    768         }
    769       }
    770       if (!rhs.nonStandardPatches()) {
    771         for (; k < depth; k++) {
    772           block[0] = dm0.loadCoeffStandard(k);
    773           block[1] = dm1.loadCoeffStandard(k);
    774           block[2] = dm2.loadCoeffStandard(k);
    775           block[3] = dm3.loadCoeffStandard(k);
    776           block += 4;
    777         }
    778       } else {
    779         for (; k < depth; k++) {
    780           block[0] = dm0(k);
    781           block[1] = dm1(k);
    782           block[2] = dm2(k);
    783           block[3] = dm3(k);
    784           block += 4;
    785         }
    786       }
    787     }
    788 
    789     // copy the remaining columns one at a time (nr==1)
    790     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
    791       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
    792       for (Index k = 0; k < depth; k++) {
    793         *block = dm0(k);
    794         block += 1;
    795       }
    796     }
    797   }
    798 };
    799 
    800 // Special case for non-vectorized types such as float16.
    801 template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
    802           typename ArgType, typename Device, typename Scalar, typename Index,
    803           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
    804           bool inner_dim_reordered, int Alignment, int nr>
    805 struct gemm_pack_rhs<
    806     Scalar, Index,
    807     TensorContractionSubMapper<
    808         Scalar, Index, Rhs,
    809         TensorEvaluator<
    810             const TensorReshapingOp<
    811                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    812             Device>,
    813         nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
    814         Alignment>,
    815     nr, ColMajor, false, false> {
    816   typedef TensorContractionSubMapper<
    817       Scalar, Index, Rhs,
    818       TensorEvaluator<
    819           const TensorReshapingOp<
    820               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
    821           Device>,
    822       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
    823       Alignment>
    824       SubMapper;
    825   typedef SubMapper DataMapper;
    826 
    827   EIGEN_DEVICE_FUNC
    828   static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
    829 
    830   EIGEN_DEVICE_FUNC
    831   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
    832                                     Index depth, Index cols, Index stride = 0,
    833                                     Index offset = 0) const {
    834     eigen_assert(stride == 0);
    835     eigen_assert(offset == 0);
    836 
    837     EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
    838 
    839     const Index packet_cols4 = (cols / 4) * 4;
    840 
    841     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
    842       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
    843       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
    844       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
    845       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
    846 
    847       if (!rhs.nonStandardPatches()) {
    848         for (Index k = 0; k < depth; k++) {
    849           block[0] = dm0.loadCoeffStandard(k);
    850           block[1] = dm1.loadCoeffStandard(k);
    851           block[2] = dm2.loadCoeffStandard(k);
    852           block[3] = dm3.loadCoeffStandard(k);
    853           block += 4;
    854         }
    855       } else {
    856         for (Index k = 0; k < depth; k++) {
    857           block[0] = dm0(k);
    858           block[1] = dm1(k);
    859           block[2] = dm2(k);
    860           block[3] = dm3(k);
    861           block += 4;
    862         }
    863       }
    864     }
    865 
    866     // copy the remaining columns one at a time (nr==1)
    867     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
    868       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
    869       for (Index k = 0; k < depth; k++) {
    870         *block = dm0(k);
    871         block += 1;
    872       }
    873     }
    874   }
    875 };
    876 
    877 }  // end namespace internal
    878 
    879 /** SpatialConvolution
    880  * \ingroup CXX11_NeuralNetworks_Module
    881  *
    882  * \brief Applies a 2D convolution over a multichannel input image.
    883  *
    884  * The input parameter is expected to be a tensor with a rank of 3 or more
    885  * (channels, height, width, and optionally others)
    886  * The kernel parameter is expected to be a 4D tensor (filters, channels,
    887  * kernel_height, kernel_width)
    888  * The input and the kernel must both be in col-major layout. The result will
    889  * also be in col-major layout.
    890  *
    891  * If col_in_stride, row_in_stride > 1, then applies convolution with holes
    892  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
    893  * pixels.
    894  *
    895  * The result can be assigned to a tensor of rank equal to the rank of the
    896  * input. The dimensions of the result will be filters, height, width (and
    897  * others if applicable).
    898  *
    899  * It is possible to swap the order of the width and height dimensions provided
    900  * that the same order is used in the input, the kernel, and the output.
    901  *
    902  */
    903 template <typename Input, typename Kernel>
    904 EIGEN_DEVICE_FUNC
    905     EIGEN_ALWAYS_INLINE static const typename internal::conditional<
    906         internal::traits<Input>::Layout == ColMajor,
    907         TensorReshapingOp<
    908             const DSizes<typename internal::traits<Input>::Index,
    909                          internal::traits<Input>::NumDimensions>,
    910             const TensorContractionOp<
    911                 const array<IndexPair<typename internal::traits<Input>::Index>,
    912                             1>,
    913                 const TensorReshapingOp<
    914                     const DSizes<typename internal::traits<Input>::Index, 2>,
    915                     const Kernel>,
    916                 const TensorReshapingOp<
    917                     const DSizes<typename internal::traits<Input>::Index, 2>,
    918                     const TensorImagePatchOp<Dynamic, Dynamic,
    919                                              const Input> > > >,
    920         TensorReshapingOp<
    921             const DSizes<typename internal::traits<Input>::Index,
    922                          internal::traits<Input>::NumDimensions>,
    923             const TensorContractionOp<
    924                 const array<IndexPair<typename internal::traits<Input>::Index>,
    925                             1>,
    926                 const TensorReshapingOp<
    927                     const DSizes<typename internal::traits<Input>::Index, 2>,
    928                     const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
    929                 const TensorReshapingOp<
    930                     const DSizes<typename internal::traits<Input>::Index, 2>,
    931                     const Kernel> > > >::type
    932     SpatialConvolution(const Input& input, const Kernel& kernel,
    933                        const DenseIndex row_stride = 1,
    934                        const DenseIndex col_stride = 1,
    935                        const PaddingType padding_type = PADDING_SAME,
    936                        const DenseIndex row_in_stride = 1,
    937                        const DenseIndex col_in_stride = 1) {
    938   typedef typename internal::traits<Input>::Index TensorIndex;
    939   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
    940                    internal::traits<Input>::NumDimensions,
    941                    internal::traits<Input>::Layout, TensorIndex> >
    942       in(input);
    943   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
    944                    internal::traits<Kernel>::NumDimensions,
    945                    internal::traits<Kernel>::Layout, TensorIndex> >
    946       kern(kernel);
    947 
    948   EIGEN_STATIC_ASSERT(
    949       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
    950       YOU_MADE_A_PROGRAMMING_MISTAKE);
    951   const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
    952 
    953   const int NumDims = internal::traits<Input>::NumDimensions;
    954 
    955   // Number of filters to apply. This is the same as the output depth of the
    956   // result
    957   const TensorIndex kernelFilters =
    958       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
    959   // Number of channels. This is the same as the input depth.
    960   const TensorIndex kernelChannels =
    961       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
    962   const TensorIndex kernelRows =
    963       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
    964   const TensorIndex kernelCols =
    965       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
    966 
    967   const DenseIndex kernelRowsEff =
    968       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
    969   const DenseIndex kernelColsEff =
    970       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
    971 
    972   array<IndexPair<TensorIndex>, 1> contract_dims;
    973   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
    974 
    975   const TensorIndex InputRows =
    976       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
    977   const TensorIndex InputCols =
    978       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
    979 
    980   TensorIndex out_height;
    981   TensorIndex out_width;
    982   switch (padding_type) {
    983     case PADDING_VALID:
    984       out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) /
    985                                 static_cast<float>(row_stride));
    986       out_width = numext::ceil((InputCols - kernelColsEff + 1.f) /
    987                                static_cast<float>(col_stride));
    988       break;
    989     case PADDING_SAME:
    990       out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
    991       out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
    992       break;
    993     default:
    994       // Initialize unused variables to avoid a compiler warning
    995       out_height = 0;
    996       out_width = 0;
    997       eigen_assert(false && "unexpected padding");
    998   }
    999 
   1000   // Molds the output of the patch extraction code into a 2d tensor:
   1001   // - the first dimension (dims[0]): the patch values to be multiplied with the
   1002   // kernels
   1003   // - the second dimension (dims[1]): everything else
   1004   DSizes<TensorIndex, 2> pre_contract_dims;
   1005   if (isColMajor) {
   1006     pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
   1007     pre_contract_dims[1] = out_height * out_width;
   1008     for (int i = 3; i < NumDims; ++i) {
   1009       pre_contract_dims[1] *= in.dimension(i);
   1010     }
   1011   } else {
   1012     pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
   1013     pre_contract_dims[0] = out_height * out_width;
   1014     for (int i = 0; i < NumDims - 3; ++i) {
   1015       pre_contract_dims[0] *= in.dimension(i);
   1016     }
   1017   }
   1018 
   1019   // Molds the output of the contraction into the shape expected by the used
   1020   // (assuming this is ColMajor):
   1021   // - 1st dim: kernel filters
   1022   // - 2nd dim: output height
   1023   // - 3rd dim: output width
   1024   // - 4th dim and beyond: everything else including batch size
   1025   DSizes<TensorIndex, NumDims> post_contract_dims;
   1026   if (isColMajor) {
   1027     post_contract_dims[0] = kernelFilters;
   1028     post_contract_dims[1] = out_height;
   1029     post_contract_dims[2] = out_width;
   1030     for (int i = 3; i < NumDims; ++i) {
   1031       post_contract_dims[i] = in.dimension(i);
   1032     }
   1033   } else {
   1034     post_contract_dims[NumDims - 1] = kernelFilters;
   1035     post_contract_dims[NumDims - 2] = out_height;
   1036     post_contract_dims[NumDims - 3] = out_width;
   1037     for (int i = 0; i < NumDims - 3; ++i) {
   1038       post_contract_dims[i] = in.dimension(i);
   1039     }
   1040   }
   1041 
   1042   DSizes<TensorIndex, 2> kernel_dims;
   1043   if (isColMajor) {
   1044     kernel_dims[0] = kernelFilters;
   1045     kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
   1046   } else {
   1047     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
   1048     kernel_dims[1] = kernelFilters;
   1049   }
   1050   // TODO(yangke): choose() is defined in TensorContraction.h -- consider
   1051   // moving it to somewhere more "common".
   1052   return choose(
   1053       Cond<internal::traits<Input>::Layout == ColMajor>(),
   1054       kernel.reshape(kernel_dims)
   1055           .contract(input
   1056                         .extract_image_patches(
   1057                             kernelRows, kernelCols, row_stride, col_stride,
   1058                             row_in_stride, col_in_stride, padding_type)
   1059                         .reshape(pre_contract_dims),
   1060                     contract_dims)
   1061           .reshape(post_contract_dims),
   1062       input
   1063           .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride,
   1064                                  row_in_stride, col_in_stride, padding_type)
   1065           .reshape(pre_contract_dims)
   1066           .contract(kernel.reshape(kernel_dims), contract_dims)
   1067           .reshape(post_contract_dims));
   1068 }
   1069 
   1070 }  // end namespace Eigen
   1071 
   1072 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_H_
   1073