1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Provides MatchVerifier, a base class to implement gtest matchers that 11 // verify things that can be matched on the AST. 12 // 13 // Also implements matchers based on MatchVerifier: 14 // LocationVerifier and RangeVerifier to verify whether a matched node has 15 // the expected source location or source range. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "clang/AST/ASTContext.h" 20 #include "clang/ASTMatchers/ASTMatchFinder.h" 21 #include "clang/ASTMatchers/ASTMatchers.h" 22 #include "clang/Tooling/Tooling.h" 23 #include "gtest/gtest.h" 24 25 namespace clang { 26 namespace ast_matchers { 27 28 enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_OpenCL }; 29 30 /// \brief Base class for verifying some property of nodes found by a matcher. 31 template <typename NodeType> 32 class MatchVerifier : public MatchFinder::MatchCallback { 33 public: 34 template <typename MatcherType> 35 testing::AssertionResult match(const std::string &Code, 36 const MatcherType &AMatcher) { 37 return match(Code, AMatcher, Lang_CXX); 38 } 39 40 template <typename MatcherType> 41 testing::AssertionResult match(const std::string &Code, 42 const MatcherType &AMatcher, Language L); 43 44 protected: 45 virtual void run(const MatchFinder::MatchResult &Result); 46 virtual void verify(const MatchFinder::MatchResult &Result, 47 const NodeType &Node) {} 48 49 void setFailure(const Twine &Result) { 50 Verified = false; 51 VerifyResult = Result.str(); 52 } 53 54 void setSuccess() { 55 Verified = true; 56 } 57 58 private: 59 bool Verified; 60 std::string VerifyResult; 61 }; 62 63 /// \brief Runs a matcher over some code, and returns the result of the 64 /// verifier for the matched node. 65 template <typename NodeType> template <typename MatcherType> 66 testing::AssertionResult MatchVerifier<NodeType>::match( 67 const std::string &Code, const MatcherType &AMatcher, Language L) { 68 MatchFinder Finder; 69 Finder.addMatcher(AMatcher.bind(""), this); 70 OwningPtr<tooling::FrontendActionFactory> Factory( 71 tooling::newFrontendActionFactory(&Finder)); 72 73 std::vector<std::string> Args; 74 StringRef FileName; 75 switch (L) { 76 case Lang_C: 77 Args.push_back("-std=c99"); 78 FileName = "input.c"; 79 break; 80 case Lang_C89: 81 Args.push_back("-std=c89"); 82 FileName = "input.c"; 83 break; 84 case Lang_CXX: 85 Args.push_back("-std=c++98"); 86 FileName = "input.cc"; 87 break; 88 case Lang_OpenCL: 89 FileName = "input.cl"; 90 } 91 92 // Default to failure in case callback is never called 93 setFailure("Could not find match"); 94 if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName)) 95 return testing::AssertionFailure() << "Parsing error"; 96 if (!Verified) 97 return testing::AssertionFailure() << VerifyResult; 98 return testing::AssertionSuccess(); 99 } 100 101 template <typename NodeType> 102 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) { 103 const NodeType *Node = Result.Nodes.getNodeAs<NodeType>(""); 104 if (!Node) { 105 setFailure("Matched node has wrong type"); 106 } else { 107 // Callback has been called, default to success. 108 setSuccess(); 109 verify(Result, *Node); 110 } 111 } 112 113 /// \brief Verify whether a node has the correct source location. 114 /// 115 /// By default, Node.getSourceLocation() is checked. This can be changed 116 /// by overriding getLocation(). 117 template <typename NodeType> 118 class LocationVerifier : public MatchVerifier<NodeType> { 119 public: 120 void expectLocation(unsigned Line, unsigned Column) { 121 ExpectLine = Line; 122 ExpectColumn = Column; 123 } 124 125 protected: 126 void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) { 127 SourceLocation Loc = getLocation(Node); 128 unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc); 129 unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc); 130 if (Line != ExpectLine || Column != ExpectColumn) { 131 std::string MsgStr; 132 llvm::raw_string_ostream Msg(MsgStr); 133 Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn 134 << ">, found <"; 135 Loc.print(Msg, *Result.SourceManager); 136 Msg << '>'; 137 this->setFailure(Msg.str()); 138 } 139 } 140 141 virtual SourceLocation getLocation(const NodeType &Node) { 142 return Node.getLocation(); 143 } 144 145 private: 146 unsigned ExpectLine, ExpectColumn; 147 }; 148 149 /// \brief Verify whether a node has the correct source range. 150 /// 151 /// By default, Node.getSourceRange() is checked. This can be changed 152 /// by overriding getRange(). 153 template <typename NodeType> 154 class RangeVerifier : public MatchVerifier<NodeType> { 155 public: 156 void expectRange(unsigned BeginLine, unsigned BeginColumn, 157 unsigned EndLine, unsigned EndColumn) { 158 ExpectBeginLine = BeginLine; 159 ExpectBeginColumn = BeginColumn; 160 ExpectEndLine = EndLine; 161 ExpectEndColumn = EndColumn; 162 } 163 164 protected: 165 void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) { 166 SourceRange R = getRange(Node); 167 SourceLocation Begin = R.getBegin(); 168 SourceLocation End = R.getEnd(); 169 unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin); 170 unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin); 171 unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End); 172 unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End); 173 if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn || 174 EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) { 175 std::string MsgStr; 176 llvm::raw_string_ostream Msg(MsgStr); 177 Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn 178 << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <"; 179 Begin.print(Msg, *Result.SourceManager); 180 Msg << '-'; 181 End.print(Msg, *Result.SourceManager); 182 Msg << '>'; 183 this->setFailure(Msg.str()); 184 } 185 } 186 187 virtual SourceRange getRange(const NodeType &Node) { 188 return Node.getSourceRange(); 189 } 190 191 private: 192 unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn; 193 }; 194 195 } // end namespace ast_matchers 196 } // end namespace clang 197