Home | History | Annotate | Download | only in Tensor
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog (at) gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
     12 
     13 
     14 namespace Eigen {
     15 
     16 /** \internal
     17   *
     18   * \class TensorDimensions
     19   * \ingroup CXX11_Tensor_Module
     20   *
     21   * \brief Set of classes used to encode and store the dimensions of a Tensor.
     22   *
     23   * The Sizes class encodes as part of the type the number of dimensions and the
     24   * sizes corresponding to each dimension. It uses no storage space since it is
     25   * entirely known at compile time.
     26   * The DSizes class is its dynamic sibling: the number of dimensions is known
     27   * at compile time but the sizes are set during execution.
     28   *
     29   * \sa Tensor
     30   */
     31 
     32 // Boilerplate code
     33 namespace internal {
     34 
     35 template<std::size_t n, typename Dimension> struct dget {
     36   static const std::size_t value = get<n, Dimension>::value;
     37 };
     38 
     39 
     40 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
     41 struct fixed_size_tensor_index_linearization_helper
     42 {
     43   template <typename Dimensions> EIGEN_DEVICE_FUNC
     44   static inline Index run(array<Index, NumIndices> const& indices,
     45                           const Dimensions& dimensions)
     46   {
     47     return array_get<RowMajor ? n - 1 : (NumIndices - n)>(indices) +
     48         dget<RowMajor ? n - 1 : (NumIndices - n), Dimensions>::value *
     49         fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
     50   }
     51 };
     52 
     53 template<typename Index, std::size_t NumIndices, bool RowMajor>
     54 struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
     55 {
     56   template <typename Dimensions> EIGEN_DEVICE_FUNC
     57   static inline Index run(array<Index, NumIndices> const&, const Dimensions&)
     58   {
     59     return 0;
     60   }
     61 };
     62 
     63 template<typename Index, std::size_t n>
     64 struct fixed_size_tensor_index_extraction_helper
     65 {
     66   template <typename Dimensions> EIGEN_DEVICE_FUNC
     67   static inline Index run(const Index index,
     68                           const Dimensions& dimensions)
     69   {
     70     const Index mult = (index == n-1) ? 1 : 0;
     71     return array_get<n-1>(dimensions) * mult +
     72         fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
     73   }
     74 };
     75 
     76 template<typename Index>
     77 struct fixed_size_tensor_index_extraction_helper<Index, 0>
     78 {
     79   template <typename Dimensions> EIGEN_DEVICE_FUNC
     80   static inline Index run(const Index,
     81                           const Dimensions&)
     82   {
     83     return 0;
     84   }
     85   };
     86 
     87 }  // end namespace internal
     88 
     89 
     90 // Fixed size
     91 #ifndef EIGEN_EMULATE_CXX11_META_H
     92 template <typename std::ptrdiff_t... Indices>
     93 struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
     94   typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base;
     95   static const std::ptrdiff_t total_size = internal::arg_prod(Indices...);
     96 
     97   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const {
     98     return Base::count;
     99   }
    100 
    101   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() {
    102     return internal::arg_prod(Indices...);
    103   }
    104 
    105   EIGEN_DEVICE_FUNC Sizes() { }
    106   template <typename DenseIndex>
    107   explicit EIGEN_DEVICE_FUNC Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
    108     // todo: add assertion
    109   }
    110 #if EIGEN_HAS_VARIADIC_TEMPLATES
    111   template <typename... DenseIndex> EIGEN_DEVICE_FUNC Sizes(DenseIndex...) { }
    112   explicit EIGEN_DEVICE_FUNC Sizes(std::initializer_list<std::ptrdiff_t> /*l*/) {
    113     // todo: add assertion
    114   }
    115 #endif
    116 
    117   template <typename T> Sizes& operator = (const T& /*other*/) {
    118     // add assertion failure if the size of other is different
    119     return *this;
    120   }
    121 
    122   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::size_t index) const {
    123     return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, *this);
    124   }
    125 
    126   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    127   size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
    128     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *static_cast<const Base*>(this));
    129   }
    130   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    131   size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
    132     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *static_cast<const Base*>(this));
    133   }
    134 };
    135 
    136 namespace internal {
    137 template <typename std::ptrdiff_t... Indices>
    138 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
    139   return Sizes<Indices...>::total_size;
    140 }
    141 }
    142 
    143 #else
    144 
    145 template <std::size_t n>
    146 struct non_zero_size {
    147   typedef internal::type2val<std::size_t, n> type;
    148 };
    149 template <>
    150 struct non_zero_size<0> {
    151   typedef internal::null_type type;
    152 };
    153 
    154 template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0, std::size_t V5=0> struct Sizes {
    155   typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base;
    156   static const size_t count = Base::count;
    157   static const std::size_t total_size = internal::arg_prod<Base>::value;
    158 
    159   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
    160     return count;
    161   }
    162 
    163   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() {
    164     return internal::arg_prod<Base>::value;
    165   }
    166 
    167   Sizes() { }
    168   template <typename DenseIndex>
    169   explicit Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
    170     // todo: add assertion
    171   }
    172   template <typename T> Sizes& operator = (const T& /*other*/) {
    173     // add assertion failure if the size of other is different
    174     return *this;
    175   }
    176 
    177 #if EIGEN_HAS_VARIADIC_TEMPLATES
    178   template <typename... DenseIndex> Sizes(DenseIndex... /*indices*/) { }
    179   explicit Sizes(std::initializer_list<std::size_t>) {
    180     // todo: add assertion
    181   }
    182 #else
    183   EIGEN_DEVICE_FUNC explicit Sizes(const DenseIndex) {
    184   }
    185   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex) {
    186   }
    187   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex) {
    188   }
    189   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
    190   }
    191   EIGEN_DEVICE_FUNC Sizes(const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex, const DenseIndex) {
    192   }
    193 #endif
    194 
    195   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex operator[] (const int index) const {
    196     switch (index) {
    197       case 0:
    198         return internal::get<0, Base>::value;
    199       case 1:
    200         return internal::get<1, Base>::value;
    201       case 2:
    202         return internal::get<2, Base>::value;
    203       case 3:
    204         return internal::get<3, Base>::value;
    205       case 4:
    206         return internal::get<4, Base>::value;
    207       default:
    208         eigen_assert(false && "index overflow");
    209         return static_cast<DenseIndex>(-1);
    210     }
    211   }
    212 
    213   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    214   size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
    215     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *reinterpret_cast<const Base*>(this));
    216   }
    217   template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    218   size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
    219     return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *reinterpret_cast<const Base*>(this));
    220   }
    221 };
    222 
    223 namespace internal {
    224 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5>
    225 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_prod(const Sizes<V1, V2, V3, V4, V5>&) {
    226   return Sizes<V1, V2, V3, V4, V5>::total_size;
    227 }
    228 }
    229 
    230 #endif
    231 
    232 // Boilerplate
    233 namespace internal {
    234 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
    235 struct tensor_index_linearization_helper
    236 {
    237   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    238   Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const& dimensions)
    239   {
    240     return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
    241       array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
    242         tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
    243   }
    244 };
    245 
    246 template<typename Index, std::size_t NumIndices, bool RowMajor>
    247 struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
    248 {
    249   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    250   Index run(array<Index, NumIndices> const& indices, array<Index, NumIndices> const&)
    251   {
    252     return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
    253   }
    254 };
    255 }  // end namespace internal
    256 
    257 
    258 
    259 // Dynamic size
    260 template <typename DenseIndex, int NumDims>
    261 struct DSizes : array<DenseIndex, NumDims> {
    262   typedef array<DenseIndex, NumDims> Base;
    263   static const int count = NumDims;
    264 
    265   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
    266     return NumDims;
    267   }
    268 
    269   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const {
    270     return (NumDims == 0) ? 1 : internal::array_prod(*static_cast<const Base*>(this));
    271   }
    272 
    273   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DSizes() {
    274     for (int i = 0 ; i < NumDims; ++i) {
    275       (*this)[i] = 0;
    276     }
    277   }
    278   EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { }
    279 
    280   EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) {
    281     eigen_assert(NumDims == 1);
    282     (*this)[0] = i0;
    283   }
    284 
    285 #if EIGEN_HAS_VARIADIC_TEMPLATES
    286   template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    287   EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, DenseIndex secondDimension, IndexTypes... otherDimensions) : Base({{firstDimension, secondDimension, otherDimensions...}}) {
    288     EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 2 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE)
    289   }
    290 #else
    291   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1) {
    292     eigen_assert(NumDims == 2);
    293     (*this)[0] = i0;
    294     (*this)[1] = i1;
    295   }
    296   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
    297     eigen_assert(NumDims == 3);
    298     (*this)[0] = i0;
    299     (*this)[1] = i1;
    300     (*this)[2] = i2;
    301   }
    302   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
    303     eigen_assert(NumDims == 4);
    304     (*this)[0] = i0;
    305     (*this)[1] = i1;
    306     (*this)[2] = i2;
    307     (*this)[3] = i3;
    308   }
    309   EIGEN_DEVICE_FUNC DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
    310     eigen_assert(NumDims == 5);
    311     (*this)[0] = i0;
    312     (*this)[1] = i1;
    313     (*this)[2] = i2;
    314     (*this)[3] = i3;
    315     (*this)[4] = i4;
    316   }
    317 #endif
    318 
    319   EIGEN_DEVICE_FUNC DSizes& operator = (const array<DenseIndex, NumDims>& other) {
    320     *static_cast<Base*>(this) = other;
    321     return *this;
    322   }
    323 
    324   // A constexpr would be so much better here
    325   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
    326     return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
    327   }
    328   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
    329     return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
    330   }
    331 };
    332 
    333 
    334 
    335 
    336 // Boilerplate
    337 namespace internal {
    338 template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
    339 struct tensor_vsize_index_linearization_helper
    340 {
    341   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    342   Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const& dimensions)
    343   {
    344     return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
    345       array_get<RowMajor ? n : (NumIndices - n - 1)>(dimensions) *
    346         tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
    347   }
    348 };
    349 
    350 template<typename Index, std::size_t NumIndices, bool RowMajor>
    351 struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor>
    352 {
    353   static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    354   Index run(array<Index, NumIndices> const& indices, std::vector<DenseIndex> const&)
    355   {
    356     return array_get<RowMajor ? 0 : NumIndices - 1>(indices);
    357   }
    358 };
    359 }  // end namespace internal
    360 
    361 
    362 namespace internal {
    363 
    364 template <typename DenseIndex, int NumDims> struct array_size<const DSizes<DenseIndex, NumDims> > {
    365   static const size_t value = NumDims;
    366 };
    367 template <typename DenseIndex, int NumDims> struct array_size<DSizes<DenseIndex, NumDims> > {
    368   static const size_t value = NumDims;
    369 };
    370 #ifndef EIGEN_EMULATE_CXX11_META_H
    371 template <typename std::ptrdiff_t... Indices> struct array_size<const Sizes<Indices...> > {
    372 static const std::ptrdiff_t value = Sizes<Indices...>::count;
    373 };
    374 template <typename std::ptrdiff_t... Indices> struct array_size<Sizes<Indices...> > {
    375 static const std::ptrdiff_t value = Sizes<Indices...>::count;
    376 };
    377 template <std::ptrdiff_t n, typename std::ptrdiff_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<Indices...>&) {
    378   return get<n, internal::numeric_list<std::size_t, Indices...> >::value;
    379 }
    380 template <std::ptrdiff_t n> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<>&) {
    381   eigen_assert(false && "should never be called");
    382   return -1;
    383 }
    384 #else
    385 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > {
    386   static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
    387 };
    388 template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<Sizes<V1,V2,V3,V4,V5> > {
    389   static const size_t value = Sizes<V1,V2,V3,V4,V5>::count;
    390 };
    391 template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<V1,V2,V3,V4,V5>&) {
    392   return get<n, typename Sizes<V1,V2,V3,V4,V5>::Base>::value;
    393 }
    394 
    395 #endif
    396 
    397 
    398 template <typename Dims1, typename Dims2, size_t n, size_t m>
    399 struct sizes_match_below_dim {
    400   static EIGEN_DEVICE_FUNC  inline bool run(Dims1&, Dims2&) {
    401     return false;
    402   }
    403 };
    404 template <typename Dims1, typename Dims2, size_t n>
    405 struct sizes_match_below_dim<Dims1, Dims2, n, n> {
    406   static EIGEN_DEVICE_FUNC  inline bool run(Dims1& dims1, Dims2& dims2) {
    407     return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) &
    408         sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
    409   }
    410 };
    411 template <typename Dims1, typename Dims2>
    412 struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
    413   static EIGEN_DEVICE_FUNC  inline bool run(Dims1&, Dims2&) {
    414     return true;
    415   }
    416 };
    417 
    418 } // end namespace internal
    419 
    420 
    421 template <typename Dims1, typename Dims2>
    422 EIGEN_DEVICE_FUNC bool dimensions_match(Dims1& dims1, Dims2& dims2) {
    423   return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2);
    424 }
    425 
    426 } // end namespace Eigen
    427 
    428 #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
    429