Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
     17 #define TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 
     21 namespace Eigen {
     22 
     23 // Changes the interpretation of padding in TensorVolumePatchOp to be compatible
     24 // with the rest of TensorFlow (odd padding is split so that more padding is put
     25 // on the right end of the tensor).
     26 template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename ArgType,
     27           typename Device>
     28 struct CustomTensorEvaluator {
     29   typedef TensorVolumePatchOp<Planes, Rows, Cols, ArgType> XprType;
     30   typedef typename XprType::Index Index;
     31   static const int NumInputDims = internal::array_size<
     32       typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
     33   static const int NumDims = NumInputDims + 1;
     34   typedef DSizes<Index, NumDims> Dimensions;
     35   typedef
     36       typename internal::remove_const<typename XprType::Scalar>::type Scalar;
     37   typedef typename XprType::CoeffReturnType CoeffReturnType;
     38   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
     39   static const Index PacketSize =
     40       internal::unpacket_traits<PacketReturnType>::size;
     41 
     42   enum {
     43     IsAligned = false,
     44     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
     45     BlockAccess = false,
     46     Layout = TensorEvaluator<ArgType, Device>::Layout,
     47     CoordAccess = NumDims == 6,
     48     RawAccess = false
     49   };
     50 
     51   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
     52   CustomTensorEvaluator(const XprType& op, const Device& device)
     53       : m_impl(op.expression(), device) {
     54     EIGEN_STATIC_ASSERT(NumDims >= 5, YOU_MADE_A_PROGRAMMING_MISTAKE);
     55 
     56     m_paddingValue = op.padding_value();
     57 
     58     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims =
     59         m_impl.dimensions();
     60 
     61     // Cache a few variables.
     62     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
     63       m_inputDepth = input_dims[0];
     64       m_inputPlanes = input_dims[1];
     65       m_inputRows = input_dims[2];
     66       m_inputCols = input_dims[3];
     67     } else {
     68       m_inputDepth = input_dims[NumInputDims - 1];
     69       m_inputPlanes = input_dims[NumInputDims - 2];
     70       m_inputRows = input_dims[NumInputDims - 3];
     71       m_inputCols = input_dims[NumInputDims - 4];
     72     }
     73 
     74     m_plane_strides = op.plane_strides();
     75     m_row_strides = op.row_strides();
     76     m_col_strides = op.col_strides();
     77 
     78     // Input strides and effective input/patch size
     79     m_in_plane_strides = op.in_plane_strides();
     80     m_in_row_strides = op.in_row_strides();
     81     m_in_col_strides = op.in_col_strides();
     82     m_plane_inflate_strides = op.plane_inflate_strides();
     83     m_row_inflate_strides = op.row_inflate_strides();
     84     m_col_inflate_strides = op.col_inflate_strides();
     85 
     86     // The "effective" spatial size after inflating data with zeros.
     87     m_input_planes_eff = (m_inputPlanes - 1) * m_plane_inflate_strides + 1;
     88     m_input_rows_eff = (m_inputRows - 1) * m_row_inflate_strides + 1;
     89     m_input_cols_eff = (m_inputCols - 1) * m_col_inflate_strides + 1;
     90     m_patch_planes_eff =
     91         op.patch_planes() + (op.patch_planes() - 1) * (m_in_plane_strides - 1);
     92     m_patch_rows_eff =
     93         op.patch_rows() + (op.patch_rows() - 1) * (m_in_row_strides - 1);
     94     m_patch_cols_eff =
     95         op.patch_cols() + (op.patch_cols() - 1) * (m_in_col_strides - 1);
     96 
     97     if (op.padding_explicit()) {
     98       m_outputPlanes = Eigen::divup(
     99           m_input_planes_eff +
    100               static_cast<Index>(op.padding_top_z() + op.padding_bottom_z()) -
    101               m_patch_planes_eff + 1,
    102           m_plane_strides);
    103       m_outputRows = Eigen::divup(
    104           m_input_rows_eff +
    105               static_cast<Index>(op.padding_top() + op.padding_bottom()) -
    106               m_patch_rows_eff + 1,
    107           m_row_strides);
    108       m_outputCols = Eigen::divup(
    109           m_input_cols_eff +
    110               static_cast<Index>(op.padding_left() + op.padding_right()) -
    111               m_patch_cols_eff + 1,
    112           m_col_strides);
    113       m_planePaddingTop = op.padding_top_z();
    114       m_rowPaddingTop = op.padding_top();
    115       m_colPaddingLeft = op.padding_left();
    116     } else {
    117       // Computing padding from the type
    118       switch (op.padding_type()) {
    119         case PADDING_VALID:
    120           m_outputPlanes = Eigen::divup(
    121               m_input_planes_eff - m_patch_planes_eff + 1, m_plane_strides);
    122           m_outputRows = Eigen::divup(m_input_rows_eff - m_patch_rows_eff + 1,
    123                                       m_row_strides);
    124           m_outputCols = Eigen::divup(m_input_cols_eff - m_patch_cols_eff + 1,
    125                                       m_col_strides);
    126           m_planePaddingTop = 0;
    127           m_rowPaddingTop = 0;
    128           m_colPaddingLeft = 0;
    129           break;
    130         case PADDING_SAME: {
    131           m_outputPlanes = Eigen::divup(m_input_planes_eff, m_plane_strides);
    132           m_outputRows = Eigen::divup(m_input_rows_eff, m_row_strides);
    133           m_outputCols = Eigen::divup(m_input_cols_eff, m_col_strides);
    134           const Index dz = numext::maxi<DenseIndex>(
    135               0, (m_outputPlanes - 1) * m_plane_strides + m_patch_planes_eff -
    136                      m_input_planes_eff);
    137           const Index dy = numext::maxi<DenseIndex>(
    138               0, (m_outputRows - 1) * m_row_strides + m_patch_rows_eff -
    139                      m_input_rows_eff);
    140           const Index dx = numext::maxi<DenseIndex>(
    141               0, (m_outputCols - 1) * m_col_strides + m_patch_cols_eff -
    142                      m_input_cols_eff);
    143           m_planePaddingTop = dz / 2;
    144           m_rowPaddingTop = dy / 2;
    145           m_colPaddingLeft = dx / 2;
    146           break;
    147         }
    148         default:
    149           eigen_assert(false && "unexpected padding");
    150       }
    151     }
    152     eigen_assert(m_outputRows > 0);
    153     eigen_assert(m_outputCols > 0);
    154     eigen_assert(m_outputPlanes > 0);
    155 
    156     // Dimensions for result of extraction.
    157     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    158       // ColMajor
    159       // 0: depth
    160       // 1: patch_planes
    161       // 2: patch_rows
    162       // 3: patch_cols
    163       // 4: number of patches
    164       // 5 and beyond: anything else (such as batch).
    165       m_dimensions[0] = input_dims[0];
    166       m_dimensions[1] = op.patch_planes();
    167       m_dimensions[2] = op.patch_rows();
    168       m_dimensions[3] = op.patch_cols();
    169       m_dimensions[4] = m_outputPlanes * m_outputRows * m_outputCols;
    170       for (int i = 5; i < NumDims; ++i) {
    171         m_dimensions[i] = input_dims[i - 1];
    172       }
    173     } else {
    174       // RowMajor
    175       // NumDims-1: depth
    176       // NumDims-2: patch_planes
    177       // NumDims-3: patch_rows
    178       // NumDims-4: patch_cols
    179       // NumDims-5: number of patches
    180       // NumDims-6 and beyond: anything else (such as batch).
    181       m_dimensions[NumDims - 1] = input_dims[NumInputDims - 1];
    182       m_dimensions[NumDims - 2] = op.patch_planes();
    183       m_dimensions[NumDims - 3] = op.patch_rows();
    184       m_dimensions[NumDims - 4] = op.patch_cols();
    185       m_dimensions[NumDims - 5] = m_outputPlanes * m_outputRows * m_outputCols;
    186       for (int i = NumDims - 6; i >= 0; --i) {
    187         m_dimensions[i] = input_dims[i];
    188       }
    189     }
    190 
    191     // Strides for the output tensor.
    192     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    193       m_rowStride = m_dimensions[1];
    194       m_colStride = m_dimensions[2] * m_rowStride;
    195       m_patchStride = m_colStride * m_dimensions[3] * m_dimensions[0];
    196       m_otherStride = m_patchStride * m_dimensions[4];
    197     } else {
    198       m_rowStride = m_dimensions[NumDims - 2];
    199       m_colStride = m_dimensions[NumDims - 3] * m_rowStride;
    200       m_patchStride =
    201           m_colStride * m_dimensions[NumDims - 4] * m_dimensions[NumDims - 1];
    202       m_otherStride = m_patchStride * m_dimensions[NumDims - 5];
    203     }
    204 
    205     // Strides for navigating through the input tensor.
    206     m_planeInputStride = m_inputDepth;
    207     m_rowInputStride = m_inputDepth * m_inputPlanes;
    208     m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
    209     m_otherInputStride =
    210         m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
    211 
    212     m_outputPlanesRows = m_outputPlanes * m_outputRows;
    213 
    214     // Fast representations of different variables.
    215     m_fastOtherStride = internal::TensorIntDivisor<Index>(m_otherStride);
    216     m_fastPatchStride = internal::TensorIntDivisor<Index>(m_patchStride);
    217     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
    218     m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
    219     m_fastInputRowStride =
    220         internal::TensorIntDivisor<Index>(m_row_inflate_strides);
    221     m_fastInputColStride =
    222         internal::TensorIntDivisor<Index>(m_col_inflate_strides);
    223     m_fastInputPlaneStride =
    224         internal::TensorIntDivisor<Index>(m_plane_inflate_strides);
    225     m_fastInputColsEff = internal::TensorIntDivisor<Index>(m_input_cols_eff);
    226     m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
    227     m_fastOutputPlanesRows =
    228         internal::TensorIntDivisor<Index>(m_outputPlanesRows);
    229 
    230     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    231       m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[0]);
    232     } else {
    233       m_fastOutputDepth =
    234           internal::TensorIntDivisor<Index>(m_dimensions[NumDims - 1]);
    235     }
    236   }
    237 
    238   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
    239     return m_dimensions;
    240   }
    241 
    242   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(
    243       Scalar* /*data*/) {
    244     m_impl.evalSubExprsIfNeeded(NULL);
    245     return true;
    246   }
    247 
    248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
    249 
    250   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType
    251   coeff(Index index) const {
    252     // Patch index corresponding to the passed in index.
    253     const Index patchIndex = index / m_fastPatchStride;
    254 
    255     // Spatial offset within the patch. This has to be translated into 3D
    256     // coordinates within the patch.
    257     const Index patchOffset =
    258         (index - patchIndex * m_patchStride) / m_fastOutputDepth;
    259 
    260     // Batch, etc.
    261     const Index otherIndex = (NumDims == 5) ? 0 : index / m_fastOtherStride;
    262     const Index patch3DIndex =
    263         (NumDims == 5)
    264             ? patchIndex
    265             : (index - otherIndex * m_otherStride) / m_fastPatchStride;
    266 
    267     // Calculate column index in the input original tensor.
    268     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
    269     const Index colOffset = patchOffset / m_fastColStride;
    270     const Index inputCol = colIndex * m_col_strides +
    271                            colOffset * m_in_col_strides - m_colPaddingLeft;
    272     const Index origInputCol =
    273         (m_col_inflate_strides == 1)
    274             ? inputCol
    275             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
    276     if (inputCol < 0 || inputCol >= m_input_cols_eff ||
    277         ((m_col_inflate_strides != 1) &&
    278          (inputCol != origInputCol * m_col_inflate_strides))) {
    279       return Scalar(m_paddingValue);
    280     }
    281 
    282     // Calculate row index in the original input tensor.
    283     const Index rowIndex =
    284         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
    285     const Index rowOffset =
    286         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
    287     const Index inputRow = rowIndex * m_row_strides +
    288                            rowOffset * m_in_row_strides - m_rowPaddingTop;
    289     const Index origInputRow =
    290         (m_row_inflate_strides == 1)
    291             ? inputRow
    292             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
    293     if (inputRow < 0 || inputRow >= m_input_rows_eff ||
    294         ((m_row_inflate_strides != 1) &&
    295          (inputRow != origInputRow * m_row_inflate_strides))) {
    296       return Scalar(m_paddingValue);
    297     }
    298 
    299     // Calculate plane index in the original input tensor.
    300     const Index planeIndex =
    301         (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex));
    302     const Index planeOffset =
    303         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
    304     const Index inputPlane = planeIndex * m_plane_strides +
    305                              planeOffset * m_in_plane_strides -
    306                              m_planePaddingTop;
    307     const Index origInputPlane =
    308         (m_plane_inflate_strides == 1)
    309             ? inputPlane
    310             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
    311     if (inputPlane < 0 || inputPlane >= m_input_planes_eff ||
    312         ((m_plane_inflate_strides != 1) &&
    313          (inputPlane != origInputPlane * m_plane_inflate_strides))) {
    314       return Scalar(m_paddingValue);
    315     }
    316 
    317     const int depth_index =
    318         static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0
    319                                                                : NumDims - 1;
    320     const Index depth =
    321         index - (index / m_fastOutputDepth) * m_dimensions[depth_index];
    322 
    323     const Index inputIndex = depth + origInputRow * m_rowInputStride +
    324                              origInputCol * m_colInputStride +
    325                              origInputPlane * m_planeInputStride +
    326                              otherIndex * m_otherInputStride;
    327 
    328     return m_impl.coeff(inputIndex);
    329   }
    330 
    331   template <int LoadMode>
    332   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType
    333   packet(Index index) const {
    334     EIGEN_STATIC_ASSERT(PacketSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
    335     eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
    336 
    337     if (m_in_row_strides != 1 || m_in_col_strides != 1 ||
    338         m_row_inflate_strides != 1 || m_col_inflate_strides != 1 ||
    339         m_in_plane_strides != 1 || m_plane_inflate_strides != 1) {
    340       return packetWithPossibleZero(index);
    341     }
    342 
    343     const Index indices[2] = {index, index + PacketSize - 1};
    344     const Index patchIndex = indices[0] / m_fastPatchStride;
    345     if (patchIndex != indices[1] / m_fastPatchStride) {
    346       return packetWithPossibleZero(index);
    347     }
    348     const Index otherIndex =
    349         (NumDims == 5) ? 0 : indices[0] / m_fastOtherStride;
    350     eigen_assert(otherIndex == indices[1] / m_fastOtherStride);
    351 
    352     // Find the offset of the element wrt the location of the first element.
    353     const Index patchOffsets[2] = {
    354         (indices[0] - patchIndex * m_patchStride) / m_fastOutputDepth,
    355         (indices[1] - patchIndex * m_patchStride) / m_fastOutputDepth};
    356 
    357     const Index patch3DIndex =
    358         (NumDims == 5)
    359             ? patchIndex
    360             : (indices[0] - otherIndex * m_otherStride) / m_fastPatchStride;
    361     eigen_assert(patch3DIndex ==
    362                  (indices[1] - otherIndex * m_otherStride) / m_fastPatchStride);
    363 
    364     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
    365     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
    366                                  patchOffsets[1] / m_fastColStride};
    367 
    368     // Calculate col indices in the original input tensor.
    369     const Index inputCols[2] = {
    370         colIndex * m_col_strides + colOffsets[0] - m_colPaddingLeft,
    371         colIndex * m_col_strides + colOffsets[1] - m_colPaddingLeft};
    372     if (inputCols[1] < 0 || inputCols[0] >= m_inputCols) {
    373       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
    374     }
    375 
    376     if (inputCols[0] != inputCols[1]) {
    377       return packetWithPossibleZero(index);
    378     }
    379 
    380     const Index rowIndex =
    381         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
    382     const Index rowOffsets[2] = {
    383         (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
    384         (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
    385     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
    386     // Calculate col indices in the original input tensor.
    387     const Index inputRows[2] = {
    388         rowIndex * m_row_strides + rowOffsets[0] - m_rowPaddingTop,
    389         rowIndex * m_row_strides + rowOffsets[1] - m_rowPaddingTop};
    390 
    391     if (inputRows[1] < 0 || inputRows[0] >= m_inputRows) {
    392       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
    393     }
    394 
    395     if (inputRows[0] != inputRows[1]) {
    396       return packetWithPossibleZero(index);
    397     }
    398 
    399     const Index planeIndex =
    400         (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex));
    401     const Index planeOffsets[2] = {
    402         patchOffsets[0] - colOffsets[0] * m_colStride -
    403             rowOffsets[0] * m_rowStride,
    404         patchOffsets[1] - colOffsets[1] * m_colStride -
    405             rowOffsets[1] * m_rowStride};
    406     eigen_assert(planeOffsets[0] <= planeOffsets[1]);
    407     const Index inputPlanes[2] = {
    408         planeIndex * m_plane_strides + planeOffsets[0] - m_planePaddingTop,
    409         planeIndex * m_plane_strides + planeOffsets[1] - m_planePaddingTop};
    410 
    411     if (inputPlanes[1] < 0 || inputPlanes[0] >= m_inputPlanes) {
    412       return internal::pset1<PacketReturnType>(Scalar(m_paddingValue));
    413     }
    414 
    415     if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
    416       // no padding
    417       const int depth_index =
    418           static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0
    419                                                                  : NumDims - 1;
    420       const Index depth =
    421           index - (index / m_fastOutputDepth) * m_dimensions[depth_index];
    422       const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
    423                                inputCols[0] * m_colInputStride +
    424                                m_planeInputStride * inputPlanes[0] +
    425                                otherIndex * m_otherInputStride;
    426       return m_impl.template packet<Unaligned>(inputIndex);
    427     }
    428 
    429     return packetWithPossibleZero(index);
    430   }
    431 
    432   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    433   costPerCoeff(bool vectorized) const {
    434     const double compute_cost = 10 * TensorOpCost::DivCost<Index>() +
    435                                 21 * TensorOpCost::MulCost<Index>() +
    436                                 8 * TensorOpCost::AddCost<Index>();
    437     return TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
    438   }
    439 
    440   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
    441 
    442   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
    443 
    444   Index planePaddingTop() const { return m_planePaddingTop; }
    445   Index rowPaddingTop() const { return m_rowPaddingTop; }
    446   Index colPaddingLeft() const { return m_colPaddingLeft; }
    447   Index outputPlanes() const { return m_outputPlanes; }
    448   Index outputRows() const { return m_outputRows; }
    449   Index outputCols() const { return m_outputCols; }
    450   Index userPlaneStride() const { return m_plane_strides; }
    451   Index userRowStride() const { return m_row_strides; }
    452   Index userColStride() const { return m_col_strides; }
    453   Index userInPlaneStride() const { return m_in_plane_strides; }
    454   Index userInRowStride() const { return m_in_row_strides; }
    455   Index userInColStride() const { return m_in_col_strides; }
    456   Index planeInflateStride() const { return m_plane_inflate_strides; }
    457   Index rowInflateStride() const { return m_row_inflate_strides; }
    458   Index colInflateStride() const { return m_col_inflate_strides; }
    459 
    460   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType
    461   coeff(const array<Index, NumDims>& coords) const {
    462     // ColMajor
    463     //   0: depth, 1: patch_planes, 2: patch_rows, 3: patch_cols, 4: number of
    464     //   patches, 5: batches
    465     // RowMajor
    466     //   0: batches, 1: number of patches, 2: patch_cols , 3: patch_rows, 4:
    467     //   patch_planes, 5: depth
    468     const Index patch3DIndex =
    469         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 4 : 1];
    470     const Index colOffset =
    471         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 3 : 2];
    472     const Index rowOffset =
    473         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 2 : 3];
    474     const Index planeOffset =
    475         coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 4];
    476 
    477     array<Index, NumDims - 1> inputCoords;
    478 
    479     const Index colIndex = patch3DIndex / m_fastOutputPlanesRows;
    480     const Index inputCol = colIndex * m_col_strides +
    481                            colOffset * m_in_col_strides - m_colPaddingLeft;
    482     const Index origInputCol =
    483         (m_col_inflate_strides == 1)
    484             ? inputCol
    485             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
    486     if (inputCol < 0 || inputCol >= m_input_cols_eff ||
    487         ((m_col_inflate_strides != 1) &&
    488          (inputCol != origInputCol * m_col_inflate_strides))) {
    489       return Scalar(m_paddingValue);
    490     }
    491 
    492     const Index rowIndex =
    493         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
    494     const Index inputRow = rowIndex * m_row_strides +
    495                            rowOffset * m_in_row_strides - m_rowPaddingTop;
    496     const Index origInputRow =
    497         (m_row_inflate_strides == 1)
    498             ? inputRow
    499             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
    500     if (inputRow < 0 || inputRow >= m_input_rows_eff ||
    501         ((m_row_inflate_strides != 1) &&
    502          (inputRow != origInputRow * m_row_inflate_strides))) {
    503       return Scalar(m_paddingValue);
    504     }
    505 
    506     const Index planeIndex =
    507         patch3DIndex - colIndex * m_outputPlanesRows - rowIndex * m_outputRows;
    508     const Index inputPlane = planeIndex * m_plane_strides +
    509                              planeOffset * m_in_plane_strides -
    510                              m_planePaddingTop;
    511     const Index origInputPlane =
    512         (m_plane_inflate_strides == 1)
    513             ? inputPlane
    514             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
    515     if (inputPlane < 0 || inputPlane >= m_input_planes_eff ||
    516         ((m_plane_inflate_strides != 1) &&
    517          (inputPlane != origInputPlane * m_plane_inflate_strides))) {
    518       return Scalar(m_paddingValue);
    519     }
    520 
    521     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    522       inputCoords[0] = coords[0];  // depth
    523       inputCoords[1] = origInputPlane;
    524       inputCoords[2] = origInputRow;
    525       inputCoords[3] = origInputCol;
    526       inputCoords[4] = coords[5];  // batch
    527     } else {
    528       inputCoords[4] = coords[5];  // depth
    529       inputCoords[3] = origInputPlane;
    530       inputCoords[2] = origInputRow;
    531       inputCoords[1] = origInputCol;
    532       inputCoords[0] = coords[0];  // batch
    533     }
    534     if (TensorEvaluator<ArgType, Device>::CoordAccess) {
    535       return m_impl.coeff(inputCoords);
    536     } else {
    537       Index inputIndex;
    538       if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    539         inputIndex = inputCoords[4] * m_otherInputStride +
    540                      inputCoords[3] * m_colInputStride +
    541                      inputCoords[2] * m_rowInputStride +
    542                      inputCoords[1] * m_planeInputStride + inputCoords[0];
    543       } else {
    544         inputIndex = inputCoords[0] * m_otherInputStride +
    545                      inputCoords[1] * m_colInputStride +
    546                      inputCoords[2] * m_rowInputStride +
    547                      inputCoords[3] * m_planeInputStride + inputCoords[4];
    548       }
    549       return m_impl.coeff(inputIndex);
    550     }
    551   }
    552 
    553  protected:
    554   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType
    555   packetWithPossibleZero(Index index) const {
    556     EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type
    557         values[PacketSize];
    558     for (int i = 0; i < PacketSize; ++i) {
    559       values[i] = coeff(index + i);
    560     }
    561     PacketReturnType rslt = internal::pload<PacketReturnType>(values);
    562     return rslt;
    563   }
    564 
    565   Dimensions m_dimensions;
    566 
    567   // Parameters passed to the costructor.
    568   Index m_plane_strides;
    569   Index m_row_strides;
    570   Index m_col_strides;
    571 
    572   Index m_outputPlanes;
    573   Index m_outputRows;
    574   Index m_outputCols;
    575 
    576   Index m_planePaddingTop;
    577   Index m_rowPaddingTop;
    578   Index m_colPaddingLeft;
    579 
    580   Index m_in_plane_strides;
    581   Index m_in_row_strides;
    582   Index m_in_col_strides;
    583 
    584   Index m_plane_inflate_strides;
    585   Index m_row_inflate_strides;
    586   Index m_col_inflate_strides;
    587 
    588   // Cached input size.
    589   Index m_inputDepth;
    590   Index m_inputPlanes;
    591   Index m_inputRows;
    592   Index m_inputCols;
    593 
    594   // Other cached variables.
    595   Index m_outputPlanesRows;
    596 
    597   // Effective input/patch post-inflation size.
    598   Index m_input_planes_eff;
    599   Index m_input_rows_eff;
    600   Index m_input_cols_eff;
    601   Index m_patch_planes_eff;
    602   Index m_patch_rows_eff;
    603   Index m_patch_cols_eff;
    604 
    605   // Strides for the output tensor.
    606   Index m_otherStride;
    607   Index m_patchStride;
    608   Index m_rowStride;
    609   Index m_colStride;
    610 
    611   // Strides for the input tensor.
    612   Index m_planeInputStride;
    613   Index m_rowInputStride;
    614   Index m_colInputStride;
    615   Index m_otherInputStride;
    616 
    617   internal::TensorIntDivisor<Index> m_fastOtherStride;
    618   internal::TensorIntDivisor<Index> m_fastPatchStride;
    619   internal::TensorIntDivisor<Index> m_fastColStride;
    620   internal::TensorIntDivisor<Index> m_fastRowStride;
    621   internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
    622   internal::TensorIntDivisor<Index> m_fastInputRowStride;
    623   internal::TensorIntDivisor<Index> m_fastInputColStride;
    624   internal::TensorIntDivisor<Index> m_fastInputColsEff;
    625   internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
    626   internal::TensorIntDivisor<Index> m_fastOutputPlanes;
    627   internal::TensorIntDivisor<Index> m_fastOutputDepth;
    628 
    629   Scalar m_paddingValue;
    630 
    631   TensorEvaluator<ArgType, Device> m_impl;
    632 };
    633 
    634 // Override the default TensorEvaluator for TensorVolumePatchOp for CPU.
    635 #define OVERRIDE_EVALUATOR(Device)                                          \
    636   template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols,            \
    637             typename ArgType>                                               \
    638   struct TensorEvaluator<                                                   \
    639       const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, Device>       \
    640       : public CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device> { \
    641     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(                  \
    642         const typename CustomTensorEvaluator<Planes, Rows, Cols, ArgType,   \
    643                                              Device>::XprType& op,          \
    644         const Device& device)                                               \
    645         : CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device>(       \
    646               op, device) {}                                                \
    647   };
    648 
    649 OVERRIDE_EVALUATOR(Eigen::ThreadPoolDevice);
    650 OVERRIDE_EVALUATOR(Eigen::DefaultDevice);
    651 
    652 #undef OVERRIDE_EVALUATOR
    653 
    654 };  // namespace Eigen
    655 
    656 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_
    657