Home | History | Annotate | Download | only in AST
      1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 //  Provides MatchVerifier, a base class to implement gtest matchers that
     11 //  verify things that can be matched on the AST.
     12 //
     13 //  Also implements matchers based on MatchVerifier:
     14 //  LocationVerifier and RangeVerifier to verify whether a matched node has
     15 //  the expected source location or source range.
     16 //
     17 //===----------------------------------------------------------------------===//
     18 
     19 #include "clang/AST/ASTContext.h"
     20 #include "clang/ASTMatchers/ASTMatchFinder.h"
     21 #include "clang/ASTMatchers/ASTMatchers.h"
     22 #include "clang/Tooling/Tooling.h"
     23 #include "gtest/gtest.h"
     24 
     25 namespace clang {
     26 namespace ast_matchers {
     27 
     28 enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_CXX11, Lang_OpenCL };
     29 
     30 /// \brief Base class for verifying some property of nodes found by a matcher.
     31 template <typename NodeType>
     32 class MatchVerifier : public MatchFinder::MatchCallback {
     33 public:
     34   template <typename MatcherType>
     35   testing::AssertionResult match(const std::string &Code,
     36                                  const MatcherType &AMatcher) {
     37     std::vector<std::string> Args;
     38     return match(Code, AMatcher, Args, Lang_CXX);
     39   }
     40 
     41   template <typename MatcherType>
     42   testing::AssertionResult match(const std::string &Code,
     43                                  const MatcherType &AMatcher,
     44                                  Language L) {
     45     std::vector<std::string> Args;
     46     return match(Code, AMatcher, Args, L);
     47   }
     48 
     49   template <typename MatcherType>
     50   testing::AssertionResult match(const std::string &Code,
     51                                  const MatcherType &AMatcher,
     52                                  std::vector<std::string>& Args,
     53                                  Language L);
     54 
     55 protected:
     56   virtual void run(const MatchFinder::MatchResult &Result);
     57   virtual void verify(const MatchFinder::MatchResult &Result,
     58                       const NodeType &Node) {}
     59 
     60   void setFailure(const Twine &Result) {
     61     Verified = false;
     62     VerifyResult = Result.str();
     63   }
     64 
     65   void setSuccess() {
     66     Verified = true;
     67   }
     68 
     69 private:
     70   bool Verified;
     71   std::string VerifyResult;
     72 };
     73 
     74 /// \brief Runs a matcher over some code, and returns the result of the
     75 /// verifier for the matched node.
     76 template <typename NodeType> template <typename MatcherType>
     77 testing::AssertionResult MatchVerifier<NodeType>::match(
     78     const std::string &Code, const MatcherType &AMatcher,
     79     std::vector<std::string>& Args, Language L) {
     80   MatchFinder Finder;
     81   Finder.addMatcher(AMatcher.bind(""), this);
     82   OwningPtr<tooling::FrontendActionFactory> Factory(
     83       tooling::newFrontendActionFactory(&Finder));
     84 
     85   StringRef FileName;
     86   switch (L) {
     87   case Lang_C:
     88     Args.push_back("-std=c99");
     89     FileName = "input.c";
     90     break;
     91   case Lang_C89:
     92     Args.push_back("-std=c89");
     93     FileName = "input.c";
     94     break;
     95   case Lang_CXX:
     96     Args.push_back("-std=c++98");
     97     FileName = "input.cc";
     98     break;
     99   case Lang_CXX11:
    100     Args.push_back("-std=c++11");
    101     FileName = "input.cc";
    102     break;
    103   case Lang_OpenCL:
    104     FileName = "input.cl";
    105   }
    106 
    107   // Default to failure in case callback is never called
    108   setFailure("Could not find match");
    109   if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName))
    110     return testing::AssertionFailure() << "Parsing error";
    111   if (!Verified)
    112     return testing::AssertionFailure() << VerifyResult;
    113   return testing::AssertionSuccess();
    114 }
    115 
    116 template <typename NodeType>
    117 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) {
    118   const NodeType *Node = Result.Nodes.getNodeAs<NodeType>("");
    119   if (!Node) {
    120     setFailure("Matched node has wrong type");
    121   } else {
    122     // Callback has been called, default to success.
    123     setSuccess();
    124     verify(Result, *Node);
    125   }
    126 }
    127 
    128 /// \brief Verify whether a node has the correct source location.
    129 ///
    130 /// By default, Node.getSourceLocation() is checked. This can be changed
    131 /// by overriding getLocation().
    132 template <typename NodeType>
    133 class LocationVerifier : public MatchVerifier<NodeType> {
    134 public:
    135   void expectLocation(unsigned Line, unsigned Column) {
    136     ExpectLine = Line;
    137     ExpectColumn = Column;
    138   }
    139 
    140 protected:
    141   void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
    142     SourceLocation Loc = getLocation(Node);
    143     unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc);
    144     unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc);
    145     if (Line != ExpectLine || Column != ExpectColumn) {
    146       std::string MsgStr;
    147       llvm::raw_string_ostream Msg(MsgStr);
    148       Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn
    149           << ">, found <";
    150       Loc.print(Msg, *Result.SourceManager);
    151       Msg << '>';
    152       this->setFailure(Msg.str());
    153     }
    154   }
    155 
    156   virtual SourceLocation getLocation(const NodeType &Node) {
    157     return Node.getLocation();
    158   }
    159 
    160 private:
    161   unsigned ExpectLine, ExpectColumn;
    162 };
    163 
    164 /// \brief Verify whether a node has the correct source range.
    165 ///
    166 /// By default, Node.getSourceRange() is checked. This can be changed
    167 /// by overriding getRange().
    168 template <typename NodeType>
    169 class RangeVerifier : public MatchVerifier<NodeType> {
    170 public:
    171   void expectRange(unsigned BeginLine, unsigned BeginColumn,
    172                    unsigned EndLine, unsigned EndColumn) {
    173     ExpectBeginLine = BeginLine;
    174     ExpectBeginColumn = BeginColumn;
    175     ExpectEndLine = EndLine;
    176     ExpectEndColumn = EndColumn;
    177   }
    178 
    179 protected:
    180   void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) {
    181     SourceRange R = getRange(Node);
    182     SourceLocation Begin = R.getBegin();
    183     SourceLocation End = R.getEnd();
    184     unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin);
    185     unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin);
    186     unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End);
    187     unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End);
    188     if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn ||
    189         EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) {
    190       std::string MsgStr;
    191       llvm::raw_string_ostream Msg(MsgStr);
    192       Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn
    193           << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <";
    194       Begin.print(Msg, *Result.SourceManager);
    195       Msg << '-';
    196       End.print(Msg, *Result.SourceManager);
    197       Msg << '>';
    198       this->setFailure(Msg.str());
    199     }
    200   }
    201 
    202   virtual SourceRange getRange(const NodeType &Node) {
    203     return Node.getSourceRange();
    204   }
    205 
    206 private:
    207   unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn;
    208 };
    209 
    210 } // end namespace ast_matchers
    211 } // end namespace clang
    212