Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 enum {
     18   Rhs = 0,
     19   Lhs = 1
     20 };
     21 
     22 /*
     23  * Implementation of the Eigen blas_data_mapper class for tensors.
     24  */
     25 
     26 template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
     27   enum {
     28     DirectOffsets = false
     29   };
     30 
     31   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { }
     32 
     33   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
     34     eigen_assert(false && "unsupported");
     35   }
     36 
     37   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
     38 
     39  template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
     40  typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
     41   {
     42     return m_tensor.template packet<LoadMode>(index);
     43   }
     44 
     45 
     46  private:
     47   const Tensor m_tensor;
     48 };
     49 
     50 template <typename Tensor> struct CoeffLoader<Tensor, true> {
     51   enum {
     52     DirectOffsets = true
     53   };
     54 
     55   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
     56 
     57   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
     58     m_data += offset;
     59   }
     60 
     61   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
     62 
     63  template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
     64  typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
     65   {
     66     return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
     67   }
     68  private:
     69   typedef typename Tensor::Scalar Scalar;
     70   const Scalar* m_data;
     71 };
     72 
     73 template<typename Scalar, typename Index, int side,
     74          typename Tensor,
     75          typename nocontract_t, typename contract_t,
     76          int packet_size, bool inner_dim_contiguous, int Alignment>
     77 class SimpleTensorContractionMapper {
     78   public:
     79   EIGEN_DEVICE_FUNC
     80   SimpleTensorContractionMapper(const Tensor& tensor,
     81                                 const nocontract_t& nocontract_strides,
     82                                 const nocontract_t& ij_strides,
     83                                 const contract_t& contract_strides,
     84                                 const contract_t& k_strides) :
     85       m_tensor(tensor),
     86       m_nocontract_strides(nocontract_strides),
     87       m_ij_strides(ij_strides),
     88       m_contract_strides(contract_strides),
     89       m_k_strides(k_strides) { }
     90 
     91   enum {
     92     DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
     93   };
     94 
     95   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
     96     m_tensor.offsetBuffer(offset);
     97   }
     98 
     99   EIGEN_DEVICE_FUNC
    100   EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
    101 
    102   EIGEN_DEVICE_FUNC
    103   EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
    104     // column major assumption
    105     return operator()(row, 0);
    106   }
    107 
    108   EIGEN_DEVICE_FUNC
    109   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
    110     return m_tensor.coeff(computeIndex(row, col));
    111   }
    112 
    113   EIGEN_DEVICE_FUNC
    114   EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
    115     const bool left = (side == Lhs);
    116     Index nocontract_val = left ? row : col;
    117     Index linidx = 0;
    118     for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
    119       const Index idx = nocontract_val / m_ij_strides[i];
    120       linidx += idx * m_nocontract_strides[i];
    121       nocontract_val -= idx * m_ij_strides[i];
    122     }
    123     if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
    124       if (side == Lhs && inner_dim_contiguous) {
    125         eigen_assert(m_nocontract_strides[0] == 1);
    126         linidx += nocontract_val;
    127       } else {
    128         linidx += nocontract_val * m_nocontract_strides[0];
    129       }
    130     }
    131 
    132     Index contract_val = left ? col : row;
    133     if(array_size<contract_t>::value > 0) {
    134       for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
    135         const Index idx = contract_val / m_k_strides[i];
    136         linidx += idx * m_contract_strides[i];
    137         contract_val -= idx * m_k_strides[i];
    138       }
    139 
    140       if (side == Rhs && inner_dim_contiguous) {
    141         eigen_assert(m_contract_strides[0] == 1);
    142         linidx += contract_val;
    143       } else {
    144         linidx += contract_val * m_contract_strides[0];
    145       }
    146     }
    147 
    148     return linidx;
    149   }
    150 
    151   EIGEN_DEVICE_FUNC
    152   EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
    153     const bool left = (side == Lhs);
    154     Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
    155     Index linidx[2] = {0, 0};
    156     if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
    157       for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
    158         const Index idx0 = nocontract_val[0] / m_ij_strides[i];
    159         const Index idx1 = nocontract_val[1] / m_ij_strides[i];
    160         linidx[0] += idx0 * m_nocontract_strides[i];
    161         linidx[1] += idx1 * m_nocontract_strides[i];
    162         nocontract_val[0] -= idx0 * m_ij_strides[i];
    163         nocontract_val[1] -= idx1 * m_ij_strides[i];
    164       }
    165       if (side == Lhs && inner_dim_contiguous) {
    166         eigen_assert(m_nocontract_strides[0] == 1);
    167         linidx[0] += nocontract_val[0];
    168         linidx[1] += nocontract_val[1];
    169       } else {
    170         linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
    171         linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
    172       }
    173     }
    174 
    175     Index contract_val[2] = {left ? col : row, left ? col : row + distance};
    176     if (array_size<contract_t>::value> 0) {
    177       for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
    178         const Index idx0 = contract_val[0] / m_k_strides[i];
    179         const Index idx1 = contract_val[1] / m_k_strides[i];
    180         linidx[0] += idx0 * m_contract_strides[i];
    181         linidx[1] += idx1 * m_contract_strides[i];
    182         contract_val[0] -= idx0 * m_k_strides[i];
    183         contract_val[1] -= idx1 * m_k_strides[i];
    184       }
    185 
    186       if (side == Rhs && inner_dim_contiguous) {
    187         eigen_assert(m_contract_strides[0] == 1);
    188         linidx[0] += contract_val[0];
    189         linidx[1] += contract_val[1];
    190       } else {
    191         linidx[0] += contract_val[0] * m_contract_strides[0];
    192         linidx[1] += contract_val[1] * m_contract_strides[0];
    193       }
    194     }
    195     return IndexPair<Index>(linidx[0], linidx[1]);
    196   }
    197 
    198   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
    199     // Only claim alignment when we can compute the actual stride (ie when we're
    200     // dealing with the lhs with inner_dim_contiguous. This is because the
    201     // matrix-vector product relies on the stride when dealing with aligned inputs.
    202     return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
    203   }
    204   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
    205     return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
    206   }
    207 
    208  protected:
    209   CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
    210   const nocontract_t m_nocontract_strides;
    211   const nocontract_t m_ij_strides;
    212   const contract_t m_contract_strides;
    213   const contract_t m_k_strides;
    214 };
    215 
    216 
    217 template<typename Scalar, typename Index, int side,
    218          typename Tensor,
    219          typename nocontract_t, typename contract_t,
    220          int packet_size, bool inner_dim_contiguous,
    221          bool inner_dim_reordered, int Alignment>
    222 class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment>
    223 {
    224  public:
    225   typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper;
    226 
    227   EIGEN_DEVICE_FUNC
    228   BaseTensorContractionMapper(const Tensor& tensor,
    229                               const nocontract_t& nocontract_strides,
    230                               const nocontract_t& ij_strides,
    231                               const contract_t& contract_strides,
    232                               const contract_t& k_strides) :
    233   ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
    234 
    235   typedef typename Tensor::PacketReturnType Packet;
    236   typedef typename unpacket_traits<Packet>::half HalfPacket;
    237 
    238   template <int AlignmentType>
    239   EIGEN_DEVICE_FUNC
    240   EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
    241     // whole method makes column major assumption
    242 
    243     // don't need to add offsets for now (because operator handles that)
    244     // current code assumes packet size must be a multiple of 2
    245     EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
    246 
    247     if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
    248       const Index index = this->computeIndex(i, j);
    249       eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
    250       return this->m_tensor.template packet<AlignmentType>(index);
    251     }
    252 
    253     const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
    254     const Index first = indexPair.first;
    255     const Index last = indexPair.second;
    256 
    257     // We can always do optimized packet reads from left hand side right now, because
    258     // the vertical matrix dimension on the left hand side is never contracting.
    259     // On the right hand side we need to check if the contracting dimensions may have
    260     // been shuffled first.
    261     if (Tensor::PacketAccess &&
    262         (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
    263         (last - first) == (packet_size - 1)) {
    264 
    265       return this->m_tensor.template packet<AlignmentType>(first);
    266     }
    267 
    268     EIGEN_ALIGN_MAX Scalar data[packet_size];
    269 
    270     data[0] = this->m_tensor.coeff(first);
    271     for (Index k = 1; k < packet_size - 1; k += 2) {
    272       const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
    273       data[k] = this->m_tensor.coeff(internal_pair.first);
    274       data[k + 1] = this->m_tensor.coeff(internal_pair.second);
    275     }
    276     data[packet_size - 1] = this->m_tensor.coeff(last);
    277 
    278     return pload<Packet>(data);
    279   }
    280 
    281   template <int AlignmentType>
    282   EIGEN_DEVICE_FUNC
    283   EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
    284     // whole method makes column major assumption
    285 
    286     // don't need to add offsets for now (because operator handles that)
    287     const Index half_packet_size = unpacket_traits<HalfPacket>::size;
    288     if (half_packet_size == packet_size) {
    289       return loadPacket<AlignmentType>(i, j);
    290     }
    291     EIGEN_ALIGN_MAX Scalar data[half_packet_size];
    292     for (Index k = 0; k < half_packet_size; k++) {
    293       data[k] = operator()(i + k, j);
    294     }
    295     return pload<HalfPacket>(data);
    296   }
    297 };
    298 
    299 
    300 template<typename Scalar, typename Index, int side,
    301          typename Tensor,
    302          typename nocontract_t, typename contract_t,
    303          bool inner_dim_contiguous,
    304          bool inner_dim_reordered, int Alignment>
    305 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
    306 {
    307  public:
    308   typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper;
    309 
    310   EIGEN_DEVICE_FUNC
    311   BaseTensorContractionMapper(const Tensor& tensor,
    312                               const nocontract_t& nocontract_strides,
    313                               const nocontract_t& ij_strides,
    314                               const contract_t& contract_strides,
    315                               const contract_t& k_strides) :
    316   ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
    317 
    318   typedef typename Tensor::PacketReturnType Packet;
    319   template <int> EIGEN_DEVICE_FUNC
    320   EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
    321     EIGEN_ALIGN_MAX Scalar data[1];
    322     data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
    323     return pload<typename Tensor::PacketReturnType>(data);
    324   }
    325   template <int> EIGEN_DEVICE_FUNC
    326   EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
    327     return loadPacket(i, j);
    328   }
    329 };
    330 
    331 
    332 template<typename Scalar, typename Index, int side,
    333          typename Tensor,
    334          typename nocontract_t, typename contract_t,
    335          int packet_size,
    336          bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
    337 class TensorContractionSubMapper {
    338  public:
    339   typedef typename Tensor::PacketReturnType Packet;
    340   typedef typename unpacket_traits<Packet>::half HalfPacket;
    341 
    342   typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
    343   typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
    344   typedef Self LinearMapper;
    345 
    346   enum {
    347     // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
    348     // TODO: we should also enable direct offsets for the Rhs case.
    349     UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
    350   };
    351 
    352   EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
    353       : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
    354     // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
    355     // this offset every time we attempt to access a coefficient.
    356     if (UseDirectOffsets) {
    357       Index stride = m_base_mapper.stride();
    358       m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
    359     }
    360   }
    361 
    362   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
    363     if (UseDirectOffsets) {
    364       return m_base_mapper(i, 0);
    365     }
    366     return m_base_mapper(i + m_vert_offset, m_horiz_offset);
    367   }
    368   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
    369     if (UseDirectOffsets) {
    370       return m_base_mapper(i, j);
    371     }
    372     return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
    373   }
    374 
    375   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
    376     if (UseDirectOffsets) {
    377       return m_base_mapper.template loadPacket<Alignment>(i, 0);
    378     }
    379     return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
    380   }
    381   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
    382     if (UseDirectOffsets) {
    383       return m_base_mapper.template loadPacket<Alignment>(i, j);
    384     }
    385     return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
    386   }
    387 
    388   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
    389     if (UseDirectOffsets) {
    390       return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
    391     }
    392     return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
    393   }
    394 
    395   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
    396     if (UseDirectOffsets) {
    397       m_base_mapper.storePacket(i, 0, p);
    398     }
    399     m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
    400   }
    401 
    402   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
    403     if (UseDirectOffsets) {
    404       return LinearMapper(m_base_mapper, i, j);
    405     }
    406     return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
    407   }
    408 
    409   template <typename PacketT, int AlignmentType>
    410   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
    411     EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
    412     const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
    413     if (UseDirectOffsets) {
    414      return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
    415     }
    416     return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
    417   }
    418 
    419   template <typename Packet>
    420   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
    421     return false;
    422   }
    423 
    424  private:
    425   ParentMapper m_base_mapper;
    426   const Index m_vert_offset;
    427   const Index m_horiz_offset;
    428 };
    429 
    430 
    431 template<typename Scalar_, typename Index, int side,
    432          typename Tensor,
    433          typename nocontract_t, typename contract_t,
    434          int packet_size,
    435          bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
    436 class TensorContractionInputMapper
    437   : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
    438 
    439  public:
    440   typedef Scalar_ Scalar;
    441   typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
    442   typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
    443   typedef SubMapper VectorMapper;
    444 
    445   EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
    446                                const nocontract_t& nocontract_strides,
    447                                const nocontract_t& ij_strides,
    448                                const contract_t& contract_strides,
    449                                const contract_t& k_strides)
    450       : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
    451 
    452   EIGEN_DEVICE_FUNC
    453   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
    454     return SubMapper(*this, i, j);
    455   }
    456 
    457   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
    458     return VectorMapper(*this, i, j);
    459   }
    460 };
    461 
    462 
    463 
    464 }  // end namespace internal
    465 }  // end namespace Eigen
    466 
    467 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
    468