1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 namespace Eigen { 22 23 // Changes the interpretation of padding in TensorVolumePatchOp to be compatible 24 // with the rest of TensorFlow (odd padding is split so that more padding is put 25 // on the right end of the tensor). 26 template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, typename ArgType, 27 typename Device> 28 struct CustomTensorEvaluator { 29 typedef TensorVolumePatchOp<Planes, Rows, Cols, ArgType> XprType; 30 typedef typename XprType::Index Index; 31 static const int NumInputDims = internal::array_size< 32 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 33 static const int NumDims = NumInputDims + 1; 34 typedef DSizes<Index, NumDims> Dimensions; 35 typedef 36 typename internal::remove_const<typename XprType::Scalar>::type Scalar; 37 typedef typename XprType::CoeffReturnType CoeffReturnType; 38 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 39 static const Index PacketSize = 40 internal::unpacket_traits<PacketReturnType>::size; 41 42 enum { 43 IsAligned = false, 44 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, 45 BlockAccess = false, 46 Layout = TensorEvaluator<ArgType, Device>::Layout, 47 CoordAccess = NumDims == 6, 48 RawAccess = false 49 }; 50 51 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 52 CustomTensorEvaluator(const XprType& op, const Device& device) 53 : m_impl(op.expression(), device) { 54 EIGEN_STATIC_ASSERT(NumDims >= 5, YOU_MADE_A_PROGRAMMING_MISTAKE); 55 56 m_paddingValue = op.padding_value(); 57 58 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = 59 m_impl.dimensions(); 60 61 // Cache a few variables. 62 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 63 m_inputDepth = input_dims[0]; 64 m_inputPlanes = input_dims[1]; 65 m_inputRows = input_dims[2]; 66 m_inputCols = input_dims[3]; 67 } else { 68 m_inputDepth = input_dims[NumInputDims - 1]; 69 m_inputPlanes = input_dims[NumInputDims - 2]; 70 m_inputRows = input_dims[NumInputDims - 3]; 71 m_inputCols = input_dims[NumInputDims - 4]; 72 } 73 74 m_plane_strides = op.plane_strides(); 75 m_row_strides = op.row_strides(); 76 m_col_strides = op.col_strides(); 77 78 // Input strides and effective input/patch size 79 m_in_plane_strides = op.in_plane_strides(); 80 m_in_row_strides = op.in_row_strides(); 81 m_in_col_strides = op.in_col_strides(); 82 m_plane_inflate_strides = op.plane_inflate_strides(); 83 m_row_inflate_strides = op.row_inflate_strides(); 84 m_col_inflate_strides = op.col_inflate_strides(); 85 86 // The "effective" spatial size after inflating data with zeros. 87 m_input_planes_eff = (m_inputPlanes - 1) * m_plane_inflate_strides + 1; 88 m_input_rows_eff = (m_inputRows - 1) * m_row_inflate_strides + 1; 89 m_input_cols_eff = (m_inputCols - 1) * m_col_inflate_strides + 1; 90 m_patch_planes_eff = 91 op.patch_planes() + (op.patch_planes() - 1) * (m_in_plane_strides - 1); 92 m_patch_rows_eff = 93 op.patch_rows() + (op.patch_rows() - 1) * (m_in_row_strides - 1); 94 m_patch_cols_eff = 95 op.patch_cols() + (op.patch_cols() - 1) * (m_in_col_strides - 1); 96 97 if (op.padding_explicit()) { 98 m_outputPlanes = Eigen::divup( 99 m_input_planes_eff + 100 static_cast<Index>(op.padding_top_z() + op.padding_bottom_z()) - 101 m_patch_planes_eff + 1, 102 m_plane_strides); 103 m_outputRows = Eigen::divup( 104 m_input_rows_eff + 105 static_cast<Index>(op.padding_top() + op.padding_bottom()) - 106 m_patch_rows_eff + 1, 107 m_row_strides); 108 m_outputCols = Eigen::divup( 109 m_input_cols_eff + 110 static_cast<Index>(op.padding_left() + op.padding_right()) - 111 m_patch_cols_eff + 1, 112 m_col_strides); 113 m_planePaddingTop = op.padding_top_z(); 114 m_rowPaddingTop = op.padding_top(); 115 m_colPaddingLeft = op.padding_left(); 116 } else { 117 // Computing padding from the type 118 switch (op.padding_type()) { 119 case PADDING_VALID: 120 m_outputPlanes = Eigen::divup( 121 m_input_planes_eff - m_patch_planes_eff + 1, m_plane_strides); 122 m_outputRows = Eigen::divup(m_input_rows_eff - m_patch_rows_eff + 1, 123 m_row_strides); 124 m_outputCols = Eigen::divup(m_input_cols_eff - m_patch_cols_eff + 1, 125 m_col_strides); 126 m_planePaddingTop = 0; 127 m_rowPaddingTop = 0; 128 m_colPaddingLeft = 0; 129 break; 130 case PADDING_SAME: { 131 m_outputPlanes = Eigen::divup(m_input_planes_eff, m_plane_strides); 132 m_outputRows = Eigen::divup(m_input_rows_eff, m_row_strides); 133 m_outputCols = Eigen::divup(m_input_cols_eff, m_col_strides); 134 const Index dz = numext::maxi<DenseIndex>( 135 0, (m_outputPlanes - 1) * m_plane_strides + m_patch_planes_eff - 136 m_input_planes_eff); 137 const Index dy = numext::maxi<DenseIndex>( 138 0, (m_outputRows - 1) * m_row_strides + m_patch_rows_eff - 139 m_input_rows_eff); 140 const Index dx = numext::maxi<DenseIndex>( 141 0, (m_outputCols - 1) * m_col_strides + m_patch_cols_eff - 142 m_input_cols_eff); 143 m_planePaddingTop = dz / 2; 144 m_rowPaddingTop = dy / 2; 145 m_colPaddingLeft = dx / 2; 146 break; 147 } 148 default: 149 eigen_assert(false && "unexpected padding"); 150 } 151 } 152 eigen_assert(m_outputRows > 0); 153 eigen_assert(m_outputCols > 0); 154 eigen_assert(m_outputPlanes > 0); 155 156 // Dimensions for result of extraction. 157 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 158 // ColMajor 159 // 0: depth 160 // 1: patch_planes 161 // 2: patch_rows 162 // 3: patch_cols 163 // 4: number of patches 164 // 5 and beyond: anything else (such as batch). 165 m_dimensions[0] = input_dims[0]; 166 m_dimensions[1] = op.patch_planes(); 167 m_dimensions[2] = op.patch_rows(); 168 m_dimensions[3] = op.patch_cols(); 169 m_dimensions[4] = m_outputPlanes * m_outputRows * m_outputCols; 170 for (int i = 5; i < NumDims; ++i) { 171 m_dimensions[i] = input_dims[i - 1]; 172 } 173 } else { 174 // RowMajor 175 // NumDims-1: depth 176 // NumDims-2: patch_planes 177 // NumDims-3: patch_rows 178 // NumDims-4: patch_cols 179 // NumDims-5: number of patches 180 // NumDims-6 and beyond: anything else (such as batch). 181 m_dimensions[NumDims - 1] = input_dims[NumInputDims - 1]; 182 m_dimensions[NumDims - 2] = op.patch_planes(); 183 m_dimensions[NumDims - 3] = op.patch_rows(); 184 m_dimensions[NumDims - 4] = op.patch_cols(); 185 m_dimensions[NumDims - 5] = m_outputPlanes * m_outputRows * m_outputCols; 186 for (int i = NumDims - 6; i >= 0; --i) { 187 m_dimensions[i] = input_dims[i]; 188 } 189 } 190 191 // Strides for the output tensor. 192 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 193 m_rowStride = m_dimensions[1]; 194 m_colStride = m_dimensions[2] * m_rowStride; 195 m_patchStride = m_colStride * m_dimensions[3] * m_dimensions[0]; 196 m_otherStride = m_patchStride * m_dimensions[4]; 197 } else { 198 m_rowStride = m_dimensions[NumDims - 2]; 199 m_colStride = m_dimensions[NumDims - 3] * m_rowStride; 200 m_patchStride = 201 m_colStride * m_dimensions[NumDims - 4] * m_dimensions[NumDims - 1]; 202 m_otherStride = m_patchStride * m_dimensions[NumDims - 5]; 203 } 204 205 // Strides for navigating through the input tensor. 206 m_planeInputStride = m_inputDepth; 207 m_rowInputStride = m_inputDepth * m_inputPlanes; 208 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; 209 m_otherInputStride = 210 m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; 211 212 m_outputPlanesRows = m_outputPlanes * m_outputRows; 213 214 // Fast representations of different variables. 215 m_fastOtherStride = internal::TensorIntDivisor<Index>(m_otherStride); 216 m_fastPatchStride = internal::TensorIntDivisor<Index>(m_patchStride); 217 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 218 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); 219 m_fastInputRowStride = 220 internal::TensorIntDivisor<Index>(m_row_inflate_strides); 221 m_fastInputColStride = 222 internal::TensorIntDivisor<Index>(m_col_inflate_strides); 223 m_fastInputPlaneStride = 224 internal::TensorIntDivisor<Index>(m_plane_inflate_strides); 225 m_fastInputColsEff = internal::TensorIntDivisor<Index>(m_input_cols_eff); 226 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); 227 m_fastOutputPlanesRows = 228 internal::TensorIntDivisor<Index>(m_outputPlanesRows); 229 230 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 231 m_fastOutputDepth = internal::TensorIntDivisor<Index>(m_dimensions[0]); 232 } else { 233 m_fastOutputDepth = 234 internal::TensorIntDivisor<Index>(m_dimensions[NumDims - 1]); 235 } 236 } 237 238 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { 239 return m_dimensions; 240 } 241 242 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded( 243 Scalar* /*data*/) { 244 m_impl.evalSubExprsIfNeeded(NULL); 245 return true; 246 } 247 248 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } 249 250 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType 251 coeff(Index index) const { 252 // Patch index corresponding to the passed in index. 253 const Index patchIndex = index / m_fastPatchStride; 254 255 // Spatial offset within the patch. This has to be translated into 3D 256 // coordinates within the patch. 257 const Index patchOffset = 258 (index - patchIndex * m_patchStride) / m_fastOutputDepth; 259 260 // Batch, etc. 261 const Index otherIndex = (NumDims == 5) ? 0 : index / m_fastOtherStride; 262 const Index patch3DIndex = 263 (NumDims == 5) 264 ? patchIndex 265 : (index - otherIndex * m_otherStride) / m_fastPatchStride; 266 267 // Calculate column index in the input original tensor. 268 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 269 const Index colOffset = patchOffset / m_fastColStride; 270 const Index inputCol = colIndex * m_col_strides + 271 colOffset * m_in_col_strides - m_colPaddingLeft; 272 const Index origInputCol = 273 (m_col_inflate_strides == 1) 274 ? inputCol 275 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 276 if (inputCol < 0 || inputCol >= m_input_cols_eff || 277 ((m_col_inflate_strides != 1) && 278 (inputCol != origInputCol * m_col_inflate_strides))) { 279 return Scalar(m_paddingValue); 280 } 281 282 // Calculate row index in the original input tensor. 283 const Index rowIndex = 284 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 285 const Index rowOffset = 286 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 287 const Index inputRow = rowIndex * m_row_strides + 288 rowOffset * m_in_row_strides - m_rowPaddingTop; 289 const Index origInputRow = 290 (m_row_inflate_strides == 1) 291 ? inputRow 292 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 293 if (inputRow < 0 || inputRow >= m_input_rows_eff || 294 ((m_row_inflate_strides != 1) && 295 (inputRow != origInputRow * m_row_inflate_strides))) { 296 return Scalar(m_paddingValue); 297 } 298 299 // Calculate plane index in the original input tensor. 300 const Index planeIndex = 301 (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex)); 302 const Index planeOffset = 303 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 304 const Index inputPlane = planeIndex * m_plane_strides + 305 planeOffset * m_in_plane_strides - 306 m_planePaddingTop; 307 const Index origInputPlane = 308 (m_plane_inflate_strides == 1) 309 ? inputPlane 310 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 311 if (inputPlane < 0 || inputPlane >= m_input_planes_eff || 312 ((m_plane_inflate_strides != 1) && 313 (inputPlane != origInputPlane * m_plane_inflate_strides))) { 314 return Scalar(m_paddingValue); 315 } 316 317 const int depth_index = 318 static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 319 : NumDims - 1; 320 const Index depth = 321 index - (index / m_fastOutputDepth) * m_dimensions[depth_index]; 322 323 const Index inputIndex = depth + origInputRow * m_rowInputStride + 324 origInputCol * m_colInputStride + 325 origInputPlane * m_planeInputStride + 326 otherIndex * m_otherInputStride; 327 328 return m_impl.coeff(inputIndex); 329 } 330 331 template <int LoadMode> 332 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType 333 packet(Index index) const { 334 EIGEN_STATIC_ASSERT(PacketSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 335 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize()); 336 337 if (m_in_row_strides != 1 || m_in_col_strides != 1 || 338 m_row_inflate_strides != 1 || m_col_inflate_strides != 1 || 339 m_in_plane_strides != 1 || m_plane_inflate_strides != 1) { 340 return packetWithPossibleZero(index); 341 } 342 343 const Index indices[2] = {index, index + PacketSize - 1}; 344 const Index patchIndex = indices[0] / m_fastPatchStride; 345 if (patchIndex != indices[1] / m_fastPatchStride) { 346 return packetWithPossibleZero(index); 347 } 348 const Index otherIndex = 349 (NumDims == 5) ? 0 : indices[0] / m_fastOtherStride; 350 eigen_assert(otherIndex == indices[1] / m_fastOtherStride); 351 352 // Find the offset of the element wrt the location of the first element. 353 const Index patchOffsets[2] = { 354 (indices[0] - patchIndex * m_patchStride) / m_fastOutputDepth, 355 (indices[1] - patchIndex * m_patchStride) / m_fastOutputDepth}; 356 357 const Index patch3DIndex = 358 (NumDims == 5) 359 ? patchIndex 360 : (indices[0] - otherIndex * m_otherStride) / m_fastPatchStride; 361 eigen_assert(patch3DIndex == 362 (indices[1] - otherIndex * m_otherStride) / m_fastPatchStride); 363 364 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 365 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 366 patchOffsets[1] / m_fastColStride}; 367 368 // Calculate col indices in the original input tensor. 369 const Index inputCols[2] = { 370 colIndex * m_col_strides + colOffsets[0] - m_colPaddingLeft, 371 colIndex * m_col_strides + colOffsets[1] - m_colPaddingLeft}; 372 if (inputCols[1] < 0 || inputCols[0] >= m_inputCols) { 373 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 374 } 375 376 if (inputCols[0] != inputCols[1]) { 377 return packetWithPossibleZero(index); 378 } 379 380 const Index rowIndex = 381 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 382 const Index rowOffsets[2] = { 383 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 384 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 385 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 386 // Calculate col indices in the original input tensor. 387 const Index inputRows[2] = { 388 rowIndex * m_row_strides + rowOffsets[0] - m_rowPaddingTop, 389 rowIndex * m_row_strides + rowOffsets[1] - m_rowPaddingTop}; 390 391 if (inputRows[1] < 0 || inputRows[0] >= m_inputRows) { 392 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 393 } 394 395 if (inputRows[0] != inputRows[1]) { 396 return packetWithPossibleZero(index); 397 } 398 399 const Index planeIndex = 400 (patch3DIndex - m_outputPlanes * (colIndex * m_outputRows + rowIndex)); 401 const Index planeOffsets[2] = { 402 patchOffsets[0] - colOffsets[0] * m_colStride - 403 rowOffsets[0] * m_rowStride, 404 patchOffsets[1] - colOffsets[1] * m_colStride - 405 rowOffsets[1] * m_rowStride}; 406 eigen_assert(planeOffsets[0] <= planeOffsets[1]); 407 const Index inputPlanes[2] = { 408 planeIndex * m_plane_strides + planeOffsets[0] - m_planePaddingTop, 409 planeIndex * m_plane_strides + planeOffsets[1] - m_planePaddingTop}; 410 411 if (inputPlanes[1] < 0 || inputPlanes[0] >= m_inputPlanes) { 412 return internal::pset1<PacketReturnType>(Scalar(m_paddingValue)); 413 } 414 415 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 416 // no padding 417 const int depth_index = 418 static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 0 419 : NumDims - 1; 420 const Index depth = 421 index - (index / m_fastOutputDepth) * m_dimensions[depth_index]; 422 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 423 inputCols[0] * m_colInputStride + 424 m_planeInputStride * inputPlanes[0] + 425 otherIndex * m_otherInputStride; 426 return m_impl.template packet<Unaligned>(inputIndex); 427 } 428 429 return packetWithPossibleZero(index); 430 } 431 432 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost 433 costPerCoeff(bool vectorized) const { 434 const double compute_cost = 10 * TensorOpCost::DivCost<Index>() + 435 21 * TensorOpCost::MulCost<Index>() + 436 8 * TensorOpCost::AddCost<Index>(); 437 return TensorOpCost(0, 0, compute_cost, vectorized, PacketSize); 438 } 439 440 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } 441 442 const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; } 443 444 Index planePaddingTop() const { return m_planePaddingTop; } 445 Index rowPaddingTop() const { return m_rowPaddingTop; } 446 Index colPaddingLeft() const { return m_colPaddingLeft; } 447 Index outputPlanes() const { return m_outputPlanes; } 448 Index outputRows() const { return m_outputRows; } 449 Index outputCols() const { return m_outputCols; } 450 Index userPlaneStride() const { return m_plane_strides; } 451 Index userRowStride() const { return m_row_strides; } 452 Index userColStride() const { return m_col_strides; } 453 Index userInPlaneStride() const { return m_in_plane_strides; } 454 Index userInRowStride() const { return m_in_row_strides; } 455 Index userInColStride() const { return m_in_col_strides; } 456 Index planeInflateStride() const { return m_plane_inflate_strides; } 457 Index rowInflateStride() const { return m_row_inflate_strides; } 458 Index colInflateStride() const { return m_col_inflate_strides; } 459 460 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType 461 coeff(const array<Index, NumDims>& coords) const { 462 // ColMajor 463 // 0: depth, 1: patch_planes, 2: patch_rows, 3: patch_cols, 4: number of 464 // patches, 5: batches 465 // RowMajor 466 // 0: batches, 1: number of patches, 2: patch_cols , 3: patch_rows, 4: 467 // patch_planes, 5: depth 468 const Index patch3DIndex = 469 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 4 : 1]; 470 const Index colOffset = 471 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 3 : 2]; 472 const Index rowOffset = 473 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 2 : 3]; 474 const Index planeOffset = 475 coords[static_cast<int>(Layout) == static_cast<int>(ColMajor) ? 1 : 4]; 476 477 array<Index, NumDims - 1> inputCoords; 478 479 const Index colIndex = patch3DIndex / m_fastOutputPlanesRows; 480 const Index inputCol = colIndex * m_col_strides + 481 colOffset * m_in_col_strides - m_colPaddingLeft; 482 const Index origInputCol = 483 (m_col_inflate_strides == 1) 484 ? inputCol 485 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 486 if (inputCol < 0 || inputCol >= m_input_cols_eff || 487 ((m_col_inflate_strides != 1) && 488 (inputCol != origInputCol * m_col_inflate_strides))) { 489 return Scalar(m_paddingValue); 490 } 491 492 const Index rowIndex = 493 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 494 const Index inputRow = rowIndex * m_row_strides + 495 rowOffset * m_in_row_strides - m_rowPaddingTop; 496 const Index origInputRow = 497 (m_row_inflate_strides == 1) 498 ? inputRow 499 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 500 if (inputRow < 0 || inputRow >= m_input_rows_eff || 501 ((m_row_inflate_strides != 1) && 502 (inputRow != origInputRow * m_row_inflate_strides))) { 503 return Scalar(m_paddingValue); 504 } 505 506 const Index planeIndex = 507 patch3DIndex - colIndex * m_outputPlanesRows - rowIndex * m_outputRows; 508 const Index inputPlane = planeIndex * m_plane_strides + 509 planeOffset * m_in_plane_strides - 510 m_planePaddingTop; 511 const Index origInputPlane = 512 (m_plane_inflate_strides == 1) 513 ? inputPlane 514 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 515 if (inputPlane < 0 || inputPlane >= m_input_planes_eff || 516 ((m_plane_inflate_strides != 1) && 517 (inputPlane != origInputPlane * m_plane_inflate_strides))) { 518 return Scalar(m_paddingValue); 519 } 520 521 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 522 inputCoords[0] = coords[0]; // depth 523 inputCoords[1] = origInputPlane; 524 inputCoords[2] = origInputRow; 525 inputCoords[3] = origInputCol; 526 inputCoords[4] = coords[5]; // batch 527 } else { 528 inputCoords[4] = coords[5]; // depth 529 inputCoords[3] = origInputPlane; 530 inputCoords[2] = origInputRow; 531 inputCoords[1] = origInputCol; 532 inputCoords[0] = coords[0]; // batch 533 } 534 if (TensorEvaluator<ArgType, Device>::CoordAccess) { 535 return m_impl.coeff(inputCoords); 536 } else { 537 Index inputIndex; 538 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 539 inputIndex = inputCoords[4] * m_otherInputStride + 540 inputCoords[3] * m_colInputStride + 541 inputCoords[2] * m_rowInputStride + 542 inputCoords[1] * m_planeInputStride + inputCoords[0]; 543 } else { 544 inputIndex = inputCoords[0] * m_otherInputStride + 545 inputCoords[1] * m_colInputStride + 546 inputCoords[2] * m_rowInputStride + 547 inputCoords[3] * m_planeInputStride + inputCoords[4]; 548 } 549 return m_impl.coeff(inputIndex); 550 } 551 } 552 553 protected: 554 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType 555 packetWithPossibleZero(Index index) const { 556 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type 557 values[PacketSize]; 558 for (int i = 0; i < PacketSize; ++i) { 559 values[i] = coeff(index + i); 560 } 561 PacketReturnType rslt = internal::pload<PacketReturnType>(values); 562 return rslt; 563 } 564 565 Dimensions m_dimensions; 566 567 // Parameters passed to the costructor. 568 Index m_plane_strides; 569 Index m_row_strides; 570 Index m_col_strides; 571 572 Index m_outputPlanes; 573 Index m_outputRows; 574 Index m_outputCols; 575 576 Index m_planePaddingTop; 577 Index m_rowPaddingTop; 578 Index m_colPaddingLeft; 579 580 Index m_in_plane_strides; 581 Index m_in_row_strides; 582 Index m_in_col_strides; 583 584 Index m_plane_inflate_strides; 585 Index m_row_inflate_strides; 586 Index m_col_inflate_strides; 587 588 // Cached input size. 589 Index m_inputDepth; 590 Index m_inputPlanes; 591 Index m_inputRows; 592 Index m_inputCols; 593 594 // Other cached variables. 595 Index m_outputPlanesRows; 596 597 // Effective input/patch post-inflation size. 598 Index m_input_planes_eff; 599 Index m_input_rows_eff; 600 Index m_input_cols_eff; 601 Index m_patch_planes_eff; 602 Index m_patch_rows_eff; 603 Index m_patch_cols_eff; 604 605 // Strides for the output tensor. 606 Index m_otherStride; 607 Index m_patchStride; 608 Index m_rowStride; 609 Index m_colStride; 610 611 // Strides for the input tensor. 612 Index m_planeInputStride; 613 Index m_rowInputStride; 614 Index m_colInputStride; 615 Index m_otherInputStride; 616 617 internal::TensorIntDivisor<Index> m_fastOtherStride; 618 internal::TensorIntDivisor<Index> m_fastPatchStride; 619 internal::TensorIntDivisor<Index> m_fastColStride; 620 internal::TensorIntDivisor<Index> m_fastRowStride; 621 internal::TensorIntDivisor<Index> m_fastInputPlaneStride; 622 internal::TensorIntDivisor<Index> m_fastInputRowStride; 623 internal::TensorIntDivisor<Index> m_fastInputColStride; 624 internal::TensorIntDivisor<Index> m_fastInputColsEff; 625 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; 626 internal::TensorIntDivisor<Index> m_fastOutputPlanes; 627 internal::TensorIntDivisor<Index> m_fastOutputDepth; 628 629 Scalar m_paddingValue; 630 631 TensorEvaluator<ArgType, Device> m_impl; 632 }; 633 634 // Override the default TensorEvaluator for TensorVolumePatchOp for CPU. 635 #define OVERRIDE_EVALUATOR(Device) \ 636 template <DenseIndex Planes, DenseIndex Rows, DenseIndex Cols, \ 637 typename ArgType> \ 638 struct TensorEvaluator< \ 639 const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, Device> \ 640 : public CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device> { \ 641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator( \ 642 const typename CustomTensorEvaluator<Planes, Rows, Cols, ArgType, \ 643 Device>::XprType& op, \ 644 const Device& device) \ 645 : CustomTensorEvaluator<Planes, Rows, Cols, ArgType, Device>( \ 646 op, device) {} \ 647 }; 648 649 OVERRIDE_EVALUATOR(Eigen::ThreadPoolDevice); 650 OVERRIDE_EVALUATOR(Eigen::DefaultDevice); 651 652 #undef OVERRIDE_EVALUATOR 653 654 }; // namespace Eigen 655 656 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_VOLUME_PATCH_H_ 657