Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorPatch
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor patch class.
     19   *
     20   *
     21   */
     22 namespace internal {
     23 template<typename PatchDim, typename XprType>
     24 struct traits<TensorPatchOp<PatchDim, XprType> > : public traits<XprType>
     25 {
     26   typedef typename XprType::Scalar Scalar;
     27   typedef traits<XprType> XprTraits;
     28   typedef typename XprTraits::StorageKind StorageKind;
     29   typedef typename XprTraits::Index Index;
     30   typedef typename XprType::Nested Nested;
     31   typedef typename remove_reference<Nested>::type _Nested;
     32   static const int NumDimensions = XprTraits::NumDimensions + 1;
     33   static const int Layout = XprTraits::Layout;
     34 };
     35 
     36 template<typename PatchDim, typename XprType>
     37 struct eval<TensorPatchOp<PatchDim, XprType>, Eigen::Dense>
     38 {
     39   typedef const TensorPatchOp<PatchDim, XprType>& type;
     40 };
     41 
     42 template<typename PatchDim, typename XprType>
     43 struct nested<TensorPatchOp<PatchDim, XprType>, 1, typename eval<TensorPatchOp<PatchDim, XprType> >::type>
     44 {
     45   typedef TensorPatchOp<PatchDim, XprType> type;
     46 };
     47 
     48 }  // end namespace internal
     49 
     50 
     51 
     52 template<typename PatchDim, typename XprType>
     53 class TensorPatchOp : public TensorBase<TensorPatchOp<PatchDim, XprType>, ReadOnlyAccessors>
     54 {
     55   public:
     56   typedef typename Eigen::internal::traits<TensorPatchOp>::Scalar Scalar;
     57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     58   typedef typename XprType::CoeffReturnType CoeffReturnType;
     59   typedef typename Eigen::internal::nested<TensorPatchOp>::type Nested;
     60   typedef typename Eigen::internal::traits<TensorPatchOp>::StorageKind StorageKind;
     61   typedef typename Eigen::internal::traits<TensorPatchOp>::Index Index;
     62 
     63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPatchOp(const XprType& expr, const PatchDim& patch_dims)
     64       : m_xpr(expr), m_patch_dims(patch_dims) {}
     65 
     66     EIGEN_DEVICE_FUNC
     67     const PatchDim& patch_dims() const { return m_patch_dims; }
     68 
     69     EIGEN_DEVICE_FUNC
     70     const typename internal::remove_all<typename XprType::Nested>::type&
     71     expression() const { return m_xpr; }
     72 
     73   protected:
     74     typename XprType::Nested m_xpr;
     75     const PatchDim m_patch_dims;
     76 };
     77 
     78 
     79 // Eval as rvalue
     80 template<typename PatchDim, typename ArgType, typename Device>
     81 struct TensorEvaluator<const TensorPatchOp<PatchDim, ArgType>, Device>
     82 {
     83   typedef TensorPatchOp<PatchDim, ArgType> XprType;
     84   typedef typename XprType::Index Index;
     85   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value + 1;
     86   typedef DSizes<Index, NumDims> Dimensions;
     87   typedef typename XprType::Scalar Scalar;
     88   typedef typename XprType::CoeffReturnType CoeffReturnType;
     89   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
     90   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
     91 
     92 
     93   enum {
     94     IsAligned = false,
     95     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
     96     Layout = TensorEvaluator<ArgType, Device>::Layout,
     97     CoordAccess = false,
     98     RawAccess = false
     99  };
    100 
    101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    102       : m_impl(op.expression(), device)
    103   {
    104     Index num_patches = 1;
    105     const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
    106     const PatchDim& patch_dims = op.patch_dims();
    107     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    108       for (int i = 0; i < NumDims-1; ++i) {
    109         m_dimensions[i] = patch_dims[i];
    110         num_patches *= (input_dims[i] - patch_dims[i] + 1);
    111       }
    112       m_dimensions[NumDims-1] = num_patches;
    113 
    114       m_inputStrides[0] = 1;
    115       m_patchStrides[0] = 1;
    116       for (int i = 1; i < NumDims-1; ++i) {
    117         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
    118         m_patchStrides[i] = m_patchStrides[i-1] * (input_dims[i-1] - patch_dims[i-1] + 1);
    119       }
    120       m_outputStrides[0] = 1;
    121       for (int i = 1; i < NumDims; ++i) {
    122         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
    123       }
    124     } else {
    125       for (int i = 0; i < NumDims-1; ++i) {
    126         m_dimensions[i+1] = patch_dims[i];
    127         num_patches *= (input_dims[i] - patch_dims[i] + 1);
    128       }
    129       m_dimensions[0] = num_patches;
    130 
    131       m_inputStrides[NumDims-2] = 1;
    132       m_patchStrides[NumDims-2] = 1;
    133       for (int i = NumDims-3; i >= 0; --i) {
    134         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
    135         m_patchStrides[i] = m_patchStrides[i+1] * (input_dims[i+1] - patch_dims[i+1] + 1);
    136       }
    137       m_outputStrides[NumDims-1] = 1;
    138       for (int i = NumDims-2; i >= 0; --i) {
    139         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
    140       }
    141     }
    142   }
    143 
    144   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    145 
    146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
    147     m_impl.evalSubExprsIfNeeded(NULL);
    148     return true;
    149   }
    150 
    151   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    152     m_impl.cleanup();
    153   }
    154 
    155   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
    156   {
    157     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
    158     // Find the location of the first element of the patch.
    159     Index patchIndex = index / m_outputStrides[output_stride_index];
    160     // Find the offset of the element wrt the location of the first element.
    161     Index patchOffset = index - patchIndex * m_outputStrides[output_stride_index];
    162     Index inputIndex = 0;
    163     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    164       for (int i = NumDims - 2; i > 0; --i) {
    165         const Index patchIdx = patchIndex / m_patchStrides[i];
    166         patchIndex -= patchIdx * m_patchStrides[i];
    167         const Index offsetIdx = patchOffset / m_outputStrides[i];
    168         patchOffset -= offsetIdx * m_outputStrides[i];
    169         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
    170       }
    171     } else {
    172       for (int i = 0; i < NumDims - 2; ++i) {
    173         const Index patchIdx = patchIndex / m_patchStrides[i];
    174         patchIndex -= patchIdx * m_patchStrides[i];
    175         const Index offsetIdx = patchOffset / m_outputStrides[i+1];
    176         patchOffset -= offsetIdx * m_outputStrides[i+1];
    177         inputIndex += (patchIdx + offsetIdx) * m_inputStrides[i];
    178       }
    179     }
    180     inputIndex += (patchIndex + patchOffset);
    181     return m_impl.coeff(inputIndex);
    182   }
    183 
    184   template<int LoadMode>
    185   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
    186   {
    187     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
    188     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
    189 
    190     Index output_stride_index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? NumDims - 1 : 0;
    191     Index indices[2] = {index, index + PacketSize - 1};
    192     Index patchIndices[2] = {indices[0] / m_outputStrides[output_stride_index],
    193                              indices[1] / m_outputStrides[output_stride_index]};
    194     Index patchOffsets[2] = {indices[0] - patchIndices[0] * m_outputStrides[output_stride_index],
    195                              indices[1] - patchIndices[1] * m_outputStrides[output_stride_index]};
    196 
    197     Index inputIndices[2] = {0, 0};
    198     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    199       for (int i = NumDims - 2; i > 0; --i) {
    200         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
    201                                    patchIndices[1] / m_patchStrides[i]};
    202         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
    203         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
    204 
    205         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i],
    206                                     patchOffsets[1] / m_outputStrides[i]};
    207         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i];
    208         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i];
    209 
    210         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
    211         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
    212       }
    213     } else {
    214       for (int i = 0; i < NumDims - 2; ++i) {
    215         const Index patchIdx[2] = {patchIndices[0] / m_patchStrides[i],
    216                                    patchIndices[1] / m_patchStrides[i]};
    217         patchIndices[0] -= patchIdx[0] * m_patchStrides[i];
    218         patchIndices[1] -= patchIdx[1] * m_patchStrides[i];
    219 
    220         const Index offsetIdx[2] = {patchOffsets[0] / m_outputStrides[i+1],
    221                                     patchOffsets[1] / m_outputStrides[i+1]};
    222         patchOffsets[0] -= offsetIdx[0] * m_outputStrides[i+1];
    223         patchOffsets[1] -= offsetIdx[1] * m_outputStrides[i+1];
    224 
    225         inputIndices[0] += (patchIdx[0] + offsetIdx[0]) * m_inputStrides[i];
    226         inputIndices[1] += (patchIdx[1] + offsetIdx[1]) * m_inputStrides[i];
    227       }
    228     }
    229     inputIndices[0] += (patchIndices[0] + patchOffsets[0]);
    230     inputIndices[1] += (patchIndices[1] + patchOffsets[1]);
    231 
    232     if (inputIndices[1] - inputIndices[0] == PacketSize - 1) {
    233       PacketReturnType rslt = m_impl.template packet<Unaligned>(inputIndices[0]);
    234       return rslt;
    235     }
    236     else {
    237       EIGEN_ALIGN_MAX CoeffReturnType values[PacketSize];
    238       values[0] = m_impl.coeff(inputIndices[0]);
    239       values[PacketSize-1] = m_impl.coeff(inputIndices[1]);
    240       for (int i = 1; i < PacketSize-1; ++i) {
    241         values[i] = coeff(index+i);
    242       }
    243       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
    244       return rslt;
    245     }
    246   }
    247 
    248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
    249     const double compute_cost = NumDims * (TensorOpCost::DivCost<Index>() +
    250                                            TensorOpCost::MulCost<Index>() +
    251                                            2 * TensorOpCost::AddCost<Index>());
    252     return m_impl.costPerCoeff(vectorized) +
    253            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
    254   }
    255 
    256   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
    257 
    258  protected:
    259   Dimensions m_dimensions;
    260   array<Index, NumDims> m_outputStrides;
    261   array<Index, NumDims-1> m_inputStrides;
    262   array<Index, NumDims-1> m_patchStrides;
    263 
    264   TensorEvaluator<ArgType, Device> m_impl;
    265 };
    266 
    267 } // end namespace Eigen
    268 
    269 #endif // EIGEN_CXX11_TENSOR_TENSOR_PATCH_H
    270