Home | History | Annotate | Download | only in AST
      1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  Provides MatchVerifier, a base class to implement gtest matchers that
     11 //  verify things that can be matched on the AST.
     12 //
     13 //  Also implements matchers based on MatchVerifier:
     14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
     15 //  the expected source location or source range.
     16 //
     17 //===----------------------------------------------------------------------===//
     18 
     19 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
     20 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
     21 
     22 #include "clang/AST/ASTContext.h"
     23 #include "clang/ASTMatchers/ASTMatchFinder.h"
     24 #include "clang/ASTMatchers/ASTMatchers.h"
     25 #include "clang/Tooling/Tooling.h"
     26 #include "gtest/gtest.h"
     27 
     28 namespace clang {
     29 namespace ast_matchers {
     30 
     31 enum Language {
     32     Lang_C,
     33     Lang_C89,
     34     Lang_CXX,
     35     Lang_CXX11,
     36     Lang_OpenCL,
     37     Lang_OBJCXX
     38 };
     39 
     40 /// \brief Base class for verifying some property of nodes found by a matcher.
     41 template <typename NodeType>
     42 class MatchVerifier : public MatchFinder::MatchCallback {
     43 public:
     44   template <typename MatcherType>
     45   testing::AssertionResult match(const std::string &Code,
     46                                  const MatcherType &AMatcher) {
     47     std::vector<std::string> Args;
     48     return match(Code, AMatcher, Args, Lang_CXX);
     49   }
     50 
     51   template <typename MatcherType>
     52   testing::AssertionResult match(const std::string &Code,
     53                                  const MatcherType &AMatcher,
     54                                  Language L) {
     55     std::vector<std::string> Args;
     56     return match(Code, AMatcher, Args, L);
     57   }
     58 
     59   template <typename MatcherType>
     60   testing::AssertionResult match(const std::string &Code,
     61                                  const MatcherType &AMatcher,
     62                                  std::vector<std::string>& Args,
     63                                  Language L);
     64 
     65   template <typename MatcherType>
     66   testing::AssertionResult match(const Decl *D, const MatcherType &AMatcher);
     67 
     68 protected:
     69   void run(const MatchFinder::MatchResult &Result) override;
     70   virtual void verify(const MatchFinder::MatchResult &Result,
     71                       const NodeType &Node) {}
     72 
     73   void setFailure(const Twine &Result) {
     74     Verified = false;
     75     VerifyResult = Result.str();
     76   }
     77 
     78   void setSuccess() {
     79     Verified = true;
     80   }
     81 
     82 private:
     83   bool Verified;
     84   std::string VerifyResult;
     85 };
     86 
     87 /// \brief Runs a matcher over some code, and returns the result of the
     88 /// verifier for the matched node.
     89 template <typename NodeType> template <typename MatcherType>
     90 testing::AssertionResult MatchVerifier<NodeType>::match(
     91     const std::string &Code, const MatcherType &AMatcher,
     92     std::vector<std::string>& Args, Language L) {
     93   MatchFinder Finder;
     94   Finder.addMatcher(AMatcher.bind(""), this);
     95   std::unique_ptr<tooling::FrontendActionFactory> Factory(
     96       tooling::newFrontendActionFactory(&Finder));
     97 
     98   StringRef FileName;
     99   switch (L) {
    100   case Lang_C:
    101     Args.push_back("-std=c99");
    102     FileName = "input.c";
    103     break;
    104   case Lang_C89:
    105     Args.push_back("-std=c89");
    106     FileName = "input.c";
    107     break;
    108   case Lang_CXX:
    109     Args.push_back("-std=c++98");
    110     FileName = "input.cc";
    111     break;
    112   case Lang_CXX11:
    113     Args.push_back("-std=c++11");
    114     FileName = "input.cc";
    115     break;
    116   case Lang_OpenCL:
    117     FileName = "input.cl";
    118     break;
    119   case Lang_OBJCXX:
    120     FileName = "input.mm";
    121     break;
    122   }
    123 
    124   // Default to failure in case callback is never called
    125   setFailure("Could not find match");
    126   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
    127     return testing::AssertionFailure() << "Parsing error";
    128   if (!Verified)
    129     return testing::AssertionFailure() << VerifyResult;
    130   return testing::AssertionSuccess();
    131 }
    132 
    133 /// \brief Runs a matcher over some AST, and returns the result of the
    134 /// verifier for the matched node.
    135 template <typename NodeType> template <typename MatcherType>
    136 testing::AssertionResult MatchVerifier<NodeType>::match(
    137     const Decl *D, const MatcherType &AMatcher) {
    138   MatchFinder Finder;
    139   Finder.addMatcher(AMatcher.bind(""), this);
    140 
    141   setFailure("Could not find match");
    142   Finder.match(*D, D->getASTContext());
    143 
    144   if (!Verified)
    145     return testing::AssertionFailure() << VerifyResult;
    146   return testing::AssertionSuccess();
    147 }
    148 
    149 template <typename NodeType>
    150 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
    151   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
    152   if (!Node) {
    153     setFailure("Matched node has wrong type");
    154   } else {
    155     // Callback has been called, default to success.
    156     setSuccess();
    157     verify(Result, *Node);
    158   }
    159 }
    160 
    161 template <>
    162 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
    163     const MatchFinder::MatchResult &Result) {
    164   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
    165   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
    166   if (I == M.end()) {
    167     setFailure("Node was not bound");
    168   } else {
    169     // Callback has been called, default to success.
    170     setSuccess();
    171     verify(Result, I->second);
    172   }
    173 }
    174 
    175 /// \brief Verify whether a node has the correct source location.
    176 ///
    177 /// By default, Node.getSourceLocation() is checked. This can be changed
    178 /// by overriding getLocation().
    179 template <typename NodeType>
    180 class LocationVerifier : public MatchVerifier<NodeType> {
    181 public:
    182   void expectLocation(unsigned Line, unsigned Column) {
    183     ExpectLine = Line;
    184     ExpectColumn = Column;
    185   }
    186 
    187 protected:
    188   void verify(const MatchFinder::MatchResult &Result,
    189               const NodeType &Node) override {
    190     SourceLocation Loc = getLocation(Node);
    191     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
    192     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
    193     if (Line != ExpectLine || Column != ExpectColumn) {
    194       std::string MsgStr;
    195       llvm::raw_string_ostream Msg(MsgStr);
    196       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
    197           << ">, found <";
    198       Loc.print(Msg, *Result.SourceManager);
    199       Msg << '>';
    200       this->setFailure(Msg.str());
    201     }
    202   }
    203 
    204   virtual SourceLocation getLocation(const NodeType &Node) {
    205     return Node.getLocation();
    206   }
    207 
    208 private:
    209   unsigned ExpectLine, ExpectColumn;
    210 };
    211 
    212 /// \brief Verify whether a node has the correct source range.
    213 ///
    214 /// By default, Node.getSourceRange() is checked. This can be changed
    215 /// by overriding getRange().
    216 template <typename NodeType>
    217 class RangeVerifier : public MatchVerifier<NodeType> {
    218 public:
    219   void expectRange(unsigned BeginLine, unsigned BeginColumn,
    220                    unsigned EndLine, unsigned EndColumn) {
    221     ExpectBeginLine = BeginLine;
    222     ExpectBeginColumn = BeginColumn;
    223     ExpectEndLine = EndLine;
    224     ExpectEndColumn = EndColumn;
    225   }
    226 
    227 protected:
    228   void verify(const MatchFinder::MatchResult &Result,
    229               const NodeType &Node) override {
    230     SourceRange R = getRange(Node);
    231     SourceLocation Begin = R.getBegin();
    232     SourceLocation End = R.getEnd();
    233     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
    234     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
    235     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
    236     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
    237     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
    238         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
    239       std::string MsgStr;
    240       llvm::raw_string_ostream Msg(MsgStr);
    241       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
    242           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
    243       Begin.print(Msg, *Result.SourceManager);
    244       Msg << '-';
    245       End.print(Msg, *Result.SourceManager);
    246       Msg << '>';
    247       this->setFailure(Msg.str());
    248     }
    249   }
    250 
    251   virtual SourceRange getRange(const NodeType &Node) {
    252     return Node.getSourceRange();
    253   }
    254 
    255 private:
    256   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
    257 };
    258 
    259 /// \brief Verify whether a node's dump contains a given substring.
    260 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
    261 public:
    262   void expectSubstring(const std::string &Str) {
    263     ExpectSubstring = Str;
    264   }
    265 
    266 protected:
    267   void verify(const MatchFinder::MatchResult &Result,
    268               const ast_type_traits::DynTypedNode &Node) override {
    269     std::string DumpStr;
    270     llvm::raw_string_ostream Dump(DumpStr);
    271     Node.dump(Dump, *Result.SourceManager);
    272 
    273     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
    274       std::string MsgStr;
    275       llvm::raw_string_ostream Msg(MsgStr);
    276       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
    277           << Dump.str() << '>';
    278       this->setFailure(Msg.str());
    279     }
    280   }
    281 
    282 private:
    283   std::string ExpectSubstring;
    284 };
    285 
    286 /// \brief Verify whether a node's pretty print matches a given string.
    287 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
    288 public:
    289   void expectString(const std::string &Str) {
    290     ExpectString = Str;
    291   }
    292 
    293 protected:
    294   void verify(const MatchFinder::MatchResult &Result,
    295               const ast_type_traits::DynTypedNode &Node) override {
    296     std::string PrintStr;
    297     llvm::raw_string_ostream Print(PrintStr);
    298     Node.print(Print, Result.Context->getPrintingPolicy());
    299 
    300     if (Print.str() != ExpectString) {
    301       std::string MsgStr;
    302       llvm::raw_string_ostream Msg(MsgStr);
    303       Msg << "Expected pretty print <" << ExpectString << ">, found <"
    304           << Print.str() << '>';
    305       this->setFailure(Msg.str());
    306     }
    307   }
    308 
    309 private:
    310   std::string ExpectString;
    311 };
    312 
    313 } // end namespace ast_matchers
    314 } // end namespace clang
    315 
    316 #endif
    317