Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo (at) gmail.com>
      5 //                    Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      6 //
      7 // This Source Code Form is subject to the terms of the Mozilla
      8 // Public License v. 2.0. If a copy of the MPL was not distributed
      9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     10 
     11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
     12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
     13 
     14 namespace Eigen {
     15 namespace internal {
     16 
     17 /** \class TensorIndexTuple
     18   * \ingroup CXX11_Tensor_Module
     19   *
     20   * \brief Tensor + Index Tuple class.
     21   *
     22   *
     23   */
     24 template<typename XprType>
     25 struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
     26 {
     27   typedef traits<XprType> XprTraits;
     28   typedef typename XprTraits::StorageKind StorageKind;
     29   typedef typename XprTraits::Index Index;
     30   typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
     31   typedef typename XprType::Nested Nested;
     32   typedef typename remove_reference<Nested>::type _Nested;
     33   static const int NumDimensions = XprTraits::NumDimensions;
     34   static const int Layout = XprTraits::Layout;
     35 };
     36 
     37 template<typename XprType>
     38 struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
     39 {
     40   typedef const TensorIndexTupleOp<XprType>& type;
     41 };
     42 
     43 template<typename XprType>
     44 struct nested<TensorIndexTupleOp<XprType>, 1,
     45               typename eval<TensorIndexTupleOp<XprType> >::type>
     46 {
     47   typedef TensorIndexTupleOp<XprType> type;
     48 };
     49 
     50 }  // end namespace internal
     51 
     52 template<typename XprType>
     53 class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
     54 {
     55   public:
     56   typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
     57   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     58   typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
     59   typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
     60   typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
     61   typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
     62 
     63   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr)
     64       : m_xpr(expr) {}
     65 
     66   EIGEN_DEVICE_FUNC
     67   const typename internal::remove_all<typename XprType::Nested>::type&
     68   expression() const { return m_xpr; }
     69 
     70   protected:
     71     typename XprType::Nested m_xpr;
     72 };
     73 
     74 // Eval as rvalue
     75 template<typename ArgType, typename Device>
     76 struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
     77 {
     78   typedef TensorIndexTupleOp<ArgType> XprType;
     79   typedef typename XprType::Index Index;
     80   typedef typename XprType::Scalar Scalar;
     81   typedef typename XprType::CoeffReturnType CoeffReturnType;
     82 
     83   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
     84   static const int NumDims = internal::array_size<Dimensions>::value;
     85 
     86   enum {
     87     IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
     88     PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
     89     BlockAccess = false,
     90     Layout = TensorEvaluator<ArgType, Device>::Layout,
     91     CoordAccess = false,  // to be implemented
     92     RawAccess = false
     93   };
     94 
     95   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
     96       : m_impl(op.expression(), device) { }
     97 
     98   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
     99     return m_impl.dimensions();
    100   }
    101 
    102   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
    103     m_impl.evalSubExprsIfNeeded(NULL);
    104     return true;
    105   }
    106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    107     m_impl.cleanup();
    108   }
    109 
    110   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
    111   {
    112     return CoeffReturnType(index, m_impl.coeff(index));
    113   }
    114 
    115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    116   costPerCoeff(bool vectorized) const {
    117     return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
    118   }
    119 
    120   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
    121 
    122  protected:
    123   TensorEvaluator<ArgType, Device> m_impl;
    124 };
    125 
    126 namespace internal {
    127 
    128 /** \class TensorTupleIndex
    129   * \ingroup CXX11_Tensor_Module
    130   *
    131   * \brief Converts to Tensor<Tuple<Index, Scalar> > and reduces to Tensor<Index>.
    132   *
    133   */
    134 template<typename ReduceOp, typename Dims, typename XprType>
    135 struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
    136 {
    137   typedef traits<XprType> XprTraits;
    138   typedef typename XprTraits::StorageKind StorageKind;
    139   typedef typename XprTraits::Index Index;
    140   typedef Index Scalar;
    141   typedef typename XprType::Nested Nested;
    142   typedef typename remove_reference<Nested>::type _Nested;
    143   static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
    144   static const int Layout = XprTraits::Layout;
    145 };
    146 
    147 template<typename ReduceOp, typename Dims, typename XprType>
    148 struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
    149 {
    150   typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>& type;
    151 };
    152 
    153 template<typename ReduceOp, typename Dims, typename XprType>
    154 struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
    155               typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
    156 {
    157   typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
    158 };
    159 
    160 }  // end namespace internal
    161 
    162 template<typename ReduceOp, typename Dims, typename XprType>
    163 class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
    164 {
    165   public:
    166   typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
    167   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    168   typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
    169   typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
    170   typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
    171   typedef Index CoeffReturnType;
    172 
    173   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
    174                                                           const ReduceOp& reduce_op,
    175                                                           const int return_dim,
    176                                                           const Dims& reduce_dims)
    177       : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
    178 
    179   EIGEN_DEVICE_FUNC
    180   const typename internal::remove_all<typename XprType::Nested>::type&
    181   expression() const { return m_xpr; }
    182 
    183   EIGEN_DEVICE_FUNC
    184   const ReduceOp& reduce_op() const { return m_reduce_op; }
    185 
    186   EIGEN_DEVICE_FUNC
    187   const Dims& reduce_dims() const { return m_reduce_dims; }
    188 
    189   EIGEN_DEVICE_FUNC
    190   int return_dim() const { return m_return_dim; }
    191 
    192   protected:
    193     typename XprType::Nested m_xpr;
    194     const ReduceOp m_reduce_op;
    195     const int m_return_dim;
    196     const Dims m_reduce_dims;
    197 };
    198 
    199 // Eval as rvalue
    200 template<typename ReduceOp, typename Dims, typename ArgType, typename Device>
    201 struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
    202 {
    203   typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
    204   typedef typename XprType::Index Index;
    205   typedef typename XprType::Scalar Scalar;
    206   typedef typename XprType::CoeffReturnType CoeffReturnType;
    207   typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
    208   typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
    209   typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
    210   static const int NumDims = internal::array_size<InputDimensions>::value;
    211   typedef array<Index, NumDims> StrideDims;
    212 
    213   enum {
    214     IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
    215     PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
    216     BlockAccess = false,
    217     Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
    218     CoordAccess = false,  // to be implemented
    219     RawAccess = false
    220   };
    221 
    222   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    223       : m_orig_impl(op.expression(), device),
    224         m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
    225         m_return_dim(op.return_dim()) {
    226 
    227     gen_strides(m_orig_impl.dimensions(), m_strides);
    228     if (Layout == static_cast<int>(ColMajor)) {
    229       const Index total_size = internal::array_prod(m_orig_impl.dimensions());
    230       m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
    231     } else {
    232       const Index total_size = internal::array_prod(m_orig_impl.dimensions());
    233       m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
    234     }
    235     m_stride_div = m_strides[m_return_dim];
    236   }
    237 
    238   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
    239     return m_impl.dimensions();
    240   }
    241 
    242   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
    243     m_impl.evalSubExprsIfNeeded(NULL);
    244     return true;
    245   }
    246   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    247     m_impl.cleanup();
    248   }
    249 
    250   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    251     const TupleType v = m_impl.coeff(index);
    252     return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
    253   }
    254 
    255   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
    256 
    257   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    258   costPerCoeff(bool vectorized) const {
    259     const double compute_cost = 1.0 +
    260         (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
    261     return m_orig_impl.costPerCoeff(vectorized) +
    262            m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
    263   }
    264 
    265  private:
    266   EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
    267     if (m_return_dim < 0) {
    268       return;  // Won't be using the strides.
    269     }
    270     eigen_assert(m_return_dim < NumDims &&
    271                  "Asking to convert index to a dimension outside of the rank");
    272 
    273     // Calculate m_stride_div and m_stride_mod, which are used to
    274     // calculate the value of an index w.r.t. the m_return_dim.
    275     if (Layout == static_cast<int>(ColMajor)) {
    276       strides[0] = 1;
    277       for (int i = 1; i < NumDims; ++i) {
    278         strides[i] = strides[i-1] * dims[i-1];
    279       }
    280     } else {
    281       strides[NumDims-1] = 1;
    282       for (int i = NumDims - 2; i >= 0; --i) {
    283         strides[i] = strides[i+1] * dims[i+1];
    284       }
    285     }
    286   }
    287 
    288  protected:
    289   TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
    290   TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
    291   const int m_return_dim;
    292   StrideDims m_strides;
    293   Index m_stride_mod;
    294   Index m_stride_div;
    295 };
    296 
    297 } // end namespace Eigen
    298 
    299 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
    300