Home | History | Annotate | Download | only in ADT
      1 //===- llvm/ADT/Trie.h ---- Generic trie structure --------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This class defines a generic trie structure. The trie structure
     11 // is immutable after creation, but the payload contained within it is not.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #ifndef LLVM_ADT_TRIE_H
     16 #define LLVM_ADT_TRIE_H
     17 
     18 #include "llvm/ADT/GraphTraits.h"
     19 #include "llvm/Support/DOTGraphTraits.h"
     20 
     21 #include <cassert>
     22 #include <vector>
     23 
     24 namespace llvm {
     25 
     26 // FIXME:
     27 // - Labels are usually small, maybe it's better to use SmallString
     28 // - Should we use char* during construction?
     29 // - Should we templatize Empty with traits-like interface?
     30 
     31 template<class Payload>
     32 class Trie {
     33   friend class GraphTraits<Trie<Payload> >;
     34   friend class DOTGraphTraits<Trie<Payload> >;
     35 public:
     36   class Node {
     37     friend class Trie;
     38 
     39   public:
     40     typedef std::vector<Node*> NodeVectorType;
     41     typedef typename NodeVectorType::iterator iterator;
     42     typedef typename NodeVectorType::const_iterator const_iterator;
     43 
     44   private:
     45     enum QueryResult {
     46       Same           = -3,
     47       StringIsPrefix = -2,
     48       LabelIsPrefix  = -1,
     49       DontMatch      = 0,
     50       HaveCommonPart
     51     };
     52 
     53     struct NodeCmp {
     54       bool operator() (Node* N1, Node* N2) {
     55         return (N1->Label[0] < N2->Label[0]);
     56       }
     57       bool operator() (Node* N, char Id) {
     58         return (N->Label[0] < Id);
     59       }
     60     };
     61 
     62     std::string Label;
     63     Payload Data;
     64     NodeVectorType Children;
     65 
     66     // Do not implement
     67     Node(const Node&);
     68     Node& operator=(const Node&);
     69 
     70     inline void addEdge(Node* N) {
     71       if (Children.empty())
     72         Children.push_back(N);
     73       else {
     74         iterator I = std::lower_bound(Children.begin(), Children.end(),
     75                                       N, NodeCmp());
     76         // FIXME: no dups are allowed
     77         Children.insert(I, N);
     78       }
     79     }
     80 
     81     inline void setEdge(Node* N) {
     82       char Id = N->Label[0];
     83       iterator I = std::lower_bound(Children.begin(), Children.end(),
     84                                      Id, NodeCmp());
     85       assert(I != Children.end() && "Node does not exists!");
     86       *I = N;
     87     }
     88 
     89     QueryResult query(const std::string& s) const {
     90       unsigned i, l;
     91       unsigned l1 = s.length();
     92       unsigned l2 = Label.length();
     93 
     94       // Find the length of common part
     95       l = std::min(l1, l2);
     96       i = 0;
     97       while ((i < l) && (s[i] == Label[i]))
     98         ++i;
     99 
    100       if (i == l) { // One is prefix of another, find who is who
    101         if (l1 == l2)
    102           return Same;
    103         else if (i == l1)
    104           return StringIsPrefix;
    105         else
    106           return LabelIsPrefix;
    107       } else // s and Label have common (possible empty) part, return its length
    108         return (QueryResult)i;
    109     }
    110 
    111   public:
    112     inline explicit Node(const Payload& data, const std::string& label = ""):
    113         Label(label), Data(data) { }
    114 
    115     inline const Payload& data() const { return Data; }
    116     inline void setData(const Payload& data) { Data = data; }
    117 
    118     inline const std::string& label() const { return Label; }
    119 
    120 #if 0
    121     inline void dump() {
    122       llvm::cerr << "Node: " << this << "\n"
    123                 << "Label: " << Label << "\n"
    124                 << "Children:\n";
    125 
    126       for (iterator I = Children.begin(), E = Children.end(); I != E; ++I)
    127         llvm::cerr << (*I)->Label << "\n";
    128     }
    129 #endif
    130 
    131     inline Node* getEdge(char Id) {
    132       Node* fNode = NULL;
    133       iterator I = std::lower_bound(Children.begin(), Children.end(),
    134                                           Id, NodeCmp());
    135       if (I != Children.end() && (*I)->Label[0] == Id)
    136         fNode = *I;
    137 
    138       return fNode;
    139     }
    140 
    141     inline iterator       begin()       { return Children.begin(); }
    142     inline const_iterator begin() const { return Children.begin(); }
    143     inline iterator       end  ()       { return Children.end();   }
    144     inline const_iterator end  () const { return Children.end();   }
    145 
    146     inline size_t         size () const { return Children.size();  }
    147     inline bool           empty() const { return Children.empty(); }
    148     inline const Node*   &front() const { return Children.front(); }
    149     inline       Node*   &front()       { return Children.front(); }
    150     inline const Node*   &back()  const { return Children.back();  }
    151     inline       Node*   &back()        { return Children.back();  }
    152 
    153   };
    154 
    155 private:
    156   std::vector<Node*> Nodes;
    157   Payload Empty;
    158 
    159   inline Node* addNode(const Payload& data, const std::string label = "") {
    160     Node* N = new Node(data, label);
    161     Nodes.push_back(N);
    162     return N;
    163   }
    164 
    165   inline Node* splitEdge(Node* N, char Id, size_t index) {
    166     Node* eNode = N->getEdge(Id);
    167     assert(eNode && "Node doesn't exist");
    168 
    169     const std::string &l = eNode->Label;
    170     assert(index > 0 && index < l.length() && "Trying to split too far!");
    171     std::string l1 = l.substr(0, index);
    172     std::string l2 = l.substr(index);
    173 
    174     Node* nNode = addNode(Empty, l1);
    175     N->setEdge(nNode);
    176 
    177     eNode->Label = l2;
    178     nNode->addEdge(eNode);
    179 
    180     return nNode;
    181   }
    182 
    183   // Do not implement
    184   Trie(const Trie&);
    185   Trie& operator=(const Trie&);
    186 
    187 public:
    188   inline explicit Trie(const Payload& empty):Empty(empty) {
    189     addNode(Empty);
    190   }
    191   inline ~Trie() {
    192     for (unsigned i = 0, e = Nodes.size(); i != e; ++i)
    193       delete Nodes[i];
    194   }
    195 
    196   inline Node* getRoot() const { return Nodes[0]; }
    197 
    198   bool addString(const std::string& s, const Payload& data);
    199   const Payload& lookup(const std::string& s) const;
    200 
    201 };
    202 
    203 // Define this out-of-line to dissuade the C++ compiler from inlining it.
    204 template<class Payload>
    205 bool Trie<Payload>::addString(const std::string& s, const Payload& data) {
    206   Node* cNode = getRoot();
    207   Node* tNode = NULL;
    208   std::string s1(s);
    209 
    210   while (tNode == NULL) {
    211     char Id = s1[0];
    212     if (Node* nNode = cNode->getEdge(Id)) {
    213       typename Node::QueryResult r = nNode->query(s1);
    214 
    215       switch (r) {
    216       case Node::Same:
    217       case Node::StringIsPrefix:
    218         // Currently we don't allow to have two strings in the trie one
    219         // being a prefix of another. This should be fixed.
    220         assert(0 && "FIXME!");
    221         return false;
    222       case Node::DontMatch:
    223         llvm_unreachable("Impossible!");
    224       case Node::LabelIsPrefix:
    225         s1 = s1.substr(nNode->label().length());
    226         cNode = nNode;
    227         break;
    228       default:
    229         nNode = splitEdge(cNode, Id, r);
    230         tNode = addNode(data, s1.substr(r));
    231         nNode->addEdge(tNode);
    232       }
    233     } else {
    234       tNode = addNode(data, s1);
    235       cNode->addEdge(tNode);
    236     }
    237   }
    238 
    239   return true;
    240 }
    241 
    242 template<class Payload>
    243 const Payload& Trie<Payload>::lookup(const std::string& s) const {
    244   Node* cNode = getRoot();
    245   Node* tNode = NULL;
    246   std::string s1(s);
    247 
    248   while (tNode == NULL) {
    249     char Id = s1[0];
    250     if (Node* nNode = cNode->getEdge(Id)) {
    251       typename Node::QueryResult r = nNode->query(s1);
    252 
    253       switch (r) {
    254       case Node::Same:
    255         tNode = nNode;
    256         break;
    257       case Node::StringIsPrefix:
    258         return Empty;
    259       case Node::DontMatch:
    260         llvm_unreachable("Impossible!");
    261       case Node::LabelIsPrefix:
    262         s1 = s1.substr(nNode->label().length());
    263         cNode = nNode;
    264         break;
    265       default:
    266         return Empty;
    267       }
    268     } else
    269       return Empty;
    270   }
    271 
    272   return tNode->data();
    273 }
    274 
    275 template<class Payload>
    276 struct GraphTraits<Trie<Payload> > {
    277   typedef Trie<Payload> TrieType;
    278   typedef typename TrieType::Node NodeType;
    279   typedef typename NodeType::iterator ChildIteratorType;
    280 
    281   static inline NodeType *getEntryNode(const TrieType& T) {
    282     return T.getRoot();
    283   }
    284 
    285   static inline ChildIteratorType child_begin(NodeType *N) {
    286     return N->begin();
    287   }
    288   static inline ChildIteratorType child_end(NodeType *N) { return N->end(); }
    289 
    290   typedef typename std::vector<NodeType*>::const_iterator nodes_iterator;
    291 
    292   static inline nodes_iterator nodes_begin(const TrieType& G) {
    293     return G.Nodes.begin();
    294   }
    295   static inline nodes_iterator nodes_end(const TrieType& G) {
    296     return G.Nodes.end();
    297   }
    298 
    299 };
    300 
    301 template<class Payload>
    302 struct DOTGraphTraits<Trie<Payload> > : public DefaultDOTGraphTraits {
    303   typedef typename Trie<Payload>::Node NodeType;
    304   typedef typename GraphTraits<Trie<Payload> >::ChildIteratorType EdgeIter;
    305 
    306   static std::string getGraphName(const Trie<Payload>& T) {
    307     return "Trie";
    308   }
    309 
    310   static std::string getNodeLabel(NodeType* Node, const Trie<Payload>& T) {
    311     if (T.getRoot() == Node)
    312       return "<Root>";
    313     else
    314       return Node->label();
    315   }
    316 
    317   static std::string getEdgeSourceLabel(NodeType* Node, EdgeIter I) {
    318     NodeType* N = *I;
    319     return N->label().substr(0, 1);
    320   }
    321 
    322   static std::string getNodeAttributes(const NodeType* Node,
    323                                        const Trie<Payload>& T) {
    324     if (Node->data() != T.Empty)
    325       return "color=blue";
    326 
    327     return "";
    328   }
    329 
    330 };
    331 
    332 } // end of llvm namespace
    333 
    334 #endif // LLVM_ADT_TRIE_H
    335