Home | History | Annotate | Download | only in AST
      1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  Provides MatchVerifier, a base class to implement gtest matchers that
     11 //  verify things that can be matched on the AST.
     12 //
     13 //  Also implements matchers based on MatchVerifier:
     14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
     15 //  the expected source location or source range.
     16 //
     17 //===----------------------------------------------------------------------===//
     18 
     19 #include "clang/AST/ASTContext.h"
     20 #include "clang/ASTMatchers/ASTMatchFinder.h"
     21 #include "clang/ASTMatchers/ASTMatchers.h"
     22 #include "clang/Tooling/Tooling.h"
     23 #include "gtest/gtest.h"
     24 
     25 namespace clang {
     26 namespace ast_matchers {
     27 
     28 enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_CXX11, Lang_OpenCL };
     29 
     30 /// \brief Base class for verifying some property of nodes found by a matcher.
     31 template <typename NodeType>
     32 class MatchVerifier : public MatchFinder::MatchCallback {
     33 public:
     34   template <typename MatcherType>
     35   testing::AssertionResult match(const std::string &Code,
     36                                  const MatcherType &AMatcher) {
     37     std::vector<std::string> Args;
     38     return match(Code, AMatcher, Args, Lang_CXX);
     39   }
     40 
     41   template <typename MatcherType>
     42   testing::AssertionResult match(const std::string &Code,
     43                                  const MatcherType &AMatcher,
     44                                  Language L) {
     45     std::vector<std::string> Args;
     46     return match(Code, AMatcher, Args, L);
     47   }
     48 
     49   template <typename MatcherType>
     50   testing::AssertionResult match(const std::string &Code,
     51                                  const MatcherType &AMatcher,
     52                                  std::vector<std::string>& Args,
     53                                  Language L);
     54 
     55 protected:
     56   virtual void run(const MatchFinder::MatchResult &Result);
     57   virtual void verify(const MatchFinder::MatchResult &Result,
     58                       const NodeType &Node) {}
     59 
     60   void setFailure(const Twine &Result) {
     61     Verified = false;
     62     VerifyResult = Result.str();
     63   }
     64 
     65   void setSuccess() {
     66     Verified = true;
     67   }
     68 
     69 private:
     70   bool Verified;
     71   std::string VerifyResult;
     72 };
     73 
     74 /// \brief Runs a matcher over some code, and returns the result of the
     75 /// verifier for the matched node.
     76 template <typename NodeType> template <typename MatcherType>
     77 testing::AssertionResult MatchVerifier<NodeType>::match(
     78     const std::string &Code, const MatcherType &AMatcher,
     79     std::vector<std::string>& Args, Language L) {
     80   MatchFinder Finder;
     81   Finder.addMatcher(AMatcher.bind(""), this);
     82   std::unique_ptr<tooling::FrontendActionFactory> Factory(
     83       tooling::newFrontendActionFactory(&Finder));
     84 
     85   StringRef FileName;
     86   switch (L) {
     87   case Lang_C:
     88     Args.push_back("-std=c99");
     89     FileName = "input.c";
     90     break;
     91   case Lang_C89:
     92     Args.push_back("-std=c89");
     93     FileName = "input.c";
     94     break;
     95   case Lang_CXX:
     96     Args.push_back("-std=c++98");
     97     FileName = "input.cc";
     98     break;
     99   case Lang_CXX11:
    100     Args.push_back("-std=c++11");
    101     FileName = "input.cc";
    102     break;
    103   case Lang_OpenCL:
    104     FileName = "input.cl";
    105   }
    106 
    107   // Default to failure in case callback is never called
    108   setFailure("Could not find match");
    109   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
    110     return testing::AssertionFailure() << "Parsing error";
    111   if (!Verified)
    112     return testing::AssertionFailure() << VerifyResult;
    113   return testing::AssertionSuccess();
    114 }
    115 
    116 template <typename NodeType>
    117 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
    118   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
    119   if (!Node) {
    120     setFailure("Matched node has wrong type");
    121   } else {
    122     // Callback has been called, default to success.
    123     setSuccess();
    124     verify(Result, *Node);
    125   }
    126 }
    127 
    128 template <>
    129 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
    130     const MatchFinder::MatchResult &Result) {
    131   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
    132   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
    133   if (I == M.end()) {
    134     setFailure("Node was not bound");
    135   } else {
    136     // Callback has been called, default to success.
    137     setSuccess();
    138     verify(Result, I->second);
    139   }
    140 }
    141 
    142 /// \brief Verify whether a node has the correct source location.
    143 ///
    144 /// By default, Node.getSourceLocation() is checked. This can be changed
    145 /// by overriding getLocation().
    146 template <typename NodeType>
    147 class LocationVerifier : public MatchVerifier<NodeType> {
    148 public:
    149   void expectLocation(unsigned Line, unsigned Column) {
    150     ExpectLine = Line;
    151     ExpectColumn = Column;
    152   }
    153 
    154 protected:
    155   void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
    156     SourceLocation Loc = getLocation(Node);
    157     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
    158     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
    159     if (Line != ExpectLine || Column != ExpectColumn) {
    160       std::string MsgStr;
    161       llvm::raw_string_ostream Msg(MsgStr);
    162       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
    163           << ">, found <";
    164       Loc.print(Msg, *Result.SourceManager);
    165       Msg << '>';
    166       this->setFailure(Msg.str());
    167     }
    168   }
    169 
    170   virtual SourceLocation getLocation(const NodeType &Node) {
    171     return Node.getLocation();
    172   }
    173 
    174 private:
    175   unsigned ExpectLine, ExpectColumn;
    176 };
    177 
    178 /// \brief Verify whether a node has the correct source range.
    179 ///
    180 /// By default, Node.getSourceRange() is checked. This can be changed
    181 /// by overriding getRange().
    182 template <typename NodeType>
    183 class RangeVerifier : public MatchVerifier<NodeType> {
    184 public:
    185   void expectRange(unsigned BeginLine, unsigned BeginColumn,
    186                    unsigned EndLine, unsigned EndColumn) {
    187     ExpectBeginLine = BeginLine;
    188     ExpectBeginColumn = BeginColumn;
    189     ExpectEndLine = EndLine;
    190     ExpectEndColumn = EndColumn;
    191   }
    192 
    193 protected:
    194   void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
    195     SourceRange R = getRange(Node);
    196     SourceLocation Begin = R.getBegin();
    197     SourceLocation End = R.getEnd();
    198     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
    199     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
    200     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
    201     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
    202     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
    203         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
    204       std::string MsgStr;
    205       llvm::raw_string_ostream Msg(MsgStr);
    206       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
    207           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
    208       Begin.print(Msg, *Result.SourceManager);
    209       Msg << '-';
    210       End.print(Msg, *Result.SourceManager);
    211       Msg << '>';
    212       this->setFailure(Msg.str());
    213     }
    214   }
    215 
    216   virtual SourceRange getRange(const NodeType &Node) {
    217     return Node.getSourceRange();
    218   }
    219 
    220 private:
    221   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
    222 };
    223 
    224 /// \brief Verify whether a node's dump contains a given substring.
    225 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
    226 public:
    227   void expectSubstring(const std::string &Str) {
    228     ExpectSubstring = Str;
    229   }
    230 
    231 protected:
    232   void verify(const MatchFinder::MatchResult &Result,
    233               const ast_type_traits::DynTypedNode &Node) {
    234     std::string DumpStr;
    235     llvm::raw_string_ostream Dump(DumpStr);
    236     Node.dump(Dump, *Result.SourceManager);
    237 
    238     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
    239       std::string MsgStr;
    240       llvm::raw_string_ostream Msg(MsgStr);
    241       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
    242           << Dump.str() << '>';
    243       this->setFailure(Msg.str());
    244     }
    245   }
    246 
    247 private:
    248   std::string ExpectSubstring;
    249 };
    250 
    251 /// \brief Verify whether a node's pretty print matches a given string.
    252 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
    253 public:
    254   void expectString(const std::string &Str) {
    255     ExpectString = Str;
    256   }
    257 
    258 protected:
    259   void verify(const MatchFinder::MatchResult &Result,
    260               const ast_type_traits::DynTypedNode &Node) {
    261     std::string PrintStr;
    262     llvm::raw_string_ostream Print(PrintStr);
    263     Node.print(Print, Result.Context->getPrintingPolicy());
    264 
    265     if (Print.str() != ExpectString) {
    266       std::string MsgStr;
    267       llvm::raw_string_ostream Msg(MsgStr);
    268       Msg << "Expected pretty print <" << ExpectString << ">, found <"
    269           << Print.str() << '>';
    270       this->setFailure(Msg.str());
    271     }
    272   }
    273 
    274 private:
    275   std::string ExpectString;
    276 };
    277 
    278 } // end namespace ast_matchers
    279 } // end namespace clang
    280