Home | History | Annotate | Download | only in TensorSymmetry
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2013 Christian Seiler <christian (at) iwakd.de>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSORSYMMETRY_DYNAMICSYMMETRY_H
     11 #define EIGEN_CXX11_TENSORSYMMETRY_DYNAMICSYMMETRY_H
     12 
     13 namespace Eigen {
     14 
     15 class DynamicSGroup
     16 {
     17   public:
     18     inline explicit DynamicSGroup() : m_numIndices(1), m_elements(), m_generators(), m_globalFlags(0) { m_elements.push_back(ge(Generator(0, 0, 0))); }
     19     inline DynamicSGroup(const DynamicSGroup& o) : m_numIndices(o.m_numIndices), m_elements(o.m_elements), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { }
     20     inline DynamicSGroup(DynamicSGroup&& o) : m_numIndices(o.m_numIndices), m_elements(), m_generators(o.m_generators), m_globalFlags(o.m_globalFlags) { std::swap(m_elements, o.m_elements); }
     21     inline DynamicSGroup& operator=(const DynamicSGroup& o) { m_numIndices = o.m_numIndices; m_elements = o.m_elements; m_generators = o.m_generators; m_globalFlags = o.m_globalFlags; return *this; }
     22     inline DynamicSGroup& operator=(DynamicSGroup&& o) { m_numIndices = o.m_numIndices; std::swap(m_elements, o.m_elements); m_generators = o.m_generators; m_globalFlags = o.m_globalFlags; return *this; }
     23 
     24     void add(int one, int two, int flags = 0);
     25 
     26     template<typename Gen_>
     27     inline void add(Gen_) { add(Gen_::One, Gen_::Two, Gen_::Flags); }
     28     inline void addSymmetry(int one, int two) { add(one, two, 0); }
     29     inline void addAntiSymmetry(int one, int two) { add(one, two, NegationFlag); }
     30     inline void addHermiticity(int one, int two) { add(one, two, ConjugationFlag); }
     31     inline void addAntiHermiticity(int one, int two) { add(one, two, NegationFlag | ConjugationFlag); }
     32 
     33     template<typename Op, typename RV, typename Index, std::size_t N, typename... Args>
     34     inline RV apply(const std::array<Index, N>& idx, RV initial, Args&&... args) const
     35     {
     36       eigen_assert(N >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
     37       for (std::size_t i = 0; i < size(); i++)
     38         initial = Op::run(h_permute(i, idx, typename internal::gen_numeric_list<int, N>::type()), m_elements[i].flags, initial, std::forward<Args>(args)...);
     39       return initial;
     40     }
     41 
     42     template<typename Op, typename RV, typename Index, typename... Args>
     43     inline RV apply(const std::vector<Index>& idx, RV initial, Args&&... args) const
     44     {
     45       eigen_assert(idx.size() >= m_numIndices && "Can only apply symmetry group to objects that have at least the required amount of indices.");
     46       for (std::size_t i = 0; i < size(); i++)
     47         initial = Op::run(h_permute(i, idx), m_elements[i].flags, initial, std::forward<Args>(args)...);
     48       return initial;
     49     }
     50 
     51     inline int globalFlags() const { return m_globalFlags; }
     52     inline std::size_t size() const { return m_elements.size(); }
     53 
     54     template<typename Tensor_, typename... IndexTypes>
     55     inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, typename Tensor_::Index firstIndex, IndexTypes... otherIndices) const
     56     {
     57       static_assert(sizeof...(otherIndices) + 1 == Tensor_::NumIndices, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
     58       return operator()(tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices>{{firstIndex, otherIndices...}});
     59     }
     60 
     61     template<typename Tensor_>
     62     inline internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup> operator()(Tensor_& tensor, std::array<typename Tensor_::Index, Tensor_::NumIndices> const& indices) const
     63     {
     64       return internal::tensor_symmetry_value_setter<Tensor_, DynamicSGroup>(tensor, *this, indices);
     65     }
     66   private:
     67     struct GroupElement {
     68       std::vector<int> representation;
     69       int flags;
     70       bool isId() const
     71       {
     72         for (std::size_t i = 0; i < representation.size(); i++)
     73           if (i != (size_t)representation[i])
     74             return false;
     75         return true;
     76       }
     77     };
     78     struct Generator {
     79       int one;
     80       int two;
     81       int flags;
     82       constexpr inline Generator(int one_, int two_, int flags_) : one(one_), two(two_), flags(flags_) {}
     83     };
     84 
     85     std::size_t m_numIndices;
     86     std::vector<GroupElement> m_elements;
     87     std::vector<Generator> m_generators;
     88     int m_globalFlags;
     89 
     90     template<typename Index, std::size_t N, int... n>
     91     inline std::array<Index, N> h_permute(std::size_t which, const std::array<Index, N>& idx, internal::numeric_list<int, n...>) const
     92     {
     93       return std::array<Index, N>{{ idx[n >= m_numIndices ? n : m_elements[which].representation[n]]... }};
     94     }
     95 
     96     template<typename Index>
     97     inline std::vector<Index> h_permute(std::size_t which, std::vector<Index> idx) const
     98     {
     99       std::vector<Index> result;
    100       result.reserve(idx.size());
    101       for (auto k : m_elements[which].representation)
    102         result.push_back(idx[k]);
    103       for (std::size_t i = m_numIndices; i < idx.size(); i++)
    104         result.push_back(idx[i]);
    105       return result;
    106     }
    107 
    108     inline GroupElement ge(Generator const& g) const
    109     {
    110       GroupElement result;
    111       result.representation.reserve(m_numIndices);
    112       result.flags = g.flags;
    113       for (std::size_t k = 0; k < m_numIndices; k++) {
    114         if (k == (std::size_t)g.one)
    115           result.representation.push_back(g.two);
    116         else if (k == (std::size_t)g.two)
    117           result.representation.push_back(g.one);
    118         else
    119           result.representation.push_back(int(k));
    120       }
    121       return result;
    122     }
    123 
    124     GroupElement mul(GroupElement, GroupElement) const;
    125     inline GroupElement mul(Generator g1, GroupElement g2) const
    126     {
    127       return mul(ge(g1), g2);
    128     }
    129 
    130     inline GroupElement mul(GroupElement g1, Generator g2) const
    131     {
    132       return mul(g1, ge(g2));
    133     }
    134 
    135     inline GroupElement mul(Generator g1, Generator g2) const
    136     {
    137       return mul(ge(g1), ge(g2));
    138     }
    139 
    140     inline int findElement(GroupElement e) const
    141     {
    142       for (auto ee : m_elements) {
    143         if (ee.representation == e.representation)
    144           return ee.flags ^ e.flags;
    145       }
    146       return -1;
    147     }
    148 
    149     void updateGlobalFlags(int flagDiffOfSameGenerator);
    150 };
    151 
    152 // dynamic symmetry group that auto-adds the template parameters in the constructor
    153 template<typename... Gen>
    154 class DynamicSGroupFromTemplateArgs : public DynamicSGroup
    155 {
    156   public:
    157     inline DynamicSGroupFromTemplateArgs() : DynamicSGroup()
    158     {
    159       add_all(internal::type_list<Gen...>());
    160     }
    161     inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs const& other) : DynamicSGroup(other) { }
    162     inline DynamicSGroupFromTemplateArgs(DynamicSGroupFromTemplateArgs&& other) : DynamicSGroup(other) { }
    163     inline DynamicSGroupFromTemplateArgs<Gen...>& operator=(const DynamicSGroupFromTemplateArgs<Gen...>& o) { DynamicSGroup::operator=(o); return *this; }
    164     inline DynamicSGroupFromTemplateArgs<Gen...>& operator=(DynamicSGroupFromTemplateArgs<Gen...>&& o) { DynamicSGroup::operator=(o); return *this; }
    165 
    166   private:
    167     template<typename Gen1, typename... GenNext>
    168     inline void add_all(internal::type_list<Gen1, GenNext...>)
    169     {
    170       add(Gen1());
    171       add_all(internal::type_list<GenNext...>());
    172     }
    173 
    174     inline void add_all(internal::type_list<>)
    175     {
    176     }
    177 };
    178 
    179 inline DynamicSGroup::GroupElement DynamicSGroup::mul(GroupElement g1, GroupElement g2) const
    180 {
    181   eigen_internal_assert(g1.representation.size() == m_numIndices);
    182   eigen_internal_assert(g2.representation.size() == m_numIndices);
    183 
    184   GroupElement result;
    185   result.representation.reserve(m_numIndices);
    186   for (std::size_t i = 0; i < m_numIndices; i++) {
    187     int v = g2.representation[g1.representation[i]];
    188     eigen_assert(v >= 0);
    189     result.representation.push_back(v);
    190   }
    191   result.flags = g1.flags ^ g2.flags;
    192   return result;
    193 }
    194 
    195 inline void DynamicSGroup::add(int one, int two, int flags)
    196 {
    197   eigen_assert(one >= 0);
    198   eigen_assert(two >= 0);
    199   eigen_assert(one != two);
    200 
    201   if ((std::size_t)one >= m_numIndices || (std::size_t)two >= m_numIndices) {
    202     std::size_t newNumIndices = (one > two) ? one : two + 1;
    203     for (auto& gelem : m_elements) {
    204       gelem.representation.reserve(newNumIndices);
    205       for (std::size_t i = m_numIndices; i < newNumIndices; i++)
    206         gelem.representation.push_back(i);
    207     }
    208     m_numIndices = newNumIndices;
    209   }
    210 
    211   Generator g{one, two, flags};
    212   GroupElement e = ge(g);
    213 
    214   /* special case for first generator */
    215   if (m_elements.size() == 1) {
    216     while (!e.isId()) {
    217       m_elements.push_back(e);
    218       e = mul(e, g);
    219     }
    220 
    221     if (e.flags > 0)
    222       updateGlobalFlags(e.flags);
    223 
    224     // only add in case we didn't have identity
    225     if (m_elements.size() > 1)
    226       m_generators.push_back(g);
    227     return;
    228   }
    229 
    230   int p = findElement(e);
    231   if (p >= 0) {
    232     updateGlobalFlags(p);
    233     return;
    234   }
    235 
    236   std::size_t coset_order = m_elements.size();
    237   m_elements.push_back(e);
    238   for (std::size_t i = 1; i < coset_order; i++)
    239     m_elements.push_back(mul(m_elements[i], e));
    240   m_generators.push_back(g);
    241 
    242   std::size_t coset_rep = coset_order;
    243   do {
    244     for (auto g : m_generators) {
    245       e = mul(m_elements[coset_rep], g);
    246       p = findElement(e);
    247       if (p < 0) {
    248         // element not yet in group
    249         m_elements.push_back(e);
    250         for (std::size_t i = 1; i < coset_order; i++)
    251           m_elements.push_back(mul(m_elements[i], e));
    252       } else if (p > 0) {
    253         updateGlobalFlags(p);
    254       }
    255     }
    256     coset_rep += coset_order;
    257   } while (coset_rep < m_elements.size());
    258 }
    259 
    260 inline void DynamicSGroup::updateGlobalFlags(int flagDiffOfSameGenerator)
    261 {
    262     switch (flagDiffOfSameGenerator) {
    263       case 0:
    264       default:
    265         // nothing happened
    266         break;
    267       case NegationFlag:
    268         // every element is it's own negative => whole tensor is zero
    269         m_globalFlags |= GlobalZeroFlag;
    270         break;
    271       case ConjugationFlag:
    272         // every element is it's own conjugate => whole tensor is real
    273         m_globalFlags |= GlobalRealFlag;
    274         break;
    275       case (NegationFlag | ConjugationFlag):
    276         // every element is it's own negative conjugate => whole tensor is imaginary
    277         m_globalFlags |= GlobalImagFlag;
    278         break;
    279       /* NOTE:
    280        *   since GlobalZeroFlag == GlobalRealFlag | GlobalImagFlag, if one generator
    281        *   causes the tensor to be real and the next one to be imaginary, this will
    282        *   trivially give the correct result
    283        */
    284     }
    285 }
    286 
    287 } // end namespace Eigen
    288 
    289 #endif // EIGEN_CXX11_TENSORSYMMETRY_DYNAMICSYMMETRY_H
    290 
    291 /*
    292  * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle;
    293  */
    294