Home | History | Annotate | Download | only in ADT
      1 //===-- llvm/ADT/EquivalenceClasses.h - Generic Equiv. Classes --*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // Generic implementation of equivalence classes through the use Tarjan's
     11 // efficient union-find algorithm.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #ifndef LLVM_ADT_EQUIVALENCECLASSES_H
     16 #define LLVM_ADT_EQUIVALENCECLASSES_H
     17 
     18 #include <cassert>
     19 #include <cstddef>
     20 #include <cstdint>
     21 #include <iterator>
     22 #include <set>
     23 
     24 namespace llvm {
     25 
     26 /// EquivalenceClasses - This represents a collection of equivalence classes and
     27 /// supports three efficient operations: insert an element into a class of its
     28 /// own, union two classes, and find the class for a given element.  In
     29 /// addition to these modification methods, it is possible to iterate over all
     30 /// of the equivalence classes and all of the elements in a class.
     31 ///
     32 /// This implementation is an efficient implementation that only stores one copy
     33 /// of the element being indexed per entry in the set, and allows any arbitrary
     34 /// type to be indexed (as long as it can be ordered with operator<).
     35 ///
     36 /// Here is a simple example using integers:
     37 ///
     38 /// \code
     39 ///  EquivalenceClasses<int> EC;
     40 ///  EC.unionSets(1, 2);                // insert 1, 2 into the same set
     41 ///  EC.insert(4); EC.insert(5);        // insert 4, 5 into own sets
     42 ///  EC.unionSets(5, 1);                // merge the set for 1 with 5's set.
     43 ///
     44 ///  for (EquivalenceClasses<int>::iterator I = EC.begin(), E = EC.end();
     45 ///       I != E; ++I) {           // Iterate over all of the equivalence sets.
     46 ///    if (!I->isLeader()) continue;   // Ignore non-leader sets.
     47 ///    for (EquivalenceClasses<int>::member_iterator MI = EC.member_begin(I);
     48 ///         MI != EC.member_end(); ++MI)   // Loop over members in this set.
     49 ///      cerr << *MI << " ";  // Print member.
     50 ///    cerr << "\n";   // Finish set.
     51 ///  }
     52 /// \endcode
     53 ///
     54 /// This example prints:
     55 ///   4
     56 ///   5 1 2
     57 ///
     58 template <class ElemTy>
     59 class EquivalenceClasses {
     60   /// ECValue - The EquivalenceClasses data structure is just a set of these.
     61   /// Each of these represents a relation for a value.  First it stores the
     62   /// value itself, which provides the ordering that the set queries.  Next, it
     63   /// provides a "next pointer", which is used to enumerate all of the elements
     64   /// in the unioned set.  Finally, it defines either a "end of list pointer" or
     65   /// "leader pointer" depending on whether the value itself is a leader.  A
     66   /// "leader pointer" points to the node that is the leader for this element,
     67   /// if the node is not a leader.  A "end of list pointer" points to the last
     68   /// node in the list of members of this list.  Whether or not a node is a
     69   /// leader is determined by a bit stolen from one of the pointers.
     70   class ECValue {
     71     friend class EquivalenceClasses;
     72     mutable const ECValue *Leader, *Next;
     73     ElemTy Data;
     74 
     75     // ECValue ctor - Start out with EndOfList pointing to this node, Next is
     76     // Null, isLeader = true.
     77     ECValue(const ElemTy &Elt)
     78       : Leader(this), Next((ECValue*)(intptr_t)1), Data(Elt) {}
     79 
     80     const ECValue *getLeader() const {
     81       if (isLeader()) return this;
     82       if (Leader->isLeader()) return Leader;
     83       // Path compression.
     84       return Leader = Leader->getLeader();
     85     }
     86 
     87     const ECValue *getEndOfList() const {
     88       assert(isLeader() && "Cannot get the end of a list for a non-leader!");
     89       return Leader;
     90     }
     91 
     92     void setNext(const ECValue *NewNext) const {
     93       assert(getNext() == nullptr && "Already has a next pointer!");
     94       Next = (const ECValue*)((intptr_t)NewNext | (intptr_t)isLeader());
     95     }
     96 
     97   public:
     98     ECValue(const ECValue &RHS) : Leader(this), Next((ECValue*)(intptr_t)1),
     99                                   Data(RHS.Data) {
    100       // Only support copying of singleton nodes.
    101       assert(RHS.isLeader() && RHS.getNext() == nullptr && "Not a singleton!");
    102     }
    103 
    104     bool operator<(const ECValue &UFN) const { return Data < UFN.Data; }
    105 
    106     bool isLeader() const { return (intptr_t)Next & 1; }
    107     const ElemTy &getData() const { return Data; }
    108 
    109     const ECValue *getNext() const {
    110       return (ECValue*)((intptr_t)Next & ~(intptr_t)1);
    111     }
    112 
    113     template<typename T>
    114     bool operator<(const T &Val) const { return Data < Val; }
    115   };
    116 
    117   /// TheMapping - This implicitly provides a mapping from ElemTy values to the
    118   /// ECValues, it just keeps the key as part of the value.
    119   std::set<ECValue> TheMapping;
    120 
    121 public:
    122   EquivalenceClasses() = default;
    123   EquivalenceClasses(const EquivalenceClasses &RHS) {
    124     operator=(RHS);
    125   }
    126 
    127   const EquivalenceClasses &operator=(const EquivalenceClasses &RHS) {
    128     TheMapping.clear();
    129     for (iterator I = RHS.begin(), E = RHS.end(); I != E; ++I)
    130       if (I->isLeader()) {
    131         member_iterator MI = RHS.member_begin(I);
    132         member_iterator LeaderIt = member_begin(insert(*MI));
    133         for (++MI; MI != member_end(); ++MI)
    134           unionSets(LeaderIt, member_begin(insert(*MI)));
    135       }
    136     return *this;
    137   }
    138 
    139   //===--------------------------------------------------------------------===//
    140   // Inspection methods
    141   //
    142 
    143   /// iterator* - Provides a way to iterate over all values in the set.
    144   typedef typename std::set<ECValue>::const_iterator iterator;
    145   iterator begin() const { return TheMapping.begin(); }
    146   iterator end() const { return TheMapping.end(); }
    147 
    148   bool empty() const { return TheMapping.empty(); }
    149 
    150   /// member_* Iterate over the members of an equivalence class.
    151   ///
    152   class member_iterator;
    153   member_iterator member_begin(iterator I) const {
    154     // Only leaders provide anything to iterate over.
    155     return member_iterator(I->isLeader() ? &*I : nullptr);
    156   }
    157   member_iterator member_end() const {
    158     return member_iterator(nullptr);
    159   }
    160 
    161   /// findValue - Return an iterator to the specified value.  If it does not
    162   /// exist, end() is returned.
    163   iterator findValue(const ElemTy &V) const {
    164     return TheMapping.find(V);
    165   }
    166 
    167   /// getLeaderValue - Return the leader for the specified value that is in the
    168   /// set.  It is an error to call this method for a value that is not yet in
    169   /// the set.  For that, call getOrInsertLeaderValue(V).
    170   const ElemTy &getLeaderValue(const ElemTy &V) const {
    171     member_iterator MI = findLeader(V);
    172     assert(MI != member_end() && "Value is not in the set!");
    173     return *MI;
    174   }
    175 
    176   /// getOrInsertLeaderValue - Return the leader for the specified value that is
    177   /// in the set.  If the member is not in the set, it is inserted, then
    178   /// returned.
    179   const ElemTy &getOrInsertLeaderValue(const ElemTy &V) {
    180     member_iterator MI = findLeader(insert(V));
    181     assert(MI != member_end() && "Value is not in the set!");
    182     return *MI;
    183   }
    184 
    185   /// getNumClasses - Return the number of equivalence classes in this set.
    186   /// Note that this is a linear time operation.
    187   unsigned getNumClasses() const {
    188     unsigned NC = 0;
    189     for (iterator I = begin(), E = end(); I != E; ++I)
    190       if (I->isLeader()) ++NC;
    191     return NC;
    192   }
    193 
    194   //===--------------------------------------------------------------------===//
    195   // Mutation methods
    196 
    197   /// insert - Insert a new value into the union/find set, ignoring the request
    198   /// if the value already exists.
    199   iterator insert(const ElemTy &Data) {
    200     return TheMapping.insert(ECValue(Data)).first;
    201   }
    202 
    203   /// findLeader - Given a value in the set, return a member iterator for the
    204   /// equivalence class it is in.  This does the path-compression part that
    205   /// makes union-find "union findy".  This returns an end iterator if the value
    206   /// is not in the equivalence class.
    207   ///
    208   member_iterator findLeader(iterator I) const {
    209     if (I == TheMapping.end()) return member_end();
    210     return member_iterator(I->getLeader());
    211   }
    212   member_iterator findLeader(const ElemTy &V) const {
    213     return findLeader(TheMapping.find(V));
    214   }
    215 
    216   /// union - Merge the two equivalence sets for the specified values, inserting
    217   /// them if they do not already exist in the equivalence set.
    218   member_iterator unionSets(const ElemTy &V1, const ElemTy &V2) {
    219     iterator V1I = insert(V1), V2I = insert(V2);
    220     return unionSets(findLeader(V1I), findLeader(V2I));
    221   }
    222   member_iterator unionSets(member_iterator L1, member_iterator L2) {
    223     assert(L1 != member_end() && L2 != member_end() && "Illegal inputs!");
    224     if (L1 == L2) return L1;   // Unifying the same two sets, noop.
    225 
    226     // Otherwise, this is a real union operation.  Set the end of the L1 list to
    227     // point to the L2 leader node.
    228     const ECValue &L1LV = *L1.Node, &L2LV = *L2.Node;
    229     L1LV.getEndOfList()->setNext(&L2LV);
    230 
    231     // Update L1LV's end of list pointer.
    232     L1LV.Leader = L2LV.getEndOfList();
    233 
    234     // Clear L2's leader flag:
    235     L2LV.Next = L2LV.getNext();
    236 
    237     // L2's leader is now L1.
    238     L2LV.Leader = &L1LV;
    239     return L1;
    240   }
    241 
    242   class member_iterator : public std::iterator<std::forward_iterator_tag,
    243                                                const ElemTy, ptrdiff_t> {
    244     typedef std::iterator<std::forward_iterator_tag,
    245                           const ElemTy, ptrdiff_t> super;
    246     const ECValue *Node;
    247     friend class EquivalenceClasses;
    248 
    249   public:
    250     typedef size_t size_type;
    251     typedef typename super::pointer pointer;
    252     typedef typename super::reference reference;
    253 
    254     explicit member_iterator() = default;
    255     explicit member_iterator(const ECValue *N) : Node(N) {}
    256 
    257     reference operator*() const {
    258       assert(Node != nullptr && "Dereferencing end()!");
    259       return Node->getData();
    260     }
    261     pointer operator->() const { return &operator*(); }
    262 
    263     member_iterator &operator++() {
    264       assert(Node != nullptr && "++'d off the end of the list!");
    265       Node = Node->getNext();
    266       return *this;
    267     }
    268 
    269     member_iterator operator++(int) {    // postincrement operators.
    270       member_iterator tmp = *this;
    271       ++*this;
    272       return tmp;
    273     }
    274 
    275     bool operator==(const member_iterator &RHS) const {
    276       return Node == RHS.Node;
    277     }
    278     bool operator!=(const member_iterator &RHS) const {
    279       return Node != RHS.Node;
    280     }
    281   };
    282 };
    283 
    284 } // end namespace llvm
    285 
    286 #endif // LLVM_ADT_EQUIVALENCECLASSES_H
    287