Home | History | Annotate | Download | only in AST
      1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  Provides MatchVerifier, a base class to implement gtest matchers that
     11 //  verify things that can be matched on the AST.
     12 //
     13 //  Also implements matchers based on MatchVerifier:
     14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
     15 //  the expected source location or source range.
     16 //
     17 //===----------------------------------------------------------------------===//
     18 
     19 #ifndef LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
     20 #define LLVM_CLANG_UNITTESTS_AST_MATCHVERIFIER_H
     21 
     22 #include "clang/AST/ASTContext.h"
     23 #include "clang/ASTMatchers/ASTMatchFinder.h"
     24 #include "clang/ASTMatchers/ASTMatchers.h"
     25 #include "clang/Tooling/Tooling.h"
     26 #include "gtest/gtest.h"
     27 
     28 namespace clang {
     29 namespace ast_matchers {
     30 
     31 enum Language {
     32     Lang_C,
     33     Lang_C89,
     34     Lang_CXX,
     35     Lang_CXX11,
     36     Lang_OpenCL,
     37     Lang_OBJCXX
     38 };
     39 
     40 /// \brief Base class for verifying some property of nodes found by a matcher.
     41 template <typename NodeType>
     42 class MatchVerifier : public MatchFinder::MatchCallback {
     43 public:
     44   template <typename MatcherType>
     45   testing::AssertionResult match(const std::string &Code,
     46                                  const MatcherType &AMatcher) {
     47     std::vector<std::string> Args;
     48     return match(Code, AMatcher, Args, Lang_CXX);
     49   }
     50 
     51   template <typename MatcherType>
     52   testing::AssertionResult match(const std::string &Code,
     53                                  const MatcherType &AMatcher,
     54                                  Language L) {
     55     std::vector<std::string> Args;
     56     return match(Code, AMatcher, Args, L);
     57   }
     58 
     59   template <typename MatcherType>
     60   testing::AssertionResult match(const std::string &Code,
     61                                  const MatcherType &AMatcher,
     62                                  std::vector<std::string>& Args,
     63                                  Language L);
     64 
     65 protected:
     66   void run(const MatchFinder::MatchResult &Result) override;
     67   virtual void verify(const MatchFinder::MatchResult &Result,
     68                       const NodeType &Node) {}
     69 
     70   void setFailure(const Twine &Result) {
     71     Verified = false;
     72     VerifyResult = Result.str();
     73   }
     74 
     75   void setSuccess() {
     76     Verified = true;
     77   }
     78 
     79 private:
     80   bool Verified;
     81   std::string VerifyResult;
     82 };
     83 
     84 /// \brief Runs a matcher over some code, and returns the result of the
     85 /// verifier for the matched node.
     86 template <typename NodeType> template <typename MatcherType>
     87 testing::AssertionResult MatchVerifier<NodeType>::match(
     88     const std::string &Code, const MatcherType &AMatcher,
     89     std::vector<std::string>& Args, Language L) {
     90   MatchFinder Finder;
     91   Finder.addMatcher(AMatcher.bind(""), this);
     92   std::unique_ptr<tooling::FrontendActionFactory> Factory(
     93       tooling::newFrontendActionFactory(&Finder));
     94 
     95   StringRef FileName;
     96   switch (L) {
     97   case Lang_C:
     98     Args.push_back("-std=c99");
     99     FileName = "input.c";
    100     break;
    101   case Lang_C89:
    102     Args.push_back("-std=c89");
    103     FileName = "input.c";
    104     break;
    105   case Lang_CXX:
    106     Args.push_back("-std=c++98");
    107     FileName = "input.cc";
    108     break;
    109   case Lang_CXX11:
    110     Args.push_back("-std=c++11");
    111     FileName = "input.cc";
    112     break;
    113   case Lang_OpenCL:
    114     FileName = "input.cl";
    115     break;
    116   case Lang_OBJCXX:
    117     FileName = "input.mm";
    118     break;
    119   }
    120 
    121   // Default to failure in case callback is never called
    122   setFailure("Could not find match");
    123   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
    124     return testing::AssertionFailure() << "Parsing error";
    125   if (!Verified)
    126     return testing::AssertionFailure() << VerifyResult;
    127   return testing::AssertionSuccess();
    128 }
    129 
    130 template <typename NodeType>
    131 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
    132   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
    133   if (!Node) {
    134     setFailure("Matched node has wrong type");
    135   } else {
    136     // Callback has been called, default to success.
    137     setSuccess();
    138     verify(Result, *Node);
    139   }
    140 }
    141 
    142 template <>
    143 inline void MatchVerifier<ast_type_traits::DynTypedNode>::run(
    144     const MatchFinder::MatchResult &Result) {
    145   BoundNodes::IDToNodeMap M = Result.Nodes.getMap();
    146   BoundNodes::IDToNodeMap::const_iterator I = M.find("");
    147   if (I == M.end()) {
    148     setFailure("Node was not bound");
    149   } else {
    150     // Callback has been called, default to success.
    151     setSuccess();
    152     verify(Result, I->second);
    153   }
    154 }
    155 
    156 /// \brief Verify whether a node has the correct source location.
    157 ///
    158 /// By default, Node.getSourceLocation() is checked. This can be changed
    159 /// by overriding getLocation().
    160 template <typename NodeType>
    161 class LocationVerifier : public MatchVerifier<NodeType> {
    162 public:
    163   void expectLocation(unsigned Line, unsigned Column) {
    164     ExpectLine = Line;
    165     ExpectColumn = Column;
    166   }
    167 
    168 protected:
    169   void verify(const MatchFinder::MatchResult &Result,
    170               const NodeType &Node) override {
    171     SourceLocation Loc = getLocation(Node);
    172     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
    173     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
    174     if (Line != ExpectLine || Column != ExpectColumn) {
    175       std::string MsgStr;
    176       llvm::raw_string_ostream Msg(MsgStr);
    177       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
    178           << ">, found <";
    179       Loc.print(Msg, *Result.SourceManager);
    180       Msg << '>';
    181       this->setFailure(Msg.str());
    182     }
    183   }
    184 
    185   virtual SourceLocation getLocation(const NodeType &Node) {
    186     return Node.getLocation();
    187   }
    188 
    189 private:
    190   unsigned ExpectLine, ExpectColumn;
    191 };
    192 
    193 /// \brief Verify whether a node has the correct source range.
    194 ///
    195 /// By default, Node.getSourceRange() is checked. This can be changed
    196 /// by overriding getRange().
    197 template <typename NodeType>
    198 class RangeVerifier : public MatchVerifier<NodeType> {
    199 public:
    200   void expectRange(unsigned BeginLine, unsigned BeginColumn,
    201                    unsigned EndLine, unsigned EndColumn) {
    202     ExpectBeginLine = BeginLine;
    203     ExpectBeginColumn = BeginColumn;
    204     ExpectEndLine = EndLine;
    205     ExpectEndColumn = EndColumn;
    206   }
    207 
    208 protected:
    209   void verify(const MatchFinder::MatchResult &Result,
    210               const NodeType &Node) override {
    211     SourceRange R = getRange(Node);
    212     SourceLocation Begin = R.getBegin();
    213     SourceLocation End = R.getEnd();
    214     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
    215     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
    216     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
    217     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
    218     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
    219         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
    220       std::string MsgStr;
    221       llvm::raw_string_ostream Msg(MsgStr);
    222       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
    223           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
    224       Begin.print(Msg, *Result.SourceManager);
    225       Msg << '-';
    226       End.print(Msg, *Result.SourceManager);
    227       Msg << '>';
    228       this->setFailure(Msg.str());
    229     }
    230   }
    231 
    232   virtual SourceRange getRange(const NodeType &Node) {
    233     return Node.getSourceRange();
    234   }
    235 
    236 private:
    237   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
    238 };
    239 
    240 /// \brief Verify whether a node's dump contains a given substring.
    241 class DumpVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
    242 public:
    243   void expectSubstring(const std::string &Str) {
    244     ExpectSubstring = Str;
    245   }
    246 
    247 protected:
    248   void verify(const MatchFinder::MatchResult &Result,
    249               const ast_type_traits::DynTypedNode &Node) override {
    250     std::string DumpStr;
    251     llvm::raw_string_ostream Dump(DumpStr);
    252     Node.dump(Dump, *Result.SourceManager);
    253 
    254     if (Dump.str().find(ExpectSubstring) == std::string::npos) {
    255       std::string MsgStr;
    256       llvm::raw_string_ostream Msg(MsgStr);
    257       Msg << "Expected dump substring <" << ExpectSubstring << ">, found <"
    258           << Dump.str() << '>';
    259       this->setFailure(Msg.str());
    260     }
    261   }
    262 
    263 private:
    264   std::string ExpectSubstring;
    265 };
    266 
    267 /// \brief Verify whether a node's pretty print matches a given string.
    268 class PrintVerifier : public MatchVerifier<ast_type_traits::DynTypedNode> {
    269 public:
    270   void expectString(const std::string &Str) {
    271     ExpectString = Str;
    272   }
    273 
    274 protected:
    275   void verify(const MatchFinder::MatchResult &Result,
    276               const ast_type_traits::DynTypedNode &Node) override {
    277     std::string PrintStr;
    278     llvm::raw_string_ostream Print(PrintStr);
    279     Node.print(Print, Result.Context->getPrintingPolicy());
    280 
    281     if (Print.str() != ExpectString) {
    282       std::string MsgStr;
    283       llvm::raw_string_ostream Msg(MsgStr);
    284       Msg << "Expected pretty print <" << ExpectString << ">, found <"
    285           << Print.str() << '>';
    286       this->setFailure(Msg.str());
    287     }
    288   }
    289 
    290 private:
    291   std::string ExpectString;
    292 };
    293 
    294 } // end namespace ast_matchers
    295 } // end namespace clang
    296 
    297 #endif
    298