Home | History | Annotate | Download | only in ASTMatchers
      1 //===- unittest/Tooling/ASTMatchersTest.h - Matcher tests helpers ------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #ifndef LLVM_CLANG_UNITTESTS_ASTMATCHERS_ASTMATCHERSTEST_H
     11 #define LLVM_CLANG_UNITTESTS_ASTMATCHERS_ASTMATCHERSTEST_H
     12 
     13 #include "clang/ASTMatchers/ASTMatchFinder.h"
     14 #include "clang/Frontend/ASTUnit.h"
     15 #include "clang/Tooling/Tooling.h"
     16 #include "gtest/gtest.h"
     17 
     18 namespace clang {
     19 namespace ast_matchers {
     20 
     21 using clang::tooling::buildASTFromCodeWithArgs;
     22 using clang::tooling::newFrontendActionFactory;
     23 using clang::tooling::runToolOnCodeWithArgs;
     24 using clang::tooling::FrontendActionFactory;
     25 using clang::tooling::FileContentMappings;
     26 
     27 class BoundNodesCallback {
     28 public:
     29   virtual ~BoundNodesCallback() {}
     30   virtual bool run(const BoundNodes *BoundNodes) = 0;
     31   virtual bool run(const BoundNodes *BoundNodes, ASTContext *Context) = 0;
     32   virtual void onEndOfTranslationUnit() {}
     33 };
     34 
     35 // If 'FindResultVerifier' is not NULL, sets *Verified to the result of
     36 // running 'FindResultVerifier' with the bound nodes as argument.
     37 // If 'FindResultVerifier' is NULL, sets *Verified to true when Run is called.
     38 class VerifyMatch : public MatchFinder::MatchCallback {
     39 public:
     40   VerifyMatch(BoundNodesCallback *FindResultVerifier, bool *Verified)
     41       : Verified(Verified), FindResultReviewer(FindResultVerifier) {}
     42 
     43   void run(const MatchFinder::MatchResult &Result) override {
     44     if (FindResultReviewer != nullptr) {
     45       *Verified |= FindResultReviewer->run(&Result.Nodes, Result.Context);
     46     } else {
     47       *Verified = true;
     48     }
     49   }
     50 
     51   void onEndOfTranslationUnit() override {
     52     if (FindResultReviewer)
     53       FindResultReviewer->onEndOfTranslationUnit();
     54   }
     55 
     56 private:
     57   bool *const Verified;
     58   BoundNodesCallback *const FindResultReviewer;
     59 };
     60 
     61 template <typename T>
     62 testing::AssertionResult matchesConditionally(
     63     const std::string &Code, const T &AMatcher, bool ExpectMatch,
     64     llvm::StringRef CompileArg,
     65     const FileContentMappings &VirtualMappedFiles = FileContentMappings(),
     66     const std::string &Filename = "input.cc") {
     67   bool Found = false, DynamicFound = false;
     68   MatchFinder Finder;
     69   VerifyMatch VerifyFound(nullptr, &Found);
     70   Finder.addMatcher(AMatcher, &VerifyFound);
     71   VerifyMatch VerifyDynamicFound(nullptr, &DynamicFound);
     72   if (!Finder.addDynamicMatcher(AMatcher, &VerifyDynamicFound))
     73     return testing::AssertionFailure() << "Could not add dynamic matcher";
     74   std::unique_ptr<FrontendActionFactory> Factory(
     75       newFrontendActionFactory(&Finder));
     76   // Some tests use typeof, which is a gnu extension.
     77   std::vector<std::string> Args;
     78   Args.push_back(CompileArg);
     79   // Some tests need rtti/exceptions on
     80   Args.push_back("-frtti");
     81   Args.push_back("-fexceptions");
     82   if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, Filename,
     83                              VirtualMappedFiles)) {
     84     return testing::AssertionFailure() << "Parsing error in \"" << Code << "\"";
     85   }
     86   if (Found != DynamicFound) {
     87     return testing::AssertionFailure() << "Dynamic match result ("
     88                                        << DynamicFound
     89                                        << ") does not match static result ("
     90                                        << Found << ")";
     91   }
     92   if (!Found && ExpectMatch) {
     93     return testing::AssertionFailure()
     94       << "Could not find match in \"" << Code << "\"";
     95   } else if (Found && !ExpectMatch) {
     96     return testing::AssertionFailure()
     97       << "Found unexpected match in \"" << Code << "\"";
     98   }
     99   return testing::AssertionSuccess();
    100 }
    101 
    102 template <typename T>
    103 testing::AssertionResult matches(const std::string &Code, const T &AMatcher) {
    104   return matchesConditionally(Code, AMatcher, true, "-std=c++11");
    105 }
    106 
    107 template <typename T>
    108 testing::AssertionResult notMatches(const std::string &Code,
    109                                     const T &AMatcher) {
    110   return matchesConditionally(Code, AMatcher, false, "-std=c++11");
    111 }
    112 
    113 template <typename T>
    114 testing::AssertionResult matchesObjC(const std::string &Code,
    115                                      const T &AMatcher) {
    116   return matchesConditionally(
    117     Code, AMatcher, true,
    118     "", FileContentMappings(), "input.m");
    119 }
    120 
    121 template <typename T>
    122 testing::AssertionResult notMatchesObjC(const std::string &Code,
    123                                      const T &AMatcher) {
    124   return matchesConditionally(
    125     Code, AMatcher, false,
    126     "", FileContentMappings(), "input.m");
    127 }
    128 
    129 
    130 // Function based on matchesConditionally with "-x cuda" argument added and
    131 // small CUDA header prepended to the code string.
    132 template <typename T>
    133 testing::AssertionResult matchesConditionallyWithCuda(
    134     const std::string &Code, const T &AMatcher, bool ExpectMatch,
    135     llvm::StringRef CompileArg) {
    136   const std::string CudaHeader =
    137       "typedef unsigned int size_t;\n"
    138       "#define __constant__ __attribute__((constant))\n"
    139       "#define __device__ __attribute__((device))\n"
    140       "#define __global__ __attribute__((global))\n"
    141       "#define __host__ __attribute__((host))\n"
    142       "#define __shared__ __attribute__((shared))\n"
    143       "struct dim3 {"
    144       "  unsigned x, y, z;"
    145       "  __host__ __device__ dim3(unsigned x, unsigned y = 1, unsigned z = 1)"
    146       "      : x(x), y(y), z(z) {}"
    147       "};"
    148       "typedef struct cudaStream *cudaStream_t;"
    149       "int cudaConfigureCall(dim3 gridSize, dim3 blockSize,"
    150       "                      size_t sharedSize = 0,"
    151       "                      cudaStream_t stream = 0);";
    152 
    153   bool Found = false, DynamicFound = false;
    154   MatchFinder Finder;
    155   VerifyMatch VerifyFound(nullptr, &Found);
    156   Finder.addMatcher(AMatcher, &VerifyFound);
    157   VerifyMatch VerifyDynamicFound(nullptr, &DynamicFound);
    158   if (!Finder.addDynamicMatcher(AMatcher, &VerifyDynamicFound))
    159     return testing::AssertionFailure() << "Could not add dynamic matcher";
    160   std::unique_ptr<FrontendActionFactory> Factory(
    161       newFrontendActionFactory(&Finder));
    162   // Some tests use typeof, which is a gnu extension.
    163   std::vector<std::string> Args;
    164   Args.push_back("-xcuda");
    165   Args.push_back("-fno-ms-extensions");
    166   Args.push_back(CompileArg);
    167   if (!runToolOnCodeWithArgs(Factory->create(),
    168                              CudaHeader + Code, Args)) {
    169     return testing::AssertionFailure() << "Parsing error in \"" << Code << "\"";
    170   }
    171   if (Found != DynamicFound) {
    172     return testing::AssertionFailure() << "Dynamic match result ("
    173                                        << DynamicFound
    174                                        << ") does not match static result ("
    175                                        << Found << ")";
    176   }
    177   if (!Found && ExpectMatch) {
    178     return testing::AssertionFailure()
    179       << "Could not find match in \"" << Code << "\"";
    180   } else if (Found && !ExpectMatch) {
    181     return testing::AssertionFailure()
    182       << "Found unexpected match in \"" << Code << "\"";
    183   }
    184   return testing::AssertionSuccess();
    185 }
    186 
    187 template <typename T>
    188 testing::AssertionResult matchesWithCuda(const std::string &Code,
    189                                          const T &AMatcher) {
    190   return matchesConditionallyWithCuda(Code, AMatcher, true, "-std=c++11");
    191 }
    192 
    193 template <typename T>
    194 testing::AssertionResult notMatchesWithCuda(const std::string &Code,
    195                                     const T &AMatcher) {
    196   return matchesConditionallyWithCuda(Code, AMatcher, false, "-std=c++11");
    197 }
    198 
    199 template <typename T>
    200 testing::AssertionResult
    201 matchAndVerifyResultConditionally(const std::string &Code, const T &AMatcher,
    202                                   BoundNodesCallback *FindResultVerifier,
    203                                   bool ExpectResult) {
    204   std::unique_ptr<BoundNodesCallback> ScopedVerifier(FindResultVerifier);
    205   bool VerifiedResult = false;
    206   MatchFinder Finder;
    207   VerifyMatch VerifyVerifiedResult(FindResultVerifier, &VerifiedResult);
    208   Finder.addMatcher(AMatcher, &VerifyVerifiedResult);
    209   std::unique_ptr<FrontendActionFactory> Factory(
    210       newFrontendActionFactory(&Finder));
    211   // Some tests use typeof, which is a gnu extension.
    212   std::vector<std::string> Args(1, "-std=gnu++98");
    213   if (!runToolOnCodeWithArgs(Factory->create(), Code, Args)) {
    214     return testing::AssertionFailure() << "Parsing error in \"" << Code << "\"";
    215   }
    216   if (!VerifiedResult && ExpectResult) {
    217     return testing::AssertionFailure()
    218       << "Could not verify result in \"" << Code << "\"";
    219   } else if (VerifiedResult && !ExpectResult) {
    220     return testing::AssertionFailure()
    221       << "Verified unexpected result in \"" << Code << "\"";
    222   }
    223 
    224   VerifiedResult = false;
    225   std::unique_ptr<ASTUnit> AST(buildASTFromCodeWithArgs(Code, Args));
    226   if (!AST.get())
    227     return testing::AssertionFailure() << "Parsing error in \"" << Code
    228                                        << "\" while building AST";
    229   Finder.matchAST(AST->getASTContext());
    230   if (!VerifiedResult && ExpectResult) {
    231     return testing::AssertionFailure()
    232       << "Could not verify result in \"" << Code << "\" with AST";
    233   } else if (VerifiedResult && !ExpectResult) {
    234     return testing::AssertionFailure()
    235       << "Verified unexpected result in \"" << Code << "\" with AST";
    236   }
    237 
    238   return testing::AssertionSuccess();
    239 }
    240 
    241 // FIXME: Find better names for these functions (or document what they
    242 // do more precisely).
    243 template <typename T>
    244 testing::AssertionResult
    245 matchAndVerifyResultTrue(const std::string &Code, const T &AMatcher,
    246                          BoundNodesCallback *FindResultVerifier) {
    247   return matchAndVerifyResultConditionally(
    248       Code, AMatcher, FindResultVerifier, true);
    249 }
    250 
    251 template <typename T>
    252 testing::AssertionResult
    253 matchAndVerifyResultFalse(const std::string &Code, const T &AMatcher,
    254                           BoundNodesCallback *FindResultVerifier) {
    255   return matchAndVerifyResultConditionally(
    256       Code, AMatcher, FindResultVerifier, false);
    257 }
    258 
    259 } // end namespace ast_matchers
    260 } // end namespace clang
    261 
    262 #endif  // LLVM_CLANG_UNITTESTS_AST_MATCHERS_AST_MATCHERS_TEST_H
    263