Home | History | Annotate | Download | only in ASTMatchers
      1 //===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  Implements an algorithm to efficiently search for matches on AST nodes.
     11 //  Uses memoization to support recursive matches like HasDescendant.
     12 //
     13 //  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
     14 //  calling the Matches(...) method of each matcher we are running on each
     15 //  AST node. The matcher can recurse via the ASTMatchFinder interface.
     16 //
     17 //===----------------------------------------------------------------------===//
     18 
     19 #include "clang/ASTMatchers/ASTMatchFinder.h"
     20 #include "clang/AST/ASTConsumer.h"
     21 #include "clang/AST/ASTContext.h"
     22 #include "clang/AST/RecursiveASTVisitor.h"
     23 #include "llvm/ADT/DenseMap.h"
     24 #include "llvm/ADT/StringMap.h"
     25 #include "llvm/Support/Timer.h"
     26 #include <deque>
     27 #include <memory>
     28 #include <set>
     29 
     30 namespace clang {
     31 namespace ast_matchers {
     32 namespace internal {
     33 namespace {
     34 
     35 typedef MatchFinder::MatchCallback MatchCallback;
     36 
     37 // The maximum number of memoization entries to store.
     38 // 10k has been experimentally found to give a good trade-off
     39 // of performance vs. memory consumption by running matcher
     40 // that match on every statement over a very large codebase.
     41 //
     42 // FIXME: Do some performance optimization in general and
     43 // revisit this number; also, put up micro-benchmarks that we can
     44 // optimize this on.
     45 static const unsigned MaxMemoizationEntries = 10000;
     46 
     47 // We use memoization to avoid running the same matcher on the same
     48 // AST node twice.  This struct is the key for looking up match
     49 // result.  It consists of an ID of the MatcherInterface (for
     50 // identifying the matcher), a pointer to the AST node and the
     51 // bound nodes before the matcher was executed.
     52 //
     53 // We currently only memoize on nodes whose pointers identify the
     54 // nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc).
     55 // For \c QualType and \c TypeLoc it is possible to implement
     56 // generation of keys for each type.
     57 // FIXME: Benchmark whether memoization of non-pointer typed nodes
     58 // provides enough benefit for the additional amount of code.
     59 struct MatchKey {
     60   DynTypedMatcher::MatcherIDType MatcherID;
     61   ast_type_traits::DynTypedNode Node;
     62   BoundNodesTreeBuilder BoundNodes;
     63 
     64   bool operator<(const MatchKey &Other) const {
     65     return std::tie(MatcherID, Node, BoundNodes) <
     66            std::tie(Other.MatcherID, Other.Node, Other.BoundNodes);
     67   }
     68 };
     69 
     70 // Used to store the result of a match and possibly bound nodes.
     71 struct MemoizedMatchResult {
     72   bool ResultOfMatch;
     73   BoundNodesTreeBuilder Nodes;
     74 };
     75 
     76 // A RecursiveASTVisitor that traverses all children or all descendants of
     77 // a node.
     78 class MatchChildASTVisitor
     79     : public RecursiveASTVisitor<MatchChildASTVisitor> {
     80 public:
     81   typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
     82 
     83   // Creates an AST visitor that matches 'matcher' on all children or
     84   // descendants of a traversed node. max_depth is the maximum depth
     85   // to traverse: use 1 for matching the children and INT_MAX for
     86   // matching the descendants.
     87   MatchChildASTVisitor(const DynTypedMatcher *Matcher,
     88                        ASTMatchFinder *Finder,
     89                        BoundNodesTreeBuilder *Builder,
     90                        int MaxDepth,
     91                        ASTMatchFinder::TraversalKind Traversal,
     92                        ASTMatchFinder::BindKind Bind)
     93       : Matcher(Matcher),
     94         Finder(Finder),
     95         Builder(Builder),
     96         CurrentDepth(0),
     97         MaxDepth(MaxDepth),
     98         Traversal(Traversal),
     99         Bind(Bind),
    100         Matches(false) {}
    101 
    102   // Returns true if a match is found in the subtree rooted at the
    103   // given AST node. This is done via a set of mutually recursive
    104   // functions. Here's how the recursion is done (the  *wildcard can
    105   // actually be Decl, Stmt, or Type):
    106   //
    107   //   - Traverse(node) calls BaseTraverse(node) when it needs
    108   //     to visit the descendants of node.
    109   //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
    110   //     Traverse*(c) for each child c of 'node'.
    111   //   - Traverse*(c) in turn calls Traverse(c), completing the
    112   //     recursion.
    113   bool findMatch(const ast_type_traits::DynTypedNode &DynNode) {
    114     reset();
    115     if (const Decl *D = DynNode.get<Decl>())
    116       traverse(*D);
    117     else if (const Stmt *S = DynNode.get<Stmt>())
    118       traverse(*S);
    119     else if (const NestedNameSpecifier *NNS =
    120              DynNode.get<NestedNameSpecifier>())
    121       traverse(*NNS);
    122     else if (const NestedNameSpecifierLoc *NNSLoc =
    123              DynNode.get<NestedNameSpecifierLoc>())
    124       traverse(*NNSLoc);
    125     else if (const QualType *Q = DynNode.get<QualType>())
    126       traverse(*Q);
    127     else if (const TypeLoc *T = DynNode.get<TypeLoc>())
    128       traverse(*T);
    129     // FIXME: Add other base types after adding tests.
    130 
    131     // It's OK to always overwrite the bound nodes, as if there was
    132     // no match in this recursive branch, the result set is empty
    133     // anyway.
    134     *Builder = ResultBindings;
    135 
    136     return Matches;
    137   }
    138 
    139   // The following are overriding methods from the base visitor class.
    140   // They are public only to allow CRTP to work. They are *not *part
    141   // of the public API of this class.
    142   bool TraverseDecl(Decl *DeclNode) {
    143     ScopedIncrement ScopedDepth(&CurrentDepth);
    144     return (DeclNode == nullptr) || traverse(*DeclNode);
    145   }
    146   bool TraverseStmt(Stmt *StmtNode) {
    147     ScopedIncrement ScopedDepth(&CurrentDepth);
    148     const Stmt *StmtToTraverse = StmtNode;
    149     if (Traversal ==
    150         ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) {
    151       const Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode);
    152       if (ExprNode) {
    153         StmtToTraverse = ExprNode->IgnoreParenImpCasts();
    154       }
    155     }
    156     return (StmtToTraverse == nullptr) || traverse(*StmtToTraverse);
    157   }
    158   // We assume that the QualType and the contained type are on the same
    159   // hierarchy level. Thus, we try to match either of them.
    160   bool TraverseType(QualType TypeNode) {
    161     if (TypeNode.isNull())
    162       return true;
    163     ScopedIncrement ScopedDepth(&CurrentDepth);
    164     // Match the Type.
    165     if (!match(*TypeNode))
    166       return false;
    167     // The QualType is matched inside traverse.
    168     return traverse(TypeNode);
    169   }
    170   // We assume that the TypeLoc, contained QualType and contained Type all are
    171   // on the same hierarchy level. Thus, we try to match all of them.
    172   bool TraverseTypeLoc(TypeLoc TypeLocNode) {
    173     if (TypeLocNode.isNull())
    174       return true;
    175     ScopedIncrement ScopedDepth(&CurrentDepth);
    176     // Match the Type.
    177     if (!match(*TypeLocNode.getType()))
    178       return false;
    179     // Match the QualType.
    180     if (!match(TypeLocNode.getType()))
    181       return false;
    182     // The TypeLoc is matched inside traverse.
    183     return traverse(TypeLocNode);
    184   }
    185   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
    186     ScopedIncrement ScopedDepth(&CurrentDepth);
    187     return (NNS == nullptr) || traverse(*NNS);
    188   }
    189   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
    190     if (!NNS)
    191       return true;
    192     ScopedIncrement ScopedDepth(&CurrentDepth);
    193     if (!match(*NNS.getNestedNameSpecifier()))
    194       return false;
    195     return traverse(NNS);
    196   }
    197 
    198   bool shouldVisitTemplateInstantiations() const { return true; }
    199   bool shouldVisitImplicitCode() const { return true; }
    200   // Disables data recursion. We intercept Traverse* methods in the RAV, which
    201   // are not triggered during data recursion.
    202   bool shouldUseDataRecursionFor(clang::Stmt *S) const { return false; }
    203 
    204 private:
    205   // Used for updating the depth during traversal.
    206   struct ScopedIncrement {
    207     explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
    208     ~ScopedIncrement() { --(*Depth); }
    209 
    210    private:
    211     int *Depth;
    212   };
    213 
    214   // Resets the state of this object.
    215   void reset() {
    216     Matches = false;
    217     CurrentDepth = 0;
    218   }
    219 
    220   // Forwards the call to the corresponding Traverse*() method in the
    221   // base visitor class.
    222   bool baseTraverse(const Decl &DeclNode) {
    223     return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
    224   }
    225   bool baseTraverse(const Stmt &StmtNode) {
    226     return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
    227   }
    228   bool baseTraverse(QualType TypeNode) {
    229     return VisitorBase::TraverseType(TypeNode);
    230   }
    231   bool baseTraverse(TypeLoc TypeLocNode) {
    232     return VisitorBase::TraverseTypeLoc(TypeLocNode);
    233   }
    234   bool baseTraverse(const NestedNameSpecifier &NNS) {
    235     return VisitorBase::TraverseNestedNameSpecifier(
    236         const_cast<NestedNameSpecifier*>(&NNS));
    237   }
    238   bool baseTraverse(NestedNameSpecifierLoc NNS) {
    239     return VisitorBase::TraverseNestedNameSpecifierLoc(NNS);
    240   }
    241 
    242   // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
    243   //   0 < CurrentDepth <= MaxDepth.
    244   //
    245   // Returns 'true' if traversal should continue after this function
    246   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
    247   template <typename T>
    248   bool match(const T &Node) {
    249     if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
    250       return true;
    251     }
    252     if (Bind != ASTMatchFinder::BK_All) {
    253       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
    254       if (Matcher->matches(ast_type_traits::DynTypedNode::create(Node), Finder,
    255                            &RecursiveBuilder)) {
    256         Matches = true;
    257         ResultBindings.addMatch(RecursiveBuilder);
    258         return false; // Abort as soon as a match is found.
    259       }
    260     } else {
    261       BoundNodesTreeBuilder RecursiveBuilder(*Builder);
    262       if (Matcher->matches(ast_type_traits::DynTypedNode::create(Node), Finder,
    263                            &RecursiveBuilder)) {
    264         // After the first match the matcher succeeds.
    265         Matches = true;
    266         ResultBindings.addMatch(RecursiveBuilder);
    267       }
    268     }
    269     return true;
    270   }
    271 
    272   // Traverses the subtree rooted at 'Node'; returns true if the
    273   // traversal should continue after this function returns.
    274   template <typename T>
    275   bool traverse(const T &Node) {
    276     static_assert(IsBaseType<T>::value,
    277                   "traverse can only be instantiated with base type");
    278     if (!match(Node))
    279       return false;
    280     return baseTraverse(Node);
    281   }
    282 
    283   const DynTypedMatcher *const Matcher;
    284   ASTMatchFinder *const Finder;
    285   BoundNodesTreeBuilder *const Builder;
    286   BoundNodesTreeBuilder ResultBindings;
    287   int CurrentDepth;
    288   const int MaxDepth;
    289   const ASTMatchFinder::TraversalKind Traversal;
    290   const ASTMatchFinder::BindKind Bind;
    291   bool Matches;
    292 };
    293 
    294 // Controls the outermost traversal of the AST and allows to match multiple
    295 // matchers.
    296 class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
    297                         public ASTMatchFinder {
    298 public:
    299   MatchASTVisitor(const MatchFinder::MatchersByType *Matchers,
    300                   const MatchFinder::MatchFinderOptions &Options)
    301       : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {}
    302 
    303   ~MatchASTVisitor() override {
    304     if (Options.CheckProfiling) {
    305       Options.CheckProfiling->Records = std::move(TimeByBucket);
    306     }
    307   }
    308 
    309   void onStartOfTranslationUnit() {
    310     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    311     TimeBucketRegion Timer;
    312     for (MatchCallback *MC : Matchers->AllCallbacks) {
    313       if (EnableCheckProfiling)
    314         Timer.setBucket(&TimeByBucket[MC->getID()]);
    315       MC->onStartOfTranslationUnit();
    316     }
    317   }
    318 
    319   void onEndOfTranslationUnit() {
    320     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    321     TimeBucketRegion Timer;
    322     for (MatchCallback *MC : Matchers->AllCallbacks) {
    323       if (EnableCheckProfiling)
    324         Timer.setBucket(&TimeByBucket[MC->getID()]);
    325       MC->onEndOfTranslationUnit();
    326     }
    327   }
    328 
    329   void set_active_ast_context(ASTContext *NewActiveASTContext) {
    330     ActiveASTContext = NewActiveASTContext;
    331   }
    332 
    333   // The following Visit*() and Traverse*() functions "override"
    334   // methods in RecursiveASTVisitor.
    335 
    336   bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) {
    337     // When we see 'typedef A B', we add name 'B' to the set of names
    338     // A's canonical type maps to.  This is necessary for implementing
    339     // isDerivedFrom(x) properly, where x can be the name of the base
    340     // class or any of its aliases.
    341     //
    342     // In general, the is-alias-of (as defined by typedefs) relation
    343     // is tree-shaped, as you can typedef a type more than once.  For
    344     // example,
    345     //
    346     //   typedef A B;
    347     //   typedef A C;
    348     //   typedef C D;
    349     //   typedef C E;
    350     //
    351     // gives you
    352     //
    353     //   A
    354     //   |- B
    355     //   `- C
    356     //      |- D
    357     //      `- E
    358     //
    359     // It is wrong to assume that the relation is a chain.  A correct
    360     // implementation of isDerivedFrom() needs to recognize that B and
    361     // E are aliases, even though neither is a typedef of the other.
    362     // Therefore, we cannot simply walk through one typedef chain to
    363     // find out whether the type name matches.
    364     const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
    365     const Type *CanonicalType =  // root of the typedef tree
    366         ActiveASTContext->getCanonicalType(TypeNode);
    367     TypeAliases[CanonicalType].insert(DeclNode);
    368     return true;
    369   }
    370 
    371   bool TraverseDecl(Decl *DeclNode);
    372   bool TraverseStmt(Stmt *StmtNode);
    373   bool TraverseType(QualType TypeNode);
    374   bool TraverseTypeLoc(TypeLoc TypeNode);
    375   bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS);
    376   bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS);
    377 
    378   // Matches children or descendants of 'Node' with 'BaseMatcher'.
    379   bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node,
    380                                   const DynTypedMatcher &Matcher,
    381                                   BoundNodesTreeBuilder *Builder, int MaxDepth,
    382                                   TraversalKind Traversal, BindKind Bind) {
    383     // For AST-nodes that don't have an identity, we can't memoize.
    384     if (!Node.getMemoizationData() || !Builder->isComparable())
    385       return matchesRecursively(Node, Matcher, Builder, MaxDepth, Traversal,
    386                                 Bind);
    387 
    388     MatchKey Key;
    389     Key.MatcherID = Matcher.getID();
    390     Key.Node = Node;
    391     // Note that we key on the bindings *before* the match.
    392     Key.BoundNodes = *Builder;
    393 
    394     MemoizationMap::iterator I = ResultCache.find(Key);
    395     if (I != ResultCache.end()) {
    396       *Builder = I->second.Nodes;
    397       return I->second.ResultOfMatch;
    398     }
    399 
    400     MemoizedMatchResult Result;
    401     Result.Nodes = *Builder;
    402     Result.ResultOfMatch = matchesRecursively(Node, Matcher, &Result.Nodes,
    403                                               MaxDepth, Traversal, Bind);
    404 
    405     MemoizedMatchResult &CachedResult = ResultCache[Key];
    406     CachedResult = std::move(Result);
    407 
    408     *Builder = CachedResult.Nodes;
    409     return CachedResult.ResultOfMatch;
    410   }
    411 
    412   // Matches children or descendants of 'Node' with 'BaseMatcher'.
    413   bool matchesRecursively(const ast_type_traits::DynTypedNode &Node,
    414                           const DynTypedMatcher &Matcher,
    415                           BoundNodesTreeBuilder *Builder, int MaxDepth,
    416                           TraversalKind Traversal, BindKind Bind) {
    417     MatchChildASTVisitor Visitor(
    418       &Matcher, this, Builder, MaxDepth, Traversal, Bind);
    419     return Visitor.findMatch(Node);
    420   }
    421 
    422   bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
    423                           const Matcher<NamedDecl> &Base,
    424                           BoundNodesTreeBuilder *Builder) override;
    425 
    426   // Implements ASTMatchFinder::matchesChildOf.
    427   bool matchesChildOf(const ast_type_traits::DynTypedNode &Node,
    428                       const DynTypedMatcher &Matcher,
    429                       BoundNodesTreeBuilder *Builder,
    430                       TraversalKind Traversal,
    431                       BindKind Bind) override {
    432     if (ResultCache.size() > MaxMemoizationEntries)
    433       ResultCache.clear();
    434     return memoizedMatchesRecursively(Node, Matcher, Builder, 1, Traversal,
    435                                       Bind);
    436   }
    437   // Implements ASTMatchFinder::matchesDescendantOf.
    438   bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node,
    439                            const DynTypedMatcher &Matcher,
    440                            BoundNodesTreeBuilder *Builder,
    441                            BindKind Bind) override {
    442     if (ResultCache.size() > MaxMemoizationEntries)
    443       ResultCache.clear();
    444     return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX,
    445                                       TK_AsIs, Bind);
    446   }
    447   // Implements ASTMatchFinder::matchesAncestorOf.
    448   bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node,
    449                          const DynTypedMatcher &Matcher,
    450                          BoundNodesTreeBuilder *Builder,
    451                          AncestorMatchMode MatchMode) override {
    452     // Reset the cache outside of the recursive call to make sure we
    453     // don't invalidate any iterators.
    454     if (ResultCache.size() > MaxMemoizationEntries)
    455       ResultCache.clear();
    456     return memoizedMatchesAncestorOfRecursively(Node, Matcher, Builder,
    457                                                 MatchMode);
    458   }
    459 
    460   // Matches all registered matchers on the given node and calls the
    461   // result callback for every node that matches.
    462   void match(const ast_type_traits::DynTypedNode &Node) {
    463     // FIXME: Improve this with a switch or a visitor pattern.
    464     if (auto *N = Node.get<Decl>()) {
    465       match(*N);
    466     } else if (auto *N = Node.get<Stmt>()) {
    467       match(*N);
    468     } else if (auto *N = Node.get<Type>()) {
    469       match(*N);
    470     } else if (auto *N = Node.get<QualType>()) {
    471       match(*N);
    472     } else if (auto *N = Node.get<NestedNameSpecifier>()) {
    473       match(*N);
    474     } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) {
    475       match(*N);
    476     } else if (auto *N = Node.get<TypeLoc>()) {
    477       match(*N);
    478     }
    479   }
    480 
    481   template <typename T> void match(const T &Node) {
    482     matchDispatch(&Node);
    483   }
    484 
    485   // Implements ASTMatchFinder::getASTContext.
    486   ASTContext &getASTContext() const override { return *ActiveASTContext; }
    487 
    488   bool shouldVisitTemplateInstantiations() const { return true; }
    489   bool shouldVisitImplicitCode() const { return true; }
    490   // Disables data recursion. We intercept Traverse* methods in the RAV, which
    491   // are not triggered during data recursion.
    492   bool shouldUseDataRecursionFor(clang::Stmt *S) const { return false; }
    493 
    494 private:
    495   class TimeBucketRegion {
    496   public:
    497     TimeBucketRegion() : Bucket(nullptr) {}
    498     ~TimeBucketRegion() { setBucket(nullptr); }
    499 
    500     /// \brief Start timing for \p NewBucket.
    501     ///
    502     /// If there was a bucket already set, it will finish the timing for that
    503     /// other bucket.
    504     /// \p NewBucket will be timed until the next call to \c setBucket() or
    505     /// until the \c TimeBucketRegion is destroyed.
    506     /// If \p NewBucket is the same as the currently timed bucket, this call
    507     /// does nothing.
    508     void setBucket(llvm::TimeRecord *NewBucket) {
    509       if (Bucket != NewBucket) {
    510         auto Now = llvm::TimeRecord::getCurrentTime(true);
    511         if (Bucket)
    512           *Bucket += Now;
    513         if (NewBucket)
    514           *NewBucket -= Now;
    515         Bucket = NewBucket;
    516       }
    517     }
    518 
    519   private:
    520     llvm::TimeRecord *Bucket;
    521   };
    522 
    523   /// \brief Runs all the \p Matchers on \p Node.
    524   ///
    525   /// Used by \c matchDispatch() below.
    526   template <typename T, typename MC>
    527   void matchWithoutFilter(const T &Node, const MC &Matchers) {
    528     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    529     TimeBucketRegion Timer;
    530     for (const auto &MP : Matchers) {
    531       if (EnableCheckProfiling)
    532         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
    533       BoundNodesTreeBuilder Builder;
    534       if (MP.first.matches(Node, this, &Builder)) {
    535         MatchVisitor Visitor(ActiveASTContext, MP.second);
    536         Builder.visitMatches(&Visitor);
    537       }
    538     }
    539   }
    540 
    541   void matchWithFilter(const ast_type_traits::DynTypedNode &DynNode) {
    542     auto Kind = DynNode.getNodeKind();
    543     auto it = MatcherFiltersMap.find(Kind);
    544     const auto &Filter =
    545         it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind);
    546 
    547     if (Filter.empty())
    548       return;
    549 
    550     const bool EnableCheckProfiling = Options.CheckProfiling.hasValue();
    551     TimeBucketRegion Timer;
    552     auto &Matchers = this->Matchers->DeclOrStmt;
    553     for (unsigned short I : Filter) {
    554       auto &MP = Matchers[I];
    555       if (EnableCheckProfiling)
    556         Timer.setBucket(&TimeByBucket[MP.second->getID()]);
    557       BoundNodesTreeBuilder Builder;
    558       if (MP.first.matchesNoKindCheck(DynNode, this, &Builder)) {
    559         MatchVisitor Visitor(ActiveASTContext, MP.second);
    560         Builder.visitMatches(&Visitor);
    561       }
    562     }
    563   }
    564 
    565   const std::vector<unsigned short> &
    566   getFilterForKind(ast_type_traits::ASTNodeKind Kind) {
    567     auto &Filter = MatcherFiltersMap[Kind];
    568     auto &Matchers = this->Matchers->DeclOrStmt;
    569     assert((Matchers.size() < USHRT_MAX) && "Too many matchers.");
    570     for (unsigned I = 0, E = Matchers.size(); I != E; ++I) {
    571       if (Matchers[I].first.canMatchNodesOfKind(Kind)) {
    572         Filter.push_back(I);
    573       }
    574     }
    575     return Filter;
    576   }
    577 
    578   /// @{
    579   /// \brief Overloads to pair the different node types to their matchers.
    580   void matchDispatch(const Decl *Node) {
    581     return matchWithFilter(ast_type_traits::DynTypedNode::create(*Node));
    582   }
    583   void matchDispatch(const Stmt *Node) {
    584     return matchWithFilter(ast_type_traits::DynTypedNode::create(*Node));
    585   }
    586 
    587   void matchDispatch(const Type *Node) {
    588     matchWithoutFilter(QualType(Node, 0), Matchers->Type);
    589   }
    590   void matchDispatch(const TypeLoc *Node) {
    591     matchWithoutFilter(*Node, Matchers->TypeLoc);
    592   }
    593   void matchDispatch(const QualType *Node) {
    594     matchWithoutFilter(*Node, Matchers->Type);
    595   }
    596   void matchDispatch(const NestedNameSpecifier *Node) {
    597     matchWithoutFilter(*Node, Matchers->NestedNameSpecifier);
    598   }
    599   void matchDispatch(const NestedNameSpecifierLoc *Node) {
    600     matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc);
    601   }
    602   void matchDispatch(const void *) { /* Do nothing. */ }
    603   /// @}
    604 
    605   // Returns whether an ancestor of \p Node matches \p Matcher.
    606   //
    607   // The order of matching ((which can lead to different nodes being bound in
    608   // case there are multiple matches) is breadth first search.
    609   //
    610   // To allow memoization in the very common case of having deeply nested
    611   // expressions inside a template function, we first walk up the AST, memoizing
    612   // the result of the match along the way, as long as there is only a single
    613   // parent.
    614   //
    615   // Once there are multiple parents, the breadth first search order does not
    616   // allow simple memoization on the ancestors. Thus, we only memoize as long
    617   // as there is a single parent.
    618   bool memoizedMatchesAncestorOfRecursively(
    619       const ast_type_traits::DynTypedNode &Node, const DynTypedMatcher &Matcher,
    620       BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) {
    621     if (Node.get<TranslationUnitDecl>() ==
    622         ActiveASTContext->getTranslationUnitDecl())
    623       return false;
    624     assert(Node.getMemoizationData() &&
    625            "Invariant broken: only nodes that support memoization may be "
    626            "used in the parent map.");
    627 
    628     MatchKey Key;
    629     Key.MatcherID = Matcher.getID();
    630     Key.Node = Node;
    631     Key.BoundNodes = *Builder;
    632 
    633     // Note that we cannot use insert and reuse the iterator, as recursive
    634     // calls to match might invalidate the result cache iterators.
    635     MemoizationMap::iterator I = ResultCache.find(Key);
    636     if (I != ResultCache.end()) {
    637       *Builder = I->second.Nodes;
    638       return I->second.ResultOfMatch;
    639     }
    640 
    641     MemoizedMatchResult Result;
    642     Result.ResultOfMatch = false;
    643     Result.Nodes = *Builder;
    644 
    645     const auto &Parents = ActiveASTContext->getParents(Node);
    646     assert(!Parents.empty() && "Found node that is not in the parent map.");
    647     if (Parents.size() == 1) {
    648       // Only one parent - do recursive memoization.
    649       const ast_type_traits::DynTypedNode Parent = Parents[0];
    650       if (Matcher.matches(Parent, this, &Result.Nodes)) {
    651         Result.ResultOfMatch = true;
    652       } else if (MatchMode != ASTMatchFinder::AMM_ParentOnly) {
    653         // Reset the results to not include the bound nodes from the failed
    654         // match above.
    655         Result.Nodes = *Builder;
    656         Result.ResultOfMatch = memoizedMatchesAncestorOfRecursively(
    657             Parent, Matcher, &Result.Nodes, MatchMode);
    658         // Once we get back from the recursive call, the result will be the
    659         // same as the parent's result.
    660       }
    661     } else {
    662       // Multiple parents - BFS over the rest of the nodes.
    663       llvm::DenseSet<const void *> Visited;
    664       std::deque<ast_type_traits::DynTypedNode> Queue(Parents.begin(),
    665                                                       Parents.end());
    666       while (!Queue.empty()) {
    667         Result.Nodes = *Builder;
    668         if (Matcher.matches(Queue.front(), this, &Result.Nodes)) {
    669           Result.ResultOfMatch = true;
    670           break;
    671         }
    672         if (MatchMode != ASTMatchFinder::AMM_ParentOnly) {
    673           for (const auto &Parent :
    674                ActiveASTContext->getParents(Queue.front())) {
    675             // Make sure we do not visit the same node twice.
    676             // Otherwise, we'll visit the common ancestors as often as there
    677             // are splits on the way down.
    678             if (Visited.insert(Parent.getMemoizationData()).second)
    679               Queue.push_back(Parent);
    680           }
    681         }
    682         Queue.pop_front();
    683       }
    684     }
    685 
    686     MemoizedMatchResult &CachedResult = ResultCache[Key];
    687     CachedResult = std::move(Result);
    688 
    689     *Builder = CachedResult.Nodes;
    690     return CachedResult.ResultOfMatch;
    691   }
    692 
    693   // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
    694   // the aggregated bound nodes for each match.
    695   class MatchVisitor : public BoundNodesTreeBuilder::Visitor {
    696   public:
    697     MatchVisitor(ASTContext* Context,
    698                  MatchFinder::MatchCallback* Callback)
    699       : Context(Context),
    700         Callback(Callback) {}
    701 
    702     void visitMatch(const BoundNodes& BoundNodesView) override {
    703       Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
    704     }
    705 
    706   private:
    707     ASTContext* Context;
    708     MatchFinder::MatchCallback* Callback;
    709   };
    710 
    711   // Returns true if 'TypeNode' has an alias that matches the given matcher.
    712   bool typeHasMatchingAlias(const Type *TypeNode,
    713                             const Matcher<NamedDecl> Matcher,
    714                             BoundNodesTreeBuilder *Builder) {
    715     const Type *const CanonicalType =
    716       ActiveASTContext->getCanonicalType(TypeNode);
    717     for (const TypedefNameDecl *Alias : TypeAliases.lookup(CanonicalType)) {
    718       BoundNodesTreeBuilder Result(*Builder);
    719       if (Matcher.matches(*Alias, this, &Result)) {
    720         *Builder = std::move(Result);
    721         return true;
    722       }
    723     }
    724     return false;
    725   }
    726 
    727   /// \brief Bucket to record map.
    728   ///
    729   /// Used to get the appropriate bucket for each matcher.
    730   llvm::StringMap<llvm::TimeRecord> TimeByBucket;
    731 
    732   const MatchFinder::MatchersByType *Matchers;
    733 
    734   /// \brief Filtered list of matcher indices for each matcher kind.
    735   ///
    736   /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node
    737   /// kind (and derived kinds) so it is a waste to try every matcher on every
    738   /// node.
    739   /// We precalculate a list of matchers that pass the toplevel restrict check.
    740   /// This also allows us to skip the restrict check at matching time. See
    741   /// use \c matchesNoKindCheck() above.
    742   llvm::DenseMap<ast_type_traits::ASTNodeKind, std::vector<unsigned short>>
    743       MatcherFiltersMap;
    744 
    745   const MatchFinder::MatchFinderOptions &Options;
    746   ASTContext *ActiveASTContext;
    747 
    748   // Maps a canonical type to its TypedefDecls.
    749   llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases;
    750 
    751   // Maps (matcher, node) -> the match result for memoization.
    752   typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap;
    753   MemoizationMap ResultCache;
    754 };
    755 
    756 static CXXRecordDecl *getAsCXXRecordDecl(const Type *TypeNode) {
    757   // Type::getAs<...>() drills through typedefs.
    758   if (TypeNode->getAs<DependentNameType>() != nullptr ||
    759       TypeNode->getAs<DependentTemplateSpecializationType>() != nullptr ||
    760       TypeNode->getAs<TemplateTypeParmType>() != nullptr)
    761     // Dependent names and template TypeNode parameters will be matched when
    762     // the template is instantiated.
    763     return nullptr;
    764   TemplateSpecializationType const *TemplateType =
    765       TypeNode->getAs<TemplateSpecializationType>();
    766   if (!TemplateType) {
    767     return TypeNode->getAsCXXRecordDecl();
    768   }
    769   if (TemplateType->getTemplateName().isDependent())
    770     // Dependent template specializations will be matched when the
    771     // template is instantiated.
    772     return nullptr;
    773 
    774   // For template specialization types which are specializing a template
    775   // declaration which is an explicit or partial specialization of another
    776   // template declaration, getAsCXXRecordDecl() returns the corresponding
    777   // ClassTemplateSpecializationDecl.
    778   //
    779   // For template specialization types which are specializing a template
    780   // declaration which is neither an explicit nor partial specialization of
    781   // another template declaration, getAsCXXRecordDecl() returns NULL and
    782   // we get the CXXRecordDecl of the templated declaration.
    783   CXXRecordDecl *SpecializationDecl = TemplateType->getAsCXXRecordDecl();
    784   if (SpecializationDecl) {
    785     return SpecializationDecl;
    786   }
    787   NamedDecl *Templated =
    788       TemplateType->getTemplateName().getAsTemplateDecl()->getTemplatedDecl();
    789   if (CXXRecordDecl *TemplatedRecord = dyn_cast<CXXRecordDecl>(Templated)) {
    790     return TemplatedRecord;
    791   }
    792   // Now it can still be that we have an alias template.
    793   TypeAliasDecl *AliasDecl = dyn_cast<TypeAliasDecl>(Templated);
    794   assert(AliasDecl);
    795   return getAsCXXRecordDecl(AliasDecl->getUnderlyingType().getTypePtr());
    796 }
    797 
    798 // Returns true if the given class is directly or indirectly derived
    799 // from a base type with the given name.  A class is not considered to be
    800 // derived from itself.
    801 bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
    802                                          const Matcher<NamedDecl> &Base,
    803                                          BoundNodesTreeBuilder *Builder) {
    804   if (!Declaration->hasDefinition())
    805     return false;
    806   for (const auto &It : Declaration->bases()) {
    807     const Type *TypeNode = It.getType().getTypePtr();
    808 
    809     if (typeHasMatchingAlias(TypeNode, Base, Builder))
    810       return true;
    811 
    812     CXXRecordDecl *ClassDecl = getAsCXXRecordDecl(TypeNode);
    813     if (!ClassDecl)
    814       continue;
    815     if (ClassDecl == Declaration) {
    816       // This can happen for recursive template definitions; if the
    817       // current declaration did not match, we can safely return false.
    818       return false;
    819     }
    820     BoundNodesTreeBuilder Result(*Builder);
    821     if (Base.matches(*ClassDecl, this, &Result)) {
    822       *Builder = std::move(Result);
    823       return true;
    824     }
    825     if (classIsDerivedFrom(ClassDecl, Base, Builder))
    826       return true;
    827   }
    828   return false;
    829 }
    830 
    831 bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
    832   if (!DeclNode) {
    833     return true;
    834   }
    835   match(*DeclNode);
    836   return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
    837 }
    838 
    839 bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode) {
    840   if (!StmtNode) {
    841     return true;
    842   }
    843   match(*StmtNode);
    844   return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode);
    845 }
    846 
    847 bool MatchASTVisitor::TraverseType(QualType TypeNode) {
    848   match(TypeNode);
    849   return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
    850 }
    851 
    852 bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) {
    853   // The RecursiveASTVisitor only visits types if they're not within TypeLocs.
    854   // We still want to find those types via matchers, so we match them here. Note
    855   // that the TypeLocs are structurally a shadow-hierarchy to the expressed
    856   // type, so we visit all involved parts of a compound type when matching on
    857   // each TypeLoc.
    858   match(TypeLocNode);
    859   match(TypeLocNode.getType());
    860   return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode);
    861 }
    862 
    863 bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
    864   match(*NNS);
    865   return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS);
    866 }
    867 
    868 bool MatchASTVisitor::TraverseNestedNameSpecifierLoc(
    869     NestedNameSpecifierLoc NNS) {
    870   match(NNS);
    871   // We only match the nested name specifier here (as opposed to traversing it)
    872   // because the traversal is already done in the parallel "Loc"-hierarchy.
    873   if (NNS.hasQualifier())
    874     match(*NNS.getNestedNameSpecifier());
    875   return
    876       RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS);
    877 }
    878 
    879 class MatchASTConsumer : public ASTConsumer {
    880 public:
    881   MatchASTConsumer(MatchFinder *Finder,
    882                    MatchFinder::ParsingDoneTestCallback *ParsingDone)
    883       : Finder(Finder), ParsingDone(ParsingDone) {}
    884 
    885 private:
    886   void HandleTranslationUnit(ASTContext &Context) override {
    887     if (ParsingDone != nullptr) {
    888       ParsingDone->run();
    889     }
    890     Finder->matchAST(Context);
    891   }
    892 
    893   MatchFinder *Finder;
    894   MatchFinder::ParsingDoneTestCallback *ParsingDone;
    895 };
    896 
    897 } // end namespace
    898 } // end namespace internal
    899 
    900 MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
    901                                       ASTContext *Context)
    902   : Nodes(Nodes), Context(Context),
    903     SourceManager(&Context->getSourceManager()) {}
    904 
    905 MatchFinder::MatchCallback::~MatchCallback() {}
    906 MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
    907 
    908 MatchFinder::MatchFinder(MatchFinderOptions Options)
    909     : Options(std::move(Options)), ParsingDone(nullptr) {}
    910 
    911 MatchFinder::~MatchFinder() {}
    912 
    913 void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
    914                              MatchCallback *Action) {
    915   Matchers.DeclOrStmt.push_back(std::make_pair(NodeMatch, Action));
    916   Matchers.AllCallbacks.push_back(Action);
    917 }
    918 
    919 void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
    920                              MatchCallback *Action) {
    921   Matchers.Type.push_back(std::make_pair(NodeMatch, Action));
    922   Matchers.AllCallbacks.push_back(Action);
    923 }
    924 
    925 void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
    926                              MatchCallback *Action) {
    927   Matchers.DeclOrStmt.push_back(std::make_pair(NodeMatch, Action));
    928   Matchers.AllCallbacks.push_back(Action);
    929 }
    930 
    931 void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch,
    932                              MatchCallback *Action) {
    933   Matchers.NestedNameSpecifier.push_back(std::make_pair(NodeMatch, Action));
    934   Matchers.AllCallbacks.push_back(Action);
    935 }
    936 
    937 void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch,
    938                              MatchCallback *Action) {
    939   Matchers.NestedNameSpecifierLoc.push_back(std::make_pair(NodeMatch, Action));
    940   Matchers.AllCallbacks.push_back(Action);
    941 }
    942 
    943 void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch,
    944                              MatchCallback *Action) {
    945   Matchers.TypeLoc.push_back(std::make_pair(NodeMatch, Action));
    946   Matchers.AllCallbacks.push_back(Action);
    947 }
    948 
    949 bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
    950                                     MatchCallback *Action) {
    951   if (NodeMatch.canConvertTo<Decl>()) {
    952     addMatcher(NodeMatch.convertTo<Decl>(), Action);
    953     return true;
    954   } else if (NodeMatch.canConvertTo<QualType>()) {
    955     addMatcher(NodeMatch.convertTo<QualType>(), Action);
    956     return true;
    957   } else if (NodeMatch.canConvertTo<Stmt>()) {
    958     addMatcher(NodeMatch.convertTo<Stmt>(), Action);
    959     return true;
    960   } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) {
    961     addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action);
    962     return true;
    963   } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) {
    964     addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action);
    965     return true;
    966   } else if (NodeMatch.canConvertTo<TypeLoc>()) {
    967     addMatcher(NodeMatch.convertTo<TypeLoc>(), Action);
    968     return true;
    969   }
    970   return false;
    971 }
    972 
    973 std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() {
    974   return llvm::make_unique<internal::MatchASTConsumer>(this, ParsingDone);
    975 }
    976 
    977 void MatchFinder::match(const clang::ast_type_traits::DynTypedNode &Node,
    978                         ASTContext &Context) {
    979   internal::MatchASTVisitor Visitor(&Matchers, Options);
    980   Visitor.set_active_ast_context(&Context);
    981   Visitor.match(Node);
    982 }
    983 
    984 void MatchFinder::matchAST(ASTContext &Context) {
    985   internal::MatchASTVisitor Visitor(&Matchers, Options);
    986   Visitor.set_active_ast_context(&Context);
    987   Visitor.onStartOfTranslationUnit();
    988   Visitor.TraverseDecl(Context.getTranslationUnitDecl());
    989   Visitor.onEndOfTranslationUnit();
    990 }
    991 
    992 void MatchFinder::registerTestCallbackAfterParsing(
    993     MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
    994   ParsingDone = NewParsingDone;
    995 }
    996 
    997 StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; }
    998 
    999 } // end namespace ast_matchers
   1000 } // end namespace clang
   1001