Home | History | Annotate | Download | only in ADT
      1 //===- llvm/ADT/Trie.h ---- Generic trie structure --------------*- C++ -*-===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This class defines a generic trie structure. The trie structure
     11 // is immutable after creation, but the payload contained within it is not.
     12 //
     13 //===----------------------------------------------------------------------===//
     14 
     15 #ifndef LLVM_ADT_TRIE_H
     16 #define LLVM_ADT_TRIE_H
     17 
     18 #include "llvm/ADT/GraphTraits.h"
     19 #include "llvm/Support/DOTGraphTraits.h"
     20 
     21 #include <cassert>
     22 #include <vector>
     23 
     24 namespace llvm {
     25 
     26 // FIXME:
     27 // - Labels are usually small, maybe it's better to use SmallString
     28 // - Should we use char* during construction?
     29 // - Should we templatize Empty with traits-like interface?
     30 
     31 template<class Payload>
     32 class Trie {
     33   friend class GraphTraits<Trie<Payload> >;
     34   friend class DOTGraphTraits<Trie<Payload> >;
     35 public:
     36   class Node {
     37     friend class Trie;
     38 
     39   public:
     40     typedef std::vector<Node*> NodeVectorType;
     41     typedef typename NodeVectorType::iterator iterator;
     42     typedef typename NodeVectorType::const_iterator const_iterator;
     43 
     44   private:
     45     enum QueryResult {
     46       Same           = -3,
     47       StringIsPrefix = -2,
     48       LabelIsPrefix  = -1,
     49       DontMatch      = 0,
     50       HaveCommonPart
     51     };
     52 
     53     struct NodeCmp {
     54       bool operator() (Node* N1, Node* N2) {
     55         return (N1->Label[0] < N2->Label[0]);
     56       }
     57       bool operator() (Node* N, char Id) {
     58         return (N->Label[0] < Id);
     59       }
     60     };
     61 
     62     std::string Label;
     63     Payload Data;
     64     NodeVectorType Children;
     65 
     66     // Do not implement
     67     Node(const Node&);
     68     Node& operator=(const Node&);
     69 
     70     inline void addEdge(Node* N) {
     71       if (Children.empty())
     72         Children.push_back(N);
     73       else {
     74         iterator I = std::lower_bound(Children.begin(), Children.end(),
     75                                       N, NodeCmp());
     76         // FIXME: no dups are allowed
     77         Children.insert(I, N);
     78       }
     79     }
     80 
     81     inline void setEdge(Node* N) {
     82       char Id = N->Label[0];
     83       iterator I = std::lower_bound(Children.begin(), Children.end(),
     84                                      Id, NodeCmp());
     85       assert(I != Children.end() && "Node does not exists!");
     86       *I = N;
     87     }
     88 
     89     QueryResult query(const std::string& s) const {
     90       unsigned i, l;
     91       unsigned l1 = s.length();
     92       unsigned l2 = Label.length();
     93 
     94       // Find the length of common part
     95       l = std::min(l1, l2);
     96       i = 0;
     97       while ((i < l) && (s[i] == Label[i]))
     98         ++i;
     99 
    100       if (i == l) { // One is prefix of another, find who is who
    101         if (l1 == l2)
    102           return Same;
    103         else if (i == l1)
    104           return StringIsPrefix;
    105         else
    106           return LabelIsPrefix;
    107       } else // s and Label have common (possible empty) part, return its length
    108         return (QueryResult)i;
    109     }
    110 
    111   public:
    112     inline explicit Node(const Payload& data, const std::string& label = ""):
    113         Label(label), Data(data) { }
    114 
    115     inline const Payload& data() const { return Data; }
    116     inline void setData(const Payload& data) { Data = data; }
    117 
    118     inline const std::string& label() const { return Label; }
    119 
    120 #if 0
    121     inline void dump() {
    122       llvm::cerr << "Node: " << this << "\n"
    123                 << "Label: " << Label << "\n"
    124                 << "Children:\n";
    125 
    126       for (iterator I = Children.begin(), E = Children.end(); I != E; ++I)
    127         llvm::cerr << (*I)->Label << "\n";
    128     }
    129 #endif
    130 
    131     inline Node* getEdge(char Id) {
    132       Node* fNode = NULL;
    133       iterator I = std::lower_bound(Children.begin(), Children.end(),
    134                                           Id, NodeCmp());
    135       if (I != Children.end() && (*I)->Label[0] == Id)
    136         fNode = *I;
    137 
    138       return fNode;
    139     }
    140 
    141     inline iterator       begin()       { return Children.begin(); }
    142     inline const_iterator begin() const { return Children.begin(); }
    143     inline iterator       end  ()       { return Children.end();   }
    144     inline const_iterator end  () const { return Children.end();   }
    145 
    146     inline size_t         size () const { return Children.size();  }
    147     inline bool           empty() const { return Children.empty(); }
    148     inline const Node*   &front() const { return Children.front(); }
    149     inline       Node*   &front()       { return Children.front(); }
    150     inline const Node*   &back()  const { return Children.back();  }
    151     inline       Node*   &back()        { return Children.back();  }
    152 
    153   };
    154 
    155 private:
    156   std::vector<Node*> Nodes;
    157   Payload Empty;
    158 
    159   inline Node* addNode(const Payload& data, const std::string label = "") {
    160     Node* N = new Node(data, label);
    161     Nodes.push_back(N);
    162     return N;
    163   }
    164 
    165   inline Node* splitEdge(Node* N, char Id, size_t index) {
    166     Node* eNode = N->getEdge(Id);
    167     assert(eNode && "Node doesn't exist");
    168 
    169     const std::string &l = eNode->Label;
    170     assert(index > 0 && index < l.length() && "Trying to split too far!");
    171     std::string l1 = l.substr(0, index);
    172     std::string l2 = l.substr(index);
    173 
    174     Node* nNode = addNode(Empty, l1);
    175     N->setEdge(nNode);
    176 
    177     eNode->Label = l2;
    178     nNode->addEdge(eNode);
    179 
    180     return nNode;
    181   }
    182 
    183   // Do not implement
    184   Trie(const Trie&);
    185   Trie& operator=(const Trie&);
    186 
    187 public:
    188   inline explicit Trie(const Payload& empty):Empty(empty) {
    189     addNode(Empty);
    190   }
    191   inline ~Trie() {
    192     for (unsigned i = 0, e = Nodes.size(); i != e; ++i)
    193       delete Nodes[i];
    194   }
    195 
    196   inline Node* getRoot() const { return Nodes[0]; }
    197 
    198   bool addString(const std::string& s, const Payload& data);
    199   const Payload& lookup(const std::string& s) const;
    200 
    201 };
    202 
    203 // Define this out-of-line to dissuade the C++ compiler from inlining it.
    204 template<class Payload>
    205 bool Trie<Payload>::addString(const std::string& s, const Payload& data) {
    206   Node* cNode = getRoot();
    207   Node* tNode = NULL;
    208   std::string s1(s);
    209 
    210   while (tNode == NULL) {
    211     char Id = s1[0];
    212     if (Node* nNode = cNode->getEdge(Id)) {
    213       typename Node::QueryResult r = nNode->query(s1);
    214 
    215       switch (r) {
    216       case Node::Same:
    217       case Node::StringIsPrefix:
    218         // Currently we don't allow to have two strings in the trie one
    219         // being a prefix of another. This should be fixed.
    220         assert(0 && "FIXME!");
    221         return false;
    222       case Node::DontMatch:
    223         assert(0 && "Impossible!");
    224         return false;
    225       case Node::LabelIsPrefix:
    226         s1 = s1.substr(nNode->label().length());
    227         cNode = nNode;
    228         break;
    229       default:
    230         nNode = splitEdge(cNode, Id, r);
    231         tNode = addNode(data, s1.substr(r));
    232         nNode->addEdge(tNode);
    233       }
    234     } else {
    235       tNode = addNode(data, s1);
    236       cNode->addEdge(tNode);
    237     }
    238   }
    239 
    240   return true;
    241 }
    242 
    243 template<class Payload>
    244 const Payload& Trie<Payload>::lookup(const std::string& s) const {
    245   Node* cNode = getRoot();
    246   Node* tNode = NULL;
    247   std::string s1(s);
    248 
    249   while (tNode == NULL) {
    250     char Id = s1[0];
    251     if (Node* nNode = cNode->getEdge(Id)) {
    252       typename Node::QueryResult r = nNode->query(s1);
    253 
    254       switch (r) {
    255       case Node::Same:
    256         tNode = nNode;
    257         break;
    258       case Node::StringIsPrefix:
    259         return Empty;
    260       case Node::DontMatch:
    261         assert(0 && "Impossible!");
    262         return Empty;
    263       case Node::LabelIsPrefix:
    264         s1 = s1.substr(nNode->label().length());
    265         cNode = nNode;
    266         break;
    267       default:
    268         return Empty;
    269       }
    270     } else
    271       return Empty;
    272   }
    273 
    274   return tNode->data();
    275 }
    276 
    277 template<class Payload>
    278 struct GraphTraits<Trie<Payload> > {
    279   typedef Trie<Payload> TrieType;
    280   typedef typename TrieType::Node NodeType;
    281   typedef typename NodeType::iterator ChildIteratorType;
    282 
    283   static inline NodeType *getEntryNode(const TrieType& T) {
    284     return T.getRoot();
    285   }
    286 
    287   static inline ChildIteratorType child_begin(NodeType *N) {
    288     return N->begin();
    289   }
    290   static inline ChildIteratorType child_end(NodeType *N) { return N->end(); }
    291 
    292   typedef typename std::vector<NodeType*>::const_iterator nodes_iterator;
    293 
    294   static inline nodes_iterator nodes_begin(const TrieType& G) {
    295     return G.Nodes.begin();
    296   }
    297   static inline nodes_iterator nodes_end(const TrieType& G) {
    298     return G.Nodes.end();
    299   }
    300 
    301 };
    302 
    303 template<class Payload>
    304 struct DOTGraphTraits<Trie<Payload> > : public DefaultDOTGraphTraits {
    305   typedef typename Trie<Payload>::Node NodeType;
    306   typedef typename GraphTraits<Trie<Payload> >::ChildIteratorType EdgeIter;
    307 
    308   static std::string getGraphName(const Trie<Payload>& T) {
    309     return "Trie";
    310   }
    311 
    312   static std::string getNodeLabel(NodeType* Node, const Trie<Payload>& T) {
    313     if (T.getRoot() == Node)
    314       return "<Root>";
    315     else
    316       return Node->label();
    317   }
    318 
    319   static std::string getEdgeSourceLabel(NodeType* Node, EdgeIter I) {
    320     NodeType* N = *I;
    321     return N->label().substr(0, 1);
    322   }
    323 
    324   static std::string getNodeAttributes(const NodeType* Node,
    325                                        const Trie<Payload>& T) {
    326     if (Node->data() != T.Empty)
    327       return "color=blue";
    328 
    329     return "";
    330   }
    331 
    332 };
    333 
    334 } // end of llvm namespace
    335 
    336 #endif // LLVM_ADT_TRIE_H
    337