1 //===- unittest/AST/MatchVerifier.h - AST unit test support ---------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Provides MatchVerifier, a base class to implement gtest matchers that 11 // verify things that can be matched on the AST. 12 // 13 // Also implements matchers based on MatchVerifier: 14 // LocationVerifier and RangeVerifier to verify whether a matched node has 15 // the expected source location or source range. 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "clang/AST/ASTContext.h" 20 #include "clang/ASTMatchers/ASTMatchFinder.h" 21 #include "clang/ASTMatchers/ASTMatchers.h" 22 #include "clang/Tooling/Tooling.h" 23 #include "gtest/gtest.h" 24 25 namespace clang { 26 namespace ast_matchers { 27 28 enum Language { Lang_C, Lang_C89, Lang_CXX, Lang_CXX11, Lang_OpenCL }; 29 30 /// \brief Base class for verifying some property of nodes found by a matcher. 31 template <typename NodeType> 32 class MatchVerifier : public MatchFinder::MatchCallback { 33 public: 34 template <typename MatcherType> 35 testing::AssertionResult match(const std::string &Code, 36 const MatcherType &AMatcher) { 37 std::vector<std::string> Args; 38 return match(Code, AMatcher, Args, Lang_CXX); 39 } 40 41 template <typename MatcherType> 42 testing::AssertionResult match(const std::string &Code, 43 const MatcherType &AMatcher, 44 Language L) { 45 std::vector<std::string> Args; 46 return match(Code, AMatcher, Args, L); 47 } 48 49 template <typename MatcherType> 50 testing::AssertionResult match(const std::string &Code, 51 const MatcherType &AMatcher, 52 std::vector<std::string>& Args, 53 Language L); 54 55 protected: 56 virtual void run(const MatchFinder::MatchResult &Result); 57 virtual void verify(const MatchFinder::MatchResult &Result, 58 const NodeType &Node) {} 59 60 void setFailure(const Twine &Result) { 61 Verified = false; 62 VerifyResult = Result.str(); 63 } 64 65 void setSuccess() { 66 Verified = true; 67 } 68 69 private: 70 bool Verified; 71 std::string VerifyResult; 72 }; 73 74 /// \brief Runs a matcher over some code, and returns the result of the 75 /// verifier for the matched node. 76 template <typename NodeType> template <typename MatcherType> 77 testing::AssertionResult MatchVerifier<NodeType>::match( 78 const std::string &Code, const MatcherType &AMatcher, 79 std::vector<std::string>& Args, Language L) { 80 MatchFinder Finder; 81 Finder.addMatcher(AMatcher.bind(""), this); 82 OwningPtr<tooling::FrontendActionFactory> Factory( 83 tooling::newFrontendActionFactory(&Finder)); 84 85 StringRef FileName; 86 switch (L) { 87 case Lang_C: 88 Args.push_back("-std=c99"); 89 FileName = "input.c"; 90 break; 91 case Lang_C89: 92 Args.push_back("-std=c89"); 93 FileName = "input.c"; 94 break; 95 case Lang_CXX: 96 Args.push_back("-std=c++98"); 97 FileName = "input.cc"; 98 break; 99 case Lang_CXX11: 100 Args.push_back("-std=c++11"); 101 FileName = "input.cc"; 102 break; 103 case Lang_OpenCL: 104 FileName = "input.cl"; 105 } 106 107 // Default to failure in case callback is never called 108 setFailure("Could not find match"); 109 if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName)) 110 return testing::AssertionFailure() << "Parsing error"; 111 if (!Verified) 112 return testing::AssertionFailure() << VerifyResult; 113 return testing::AssertionSuccess(); 114 } 115 116 template <typename NodeType> 117 void MatchVerifier<NodeType>::run(const MatchFinder::MatchResult &Result) { 118 const NodeType *Node = Result.Nodes.getNodeAs<NodeType>(""); 119 if (!Node) { 120 setFailure("Matched node has wrong type"); 121 } else { 122 // Callback has been called, default to success. 123 setSuccess(); 124 verify(Result, *Node); 125 } 126 } 127 128 /// \brief Verify whether a node has the correct source location. 129 /// 130 /// By default, Node.getSourceLocation() is checked. This can be changed 131 /// by overriding getLocation(). 132 template <typename NodeType> 133 class LocationVerifier : public MatchVerifier<NodeType> { 134 public: 135 void expectLocation(unsigned Line, unsigned Column) { 136 ExpectLine = Line; 137 ExpectColumn = Column; 138 } 139 140 protected: 141 void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) { 142 SourceLocation Loc = getLocation(Node); 143 unsigned Line = Result.SourceManager->getSpellingLineNumber(Loc); 144 unsigned Column = Result.SourceManager->getSpellingColumnNumber(Loc); 145 if (Line != ExpectLine || Column != ExpectColumn) { 146 std::string MsgStr; 147 llvm::raw_string_ostream Msg(MsgStr); 148 Msg << "Expected location <" << ExpectLine << ":" << ExpectColumn 149 << ">, found <"; 150 Loc.print(Msg, *Result.SourceManager); 151 Msg << '>'; 152 this->setFailure(Msg.str()); 153 } 154 } 155 156 virtual SourceLocation getLocation(const NodeType &Node) { 157 return Node.getLocation(); 158 } 159 160 private: 161 unsigned ExpectLine, ExpectColumn; 162 }; 163 164 /// \brief Verify whether a node has the correct source range. 165 /// 166 /// By default, Node.getSourceRange() is checked. This can be changed 167 /// by overriding getRange(). 168 template <typename NodeType> 169 class RangeVerifier : public MatchVerifier<NodeType> { 170 public: 171 void expectRange(unsigned BeginLine, unsigned BeginColumn, 172 unsigned EndLine, unsigned EndColumn) { 173 ExpectBeginLine = BeginLine; 174 ExpectBeginColumn = BeginColumn; 175 ExpectEndLine = EndLine; 176 ExpectEndColumn = EndColumn; 177 } 178 179 protected: 180 void verify(const MatchFinder::MatchResult &Result, const NodeType &Node) { 181 SourceRange R = getRange(Node); 182 SourceLocation Begin = R.getBegin(); 183 SourceLocation End = R.getEnd(); 184 unsigned BeginLine = Result.SourceManager->getSpellingLineNumber(Begin); 185 unsigned BeginColumn = Result.SourceManager->getSpellingColumnNumber(Begin); 186 unsigned EndLine = Result.SourceManager->getSpellingLineNumber(End); 187 unsigned EndColumn = Result.SourceManager->getSpellingColumnNumber(End); 188 if (BeginLine != ExpectBeginLine || BeginColumn != ExpectBeginColumn || 189 EndLine != ExpectEndLine || EndColumn != ExpectEndColumn) { 190 std::string MsgStr; 191 llvm::raw_string_ostream Msg(MsgStr); 192 Msg << "Expected range <" << ExpectBeginLine << ":" << ExpectBeginColumn 193 << '-' << ExpectEndLine << ":" << ExpectEndColumn << ">, found <"; 194 Begin.print(Msg, *Result.SourceManager); 195 Msg << '-'; 196 End.print(Msg, *Result.SourceManager); 197 Msg << '>'; 198 this->setFailure(Msg.str()); 199 } 200 } 201 202 virtual SourceRange getRange(const NodeType &Node) { 203 return Node.getSourceRange(); 204 } 205 206 private: 207 unsigned ExpectBeginLine, ExpectBeginColumn, ExpectEndLine, ExpectEndColumn; 208 }; 209 210 } // end namespace ast_matchers 211 } // end namespace clang 212