1 //===- llvm/Analysis/MaximumSpanningTree.h - Interface ----------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This module provides means for calculating a maximum spanning tree for a 11 // given set of weighted edges. The type parameter T is the type of a node. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #ifndef LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H 16 #define LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H 17 18 #include "llvm/BasicBlock.h" 19 #include "llvm/ADT/EquivalenceClasses.h" 20 #include <vector> 21 #include <algorithm> 22 23 namespace llvm { 24 25 /// MaximumSpanningTree - A MST implementation. 26 /// The type parameter T determines the type of the nodes of the graph. 27 template <typename T> 28 class MaximumSpanningTree { 29 30 // A comparing class for comparing weighted edges. 31 template <typename CT> 32 struct EdgeWeightCompare { 33 bool operator()(typename MaximumSpanningTree<CT>::EdgeWeight X, 34 typename MaximumSpanningTree<CT>::EdgeWeight Y) const { 35 if (X.second > Y.second) return true; 36 if (X.second < Y.second) return false; 37 if (const BasicBlock *BBX = dyn_cast<BasicBlock>(X.first.first)) { 38 if (const BasicBlock *BBY = dyn_cast<BasicBlock>(Y.first.first)) { 39 if (BBX->size() > BBY->size()) return true; 40 if (BBX->size() < BBY->size()) return false; 41 } 42 } 43 if (const BasicBlock *BBX = dyn_cast<BasicBlock>(X.first.second)) { 44 if (const BasicBlock *BBY = dyn_cast<BasicBlock>(Y.first.second)) { 45 if (BBX->size() > BBY->size()) return true; 46 if (BBX->size() < BBY->size()) return false; 47 } 48 } 49 return false; 50 } 51 }; 52 53 public: 54 typedef std::pair<const T*, const T*> Edge; 55 typedef std::pair<Edge, double> EdgeWeight; 56 typedef std::vector<EdgeWeight> EdgeWeights; 57 protected: 58 typedef std::vector<Edge> MaxSpanTree; 59 60 MaxSpanTree MST; 61 62 public: 63 static char ID; // Class identification, replacement for typeinfo 64 65 /// MaximumSpanningTree() - Takes a vector of weighted edges and returns a 66 /// spanning tree. 67 MaximumSpanningTree(EdgeWeights &EdgeVector) { 68 69 std::stable_sort(EdgeVector.begin(), EdgeVector.end(), EdgeWeightCompare<T>()); 70 71 // Create spanning tree, Forest contains a special data structure 72 // that makes checking if two nodes are already in a common (sub-)tree 73 // fast and cheap. 74 EquivalenceClasses<const T*> Forest; 75 for (typename EdgeWeights::iterator EWi = EdgeVector.begin(), 76 EWe = EdgeVector.end(); EWi != EWe; ++EWi) { 77 Edge e = (*EWi).first; 78 79 Forest.insert(e.first); 80 Forest.insert(e.second); 81 } 82 83 // Iterate over the sorted edges, biggest first. 84 for (typename EdgeWeights::iterator EWi = EdgeVector.begin(), 85 EWe = EdgeVector.end(); EWi != EWe; ++EWi) { 86 Edge e = (*EWi).first; 87 88 if (Forest.findLeader(e.first) != Forest.findLeader(e.second)) { 89 Forest.unionSets(e.first, e.second); 90 // So we know now that the edge is not already in a subtree, so we push 91 // the edge to the MST. 92 MST.push_back(e); 93 } 94 } 95 } 96 97 typename MaxSpanTree::iterator begin() { 98 return MST.begin(); 99 } 100 101 typename MaxSpanTree::iterator end() { 102 return MST.end(); 103 } 104 }; 105 106 } // End llvm namespace 107 108 #endif 109