Home | History | Annotate | Download | only in TensorSymmetry
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2013 Christian Seiler <christian (at) iwakd.de>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
     11 #define EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template<typename list> struct tensor_static_symgroup_permutate;
     18 
     19 template<int... nn>
     20 struct tensor_static_symgroup_permutate<numeric_list<int, nn...>>
     21 {
     22   constexpr static std::size_t N = sizeof...(nn);
     23 
     24   template<typename T>
     25   constexpr static inline std::array<T, N> run(const std::array<T, N>& indices)
     26   {
     27     return {{indices[nn]...}};
     28   }
     29 };
     30 
     31 template<typename indices_, int flags_>
     32 struct tensor_static_symgroup_element
     33 {
     34   typedef indices_ indices;
     35   constexpr static int flags = flags_;
     36 };
     37 
     38 template<typename Gen, int N>
     39 struct tensor_static_symgroup_element_ctor
     40 {
     41   typedef tensor_static_symgroup_element<
     42     typename gen_numeric_list_swapped_pair<int, N, Gen::One, Gen::Two>::type,
     43     Gen::Flags
     44   > type;
     45 };
     46 
     47 template<int N>
     48 struct tensor_static_symgroup_identity_ctor
     49 {
     50   typedef tensor_static_symgroup_element<
     51     typename gen_numeric_list<int, N>::type,
     52     0
     53   > type;
     54 };
     55 
     56 template<typename iib>
     57 struct tensor_static_symgroup_multiply_helper
     58 {
     59   template<int... iia>
     60   constexpr static inline numeric_list<int, get<iia, iib>::value...> helper(numeric_list<int, iia...>) {
     61     return numeric_list<int, get<iia, iib>::value...>();
     62   }
     63 };
     64 
     65 template<typename A, typename B>
     66 struct tensor_static_symgroup_multiply
     67 {
     68   private:
     69     typedef typename A::indices iia;
     70     typedef typename B::indices iib;
     71     constexpr static int ffa = A::flags;
     72     constexpr static int ffb = B::flags;
     73 
     74   public:
     75     static_assert(iia::count == iib::count, "Cannot multiply symmetry elements with different number of indices.");
     76 
     77     typedef tensor_static_symgroup_element<
     78       decltype(tensor_static_symgroup_multiply_helper<iib>::helper(iia())),
     79       ffa ^ ffb
     80     > type;
     81 };
     82 
     83 template<typename A, typename B>
     84 struct tensor_static_symgroup_equality
     85 {
     86     typedef typename A::indices iia;
     87     typedef typename B::indices iib;
     88     constexpr static int ffa = A::flags;
     89     constexpr static int ffb = B::flags;
     90     static_assert(iia::count == iib::count, "Cannot compare symmetry elements with different number of indices.");
     91 
     92     constexpr static bool value = is_same<iia, iib>::value;
     93 
     94   private:
     95     /* this should be zero if they are identical, or else the tensor
     96      * will be forced to be pure real, pure imaginary or even pure zero
     97      */
     98     constexpr static int flags_cmp_ = ffa ^ ffb;
     99 
    100     /* either they are not equal, then we don't care whether the flags
    101      * match, or they are equal, and then we have to check
    102      */
    103     constexpr static bool is_zero      = value && flags_cmp_ == NegationFlag;
    104     constexpr static bool is_real      = value && flags_cmp_ == ConjugationFlag;
    105     constexpr static bool is_imag      = value && flags_cmp_ == (NegationFlag | ConjugationFlag);
    106 
    107   public:
    108     constexpr static int global_flags =
    109       (is_real ? GlobalRealFlag : 0) |
    110       (is_imag ? GlobalImagFlag : 0) |
    111       (is_zero ? GlobalZeroFlag : 0);
    112 };
    113 
    114 template<std::size_t NumIndices, typename... Gen>
    115 struct tensor_static_symgroup
    116 {
    117   typedef StaticSGroup<Gen...> type;
    118   constexpr static std::size_t size = type::static_size;
    119 };
    120 
    121 template<typename Index, std::size_t N, int... ii, int... jj>
    122 constexpr static inline std::array<Index, N> tensor_static_symgroup_index_permute(std::array<Index, N> idx, internal::numeric_list<int, ii...>, internal::numeric_list<int, jj...>)
    123 {
    124   return {{ idx[ii]..., idx[jj]... }};
    125 }
    126 
    127 template<typename Index, int... ii>
    128 static inline std::vector<Index> tensor_static_symgroup_index_permute(std::vector<Index> idx, internal::numeric_list<int, ii...>)
    129 {
    130   std::vector<Index> result{{ idx[ii]... }};
    131   std::size_t target_size = idx.size();
    132   for (std::size_t i = result.size(); i < target_size; i++)
    133     result.push_back(idx[i]);
    134   return result;
    135 }
    136 
    137 template<typename T> struct tensor_static_symgroup_do_apply;
    138 
    139 template<typename first, typename... next>
    140 struct tensor_static_symgroup_do_apply<internal::type_list<first, next...>>
    141 {
    142   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
    143   static inline RV run(const std::array<Index, NumIndices>& idx, RV initial, Args&&... args)
    144   {
    145     static_assert(NumIndices >= SGNumIndices, "Can only apply symmetry group to objects that have at least the required amount of indices.");
    146     typedef typename internal::gen_numeric_list<int, NumIndices - SGNumIndices, SGNumIndices>::type remaining_indices;
    147     initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices(), remaining_indices()), first::flags, initial, std::forward<Args>(args)...);
    148     return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
    149   }
    150 
    151   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
    152   static inline RV run(const std::vector<Index>& idx, RV initial, Args&&... args)
    153   {
    154     eigen_assert(idx.size() >= SGNumIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
    155     initial = Op::run(tensor_static_symgroup_index_permute(idx, typename first::indices()), first::flags, initial, std::forward<Args>(args)...);
    156     return tensor_static_symgroup_do_apply<internal::type_list<next...>>::template run<Op, RV, SGNumIndices>(idx, initial, args...);
    157   }
    158 };
    159 
    160 template<EIGEN_TPL_PP_SPEC_HACK_DEF(typename, empty)>
    161 struct tensor_static_symgroup_do_apply<internal::type_list<EIGEN_TPL_PP_SPEC_HACK_USE(empty)>>
    162 {
    163   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, std::size_t NumIndices, typename... Args>
    164   static inline RV run(const std::array<Index, NumIndices>&, RV initial, Args&&...)
    165   {
    166     // do nothing
    167     return initial;
    168   }
    169 
    170   template<typename Op, typename RV, std::size_t SGNumIndices, typename Index, typename... Args>
    171   static inline RV run(const std::vector<Index>&, RV initial, Args&&...)
    172   {
    173     // do nothing
    174     return initial;
    175   }
    176 };
    177 
    178 } // end namespace internal
    179 
    180 template<typename... Gen>
    181 class StaticSGroup
    182 {
    183     constexpr static std::size_t NumIndices = internal::tensor_symmetry_num_indices<Gen...>::value;
    184     typedef internal::group_theory::enumerate_group_elements<
    185       internal::tensor_static_symgroup_multiply,
    186       internal::tensor_static_symgroup_equality,
    187       typename internal::tensor_static_symgroup_identity_ctor<NumIndices>::type,
    188       internal::type_list<typename internal::tensor_static_symgroup_element_ctor<Gen, NumIndices>::type...>
    189     > group_elements;
    190     typedef typename group_elements::type ge;
    191   public:
    192     constexpr inline StaticSGroup() {}
    193     constexpr inline StaticSGroup(const StaticSGroup<Gen...>&) {}
    194     constexpr inline StaticSGroup(StaticSGroup<Gen...>&&) {}
    195 
    196     template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
    197     static inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args)
    198     {
    199       return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
    200     }
    201 
    202     template<typename Op, typename RV, typename Index, typename... Args>
    203     static inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args)
    204     {
    205       eigen_assert(idx.size() == NumIndices);
    206       return internal::tensor_static_symgroup_do_apply<ge>::template run<Op, RV, NumIndices>(idx, initial, args...);
    207     }
    208 
    209     constexpr static std::size_t static_size = ge::count;
    210 
    211     constexpr static inline std::size_t size() {
    212       return ge::count;
    213     }
    214     constexpr static inline int globalFlags() { return group_elements::global_flags; }
    215 
    216     template<typename Tensor_, typename... IndexTypes>
    217     inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
    218     {
    219       static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
    220       return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
    221     }
    222 
    223     template<typename Tensor_>
    224     inline internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
    225     {
    226       return internal::tensor_symmetry_value_setter<Tensor_, StaticSGroup<Gen...>>(tensor, *this, indices);
    227     }
    228 };
    229 
    230 } // end namespace Eigen
    231 
    232 #endif // EIGEN_CXX11_TENSORSYMMETRY_STATICSYMMETRY_H
    233 
    234 /*
    235  * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle;
    236  */
    237