Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorBroadcasting
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor broadcasting class.
     19   *
     20   *
     21   */
     22 namespace internal {
     23 template<typename Broadcast, typename XprType>
     24 struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
     25 {
     26   typedef typename XprType::Scalar Scalar;
     27   typedef traits<XprType> XprTraits;
     28   typedef typename XprTraits::StorageKind StorageKind;
     29   typedef typename XprTraits::Index Index;
     30   typedef typename XprType::Nested Nested;
     31   typedef typename remove_reference<Nested>::type _Nested;
     32   static const int NumDimensions = XprTraits::NumDimensions;
     33   static const int Layout = XprTraits::Layout;
     34 };
     35 
     36 template<typename Broadcast, typename XprType>
     37 struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
     38 {
     39   typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
     40 };
     41 
     42 template<typename Broadcast, typename XprType>
     43 struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
     44 {
     45   typedef TensorBroadcastingOp<Broadcast, XprType> type;
     46 };
     47 
     48 template <typename Dims>
     49 struct is_input_scalar {
     50   static const bool value = false;
     51 };
     52 template <>
     53 struct is_input_scalar<Sizes<> > {
     54   static const bool value = true;
     55 };
     56 #ifndef EIGEN_EMULATE_CXX11_META_H
     57 template <typename std::size_t... Indices>
     58 struct is_input_scalar<Sizes<Indices...> > {
     59   static const bool value = (Sizes<Indices...>::total_size == 1);
     60 };
     61 #endif
     62 
     63 }  // end namespace internal
     64 
     65 
     66 
     67 template<typename Broadcast, typename XprType>
     68 class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors>
     69 {
     70   public:
     71   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
     72   typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     73   typedef typename XprType::CoeffReturnType CoeffReturnType;
     74   typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
     75   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
     76   typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
     77 
     78   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
     79       : m_xpr(expr), m_broadcast(broadcast) {}
     80 
     81     EIGEN_DEVICE_FUNC
     82     const Broadcast& broadcast() const { return m_broadcast; }
     83 
     84     EIGEN_DEVICE_FUNC
     85     const typename internal::remove_all<typename XprType::Nested>::type&
     86     expression() const { return m_xpr; }
     87 
     88   protected:
     89     typename XprType::Nested m_xpr;
     90     const Broadcast m_broadcast;
     91 };
     92 
     93 
     94 // Eval as rvalue
     95 template<typename Broadcast, typename ArgType, typename Device>
     96 struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
     97 {
     98   typedef TensorBroadcastingOp<Broadcast, ArgType> XprType;
     99   typedef typename XprType::Index Index;
    100   static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
    101   typedef DSizes<Index, NumDims> Dimensions;
    102   typedef typename XprType::Scalar Scalar;
    103   typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
    104   typedef typename XprType::CoeffReturnType CoeffReturnType;
    105   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    106   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
    107 
    108   enum {
    109     IsAligned = true,
    110     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
    111     Layout = TensorEvaluator<ArgType, Device>::Layout,
    112     RawAccess = false
    113   };
    114 
    115   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    116     : m_broadcast(op.broadcast()),m_impl(op.expression(), device)
    117   {
    118     // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar
    119     // and store the result in a scalar. Instead one should reshape the scalar into a a N-D
    120     // tensor with N >= 1 of 1 element first and then broadcast.
    121     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
    122     const InputDimensions& input_dims = m_impl.dimensions();
    123     const Broadcast& broadcast = op.broadcast();
    124     for (int i = 0; i < NumDims; ++i) {
    125       eigen_assert(input_dims[i] > 0);
    126       m_dimensions[i] = input_dims[i] * broadcast[i];
    127     }
    128 
    129     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    130       m_inputStrides[0] = 1;
    131       m_outputStrides[0] = 1;
    132       for (int i = 1; i < NumDims; ++i) {
    133         m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
    134         m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
    135       }
    136     } else {
    137       m_inputStrides[NumDims-1] = 1;
    138       m_outputStrides[NumDims-1] = 1;
    139       for (int i = NumDims-2; i >= 0; --i) {
    140         m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
    141         m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
    142       }
    143     }
    144   }
    145 
    146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    147 
    148   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
    149     m_impl.evalSubExprsIfNeeded(NULL);
    150     return true;
    151   }
    152 
    153   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
    154     m_impl.cleanup();
    155   }
    156 
    157   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
    158   {
    159     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
    160       return m_impl.coeff(0);
    161     }
    162 
    163     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    164       return coeffColMajor(index);
    165     } else {
    166       return coeffRowMajor(index);
    167     }
    168   }
    169 
    170   // TODO: attempt to speed this up. The integer divisions and modulo are slow
    171   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
    172   {
    173     Index inputIndex = 0;
    174     for (int i = NumDims - 1; i > 0; --i) {
    175       const Index idx = index / m_outputStrides[i];
    176       if (internal::index_statically_eq<Broadcast>(i, 1)) {
    177         eigen_assert(idx < m_impl.dimensions()[i]);
    178         inputIndex += idx * m_inputStrides[i];
    179       } else {
    180         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
    181           eigen_assert(idx % m_impl.dimensions()[i] == 0);
    182         } else {
    183           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
    184         }
    185       }
    186       index -= idx * m_outputStrides[i];
    187     }
    188     if (internal::index_statically_eq<Broadcast>(0, 1)) {
    189       eigen_assert(index < m_impl.dimensions()[0]);
    190       inputIndex += index;
    191     } else {
    192       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
    193         eigen_assert(index % m_impl.dimensions()[0] == 0);
    194       } else {
    195         inputIndex += (index % m_impl.dimensions()[0]);
    196       }
    197     }
    198     return m_impl.coeff(inputIndex);
    199   }
    200 
    201   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
    202   {
    203     Index inputIndex = 0;
    204     for (int i = 0; i < NumDims - 1; ++i) {
    205       const Index idx = index / m_outputStrides[i];
    206       if (internal::index_statically_eq<Broadcast>(i, 1)) {
    207         eigen_assert(idx < m_impl.dimensions()[i]);
    208         inputIndex += idx * m_inputStrides[i];
    209       } else {
    210         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
    211           eigen_assert(idx % m_impl.dimensions()[i] == 0);
    212         } else {
    213           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
    214         }
    215       }
    216       index -= idx * m_outputStrides[i];
    217     }
    218     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
    219       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
    220       inputIndex += index;
    221     } else {
    222       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
    223         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
    224       } else {
    225         inputIndex += (index % m_impl.dimensions()[NumDims-1]);
    226       }
    227     }
    228     return m_impl.coeff(inputIndex);
    229   }
    230 
    231   template<int LoadMode>
    232   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
    233   {
    234     if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
    235       return internal::pset1<PacketReturnType>(m_impl.coeff(0));
    236     }
    237 
    238     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    239       return packetColMajor<LoadMode>(index);
    240     } else {
    241       return packetRowMajor<LoadMode>(index);
    242     }
    243   }
    244 
    245   // Ignore the LoadMode and always use unaligned loads since we can't guarantee
    246   // the alignment at compile time.
    247   template<int LoadMode>
    248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
    249   {
    250     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
    251     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
    252 
    253     const Index originalIndex = index;
    254 
    255     Index inputIndex = 0;
    256     for (int i = NumDims - 1; i > 0; --i) {
    257       const Index idx = index / m_outputStrides[i];
    258       if (internal::index_statically_eq<Broadcast>(i, 1)) {
    259         eigen_assert(idx < m_impl.dimensions()[i]);
    260         inputIndex += idx * m_inputStrides[i];
    261       } else {
    262         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
    263           eigen_assert(idx % m_impl.dimensions()[i] == 0);
    264         } else {
    265           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
    266         }
    267       }
    268       index -= idx * m_outputStrides[i];
    269     }
    270     Index innermostLoc;
    271     if (internal::index_statically_eq<Broadcast>(0, 1)) {
    272       eigen_assert(index < m_impl.dimensions()[0]);
    273       innermostLoc = index;
    274     } else {
    275       if (internal::index_statically_eq<InputDimensions>(0, 1)) {
    276         eigen_assert(index % m_impl.dimensions()[0] == 0);
    277         innermostLoc = 0;
    278       } else {
    279         innermostLoc = index % m_impl.dimensions()[0];
    280       }
    281     }
    282     inputIndex += innermostLoc;
    283 
    284     // Todo: this could be extended to the second dimension if we're not
    285     // broadcasting alongside the first dimension, and so on.
    286     if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
    287       return m_impl.template packet<Unaligned>(inputIndex);
    288     } else {
    289       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
    290       values[0] = m_impl.coeff(inputIndex);
    291       for (int i = 1; i < PacketSize; ++i) {
    292         values[i] = coeffColMajor(originalIndex+i);
    293       }
    294       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
    295       return rslt;
    296     }
    297   }
    298 
    299   template<int LoadMode>
    300   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
    301   {
    302     EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
    303     eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
    304 
    305     const Index originalIndex = index;
    306 
    307     Index inputIndex = 0;
    308     for (int i = 0; i < NumDims - 1; ++i) {
    309       const Index idx = index / m_outputStrides[i];
    310       if (internal::index_statically_eq<Broadcast>(i, 1)) {
    311         eigen_assert(idx < m_impl.dimensions()[i]);
    312         inputIndex += idx * m_inputStrides[i];
    313       } else {
    314         if (internal::index_statically_eq<InputDimensions>(i, 1)) {
    315           eigen_assert(idx % m_impl.dimensions()[i] == 0);
    316         } else {
    317           inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
    318         }
    319       }
    320       index -= idx * m_outputStrides[i];
    321     }
    322     Index innermostLoc;
    323     if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
    324       eigen_assert(index < m_impl.dimensions()[NumDims-1]);
    325       innermostLoc = index;
    326     } else {
    327       if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
    328         eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
    329         innermostLoc = 0;
    330       } else {
    331         innermostLoc = index % m_impl.dimensions()[NumDims-1];
    332       }
    333     }
    334     inputIndex += innermostLoc;
    335 
    336     // Todo: this could be extended to the second dimension if we're not
    337     // broadcasting alongside the first dimension, and so on.
    338     if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
    339       return m_impl.template packet<Unaligned>(inputIndex);
    340     } else {
    341       EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
    342       values[0] = m_impl.coeff(inputIndex);
    343       for (int i = 1; i < PacketSize; ++i) {
    344         values[i] = coeffRowMajor(originalIndex+i);
    345       }
    346       PacketReturnType rslt = internal::pload<PacketReturnType>(values);
    347       return rslt;
    348     }
    349   }
    350 
    351   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    352   costPerCoeff(bool vectorized) const {
    353     double compute_cost = TensorOpCost::AddCost<Index>();
    354     if (NumDims > 0) {
    355       for (int i = NumDims - 1; i > 0; --i) {
    356         compute_cost += TensorOpCost::DivCost<Index>();
    357         if (internal::index_statically_eq<Broadcast>(i, 1)) {
    358           compute_cost +=
    359               TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
    360         } else {
    361           if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
    362             compute_cost += TensorOpCost::MulCost<Index>() +
    363                             TensorOpCost::ModCost<Index>() +
    364                             TensorOpCost::AddCost<Index>();
    365           }
    366         }
    367         compute_cost +=
    368             TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
    369       }
    370     }
    371     return m_impl.costPerCoeff(vectorized) +
    372            TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
    373   }
    374 
    375   EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
    376 
    377   const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
    378 
    379   Broadcast functor() const { return m_broadcast; }
    380 
    381  protected:
    382   const Broadcast m_broadcast;
    383   Dimensions m_dimensions;
    384   array<Index, NumDims> m_outputStrides;
    385   array<Index, NumDims> m_inputStrides;
    386   TensorEvaluator<ArgType, Device> m_impl;
    387 };
    388 
    389 
    390 } // end namespace Eigen
    391 
    392 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
    393