Home | History | Annotate | Download | only in Analysis
      1 //===- ScalarEvolutionsTest.cpp - ScalarEvolution unit tests --------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
     11 #include "llvm/Analysis/AssumptionCache.h"
     12 #include "llvm/Analysis/LoopInfo.h"
     13 #include "llvm/Analysis/TargetLibraryInfo.h"
     14 #include "llvm/ADT/SmallVector.h"
     15 #include "llvm/Analysis/LoopInfo.h"
     16 #include "llvm/IR/Constants.h"
     17 #include "llvm/IR/Dominators.h"
     18 #include "llvm/IR/GlobalVariable.h"
     19 #include "llvm/IR/LLVMContext.h"
     20 #include "llvm/IR/Module.h"
     21 #include "llvm/IR/LegacyPassManager.h"
     22 #include "gtest/gtest.h"
     23 
     24 namespace llvm {
     25 namespace {
     26 
     27 // We use this fixture to ensure that we clean up ScalarEvolution before
     28 // deleting the PassManager.
     29 class ScalarEvolutionsTest : public testing::Test {
     30 protected:
     31   LLVMContext Context;
     32   Module M;
     33   TargetLibraryInfoImpl TLII;
     34   TargetLibraryInfo TLI;
     35 
     36   std::unique_ptr<AssumptionCache> AC;
     37   std::unique_ptr<DominatorTree> DT;
     38   std::unique_ptr<LoopInfo> LI;
     39 
     40   ScalarEvolutionsTest() : M("", Context), TLII(), TLI(TLII) {}
     41 
     42   ScalarEvolution buildSE(Function &F) {
     43     AC.reset(new AssumptionCache(F));
     44     DT.reset(new DominatorTree(F));
     45     LI.reset(new LoopInfo(*DT));
     46     return ScalarEvolution(F, TLI, *AC, *DT, *LI);
     47   }
     48 };
     49 
     50 TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
     51   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context),
     52                                               std::vector<Type *>(), false);
     53   Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
     54   BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
     55   ReturnInst::Create(Context, nullptr, BB);
     56 
     57   Type *Ty = Type::getInt1Ty(Context);
     58   Constant *Init = Constant::getNullValue(Ty);
     59   Value *V0 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V0");
     60   Value *V1 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V1");
     61   Value *V2 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V2");
     62 
     63   ScalarEvolution SE = buildSE(*F);
     64 
     65   const SCEV *S0 = SE.getSCEV(V0);
     66   const SCEV *S1 = SE.getSCEV(V1);
     67   const SCEV *S2 = SE.getSCEV(V2);
     68 
     69   const SCEV *P0 = SE.getAddExpr(S0, S0);
     70   const SCEV *P1 = SE.getAddExpr(S1, S1);
     71   const SCEV *P2 = SE.getAddExpr(S2, S2);
     72 
     73   const SCEVMulExpr *M0 = cast<SCEVMulExpr>(P0);
     74   const SCEVMulExpr *M1 = cast<SCEVMulExpr>(P1);
     75   const SCEVMulExpr *M2 = cast<SCEVMulExpr>(P2);
     76 
     77   EXPECT_EQ(cast<SCEVConstant>(M0->getOperand(0))->getValue()->getZExtValue(),
     78             2u);
     79   EXPECT_EQ(cast<SCEVConstant>(M1->getOperand(0))->getValue()->getZExtValue(),
     80             2u);
     81   EXPECT_EQ(cast<SCEVConstant>(M2->getOperand(0))->getValue()->getZExtValue(),
     82             2u);
     83 
     84   // Before the RAUWs, these are all pointing to separate values.
     85   EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0);
     86   EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V1);
     87   EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V2);
     88 
     89   // Do some RAUWs.
     90   V2->replaceAllUsesWith(V1);
     91   V1->replaceAllUsesWith(V0);
     92 
     93   // After the RAUWs, these should all be pointing to V0.
     94   EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0);
     95   EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V0);
     96   EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V0);
     97 }
     98 
     99 TEST_F(ScalarEvolutionsTest, SCEVMultiplyAddRecs) {
    100   Type *Ty = Type::getInt32Ty(Context);
    101   SmallVector<Type *, 10> Types;
    102   Types.append(10, Ty);
    103   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false);
    104   Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
    105   BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
    106   ReturnInst::Create(Context, nullptr, BB);
    107 
    108   ScalarEvolution SE = buildSE(*F);
    109 
    110   // It's possible to produce an empty loop through the default constructor,
    111   // but you can't add any blocks to it without a LoopInfo pass.
    112   Loop L;
    113   const_cast<std::vector<BasicBlock*>&>(L.getBlocks()).push_back(BB);
    114 
    115   Function::arg_iterator AI = F->arg_begin();
    116   SmallVector<const SCEV *, 5> A;
    117   A.push_back(SE.getSCEV(&*AI++));
    118   A.push_back(SE.getSCEV(&*AI++));
    119   A.push_back(SE.getSCEV(&*AI++));
    120   A.push_back(SE.getSCEV(&*AI++));
    121   A.push_back(SE.getSCEV(&*AI++));
    122   const SCEV *A_rec = SE.getAddRecExpr(A, &L, SCEV::FlagAnyWrap);
    123 
    124   SmallVector<const SCEV *, 5> B;
    125   B.push_back(SE.getSCEV(&*AI++));
    126   B.push_back(SE.getSCEV(&*AI++));
    127   B.push_back(SE.getSCEV(&*AI++));
    128   B.push_back(SE.getSCEV(&*AI++));
    129   B.push_back(SE.getSCEV(&*AI++));
    130   const SCEV *B_rec = SE.getAddRecExpr(B, &L, SCEV::FlagAnyWrap);
    131 
    132   /* Spot check that we perform this transformation:
    133      {A0,+,A1,+,A2,+,A3,+,A4} * {B0,+,B1,+,B2,+,B3,+,B4} =
    134      {A0*B0,+,
    135       A1*B0 + A0*B1 + A1*B1,+,
    136       A2*B0 + 2A1*B1 + A0*B2 + 2A2*B1 + 2A1*B2 + A2*B2,+,
    137       A3*B0 + 3A2*B1 + 3A1*B2 + A0*B3 + 3A3*B1 + 6A2*B2 + 3A1*B3 + 3A3*B2 +
    138         3A2*B3 + A3*B3,+,
    139       A4*B0 + 4A3*B1 + 6A2*B2 + 4A1*B3 + A0*B4 + 4A4*B1 + 12A3*B2 + 12A2*B3 +
    140         4A1*B4 + 6A4*B2 + 12A3*B3 + 6A2*B4 + 4A4*B3 + 4A3*B4 + A4*B4,+,
    141       5A4*B1 + 10A3*B2 + 10A2*B3 + 5A1*B4 + 20A4*B2 + 30A3*B3 + 20A2*B4 +
    142         30A4*B3 + 30A3*B4 + 20A4*B4,+,
    143       15A4*B2 + 20A3*B3 + 15A2*B4 + 60A4*B3 + 60A3*B4 + 90A4*B4,+,
    144       35A4*B3 + 35A3*B4 + 140A4*B4,+,
    145       70A4*B4}
    146   */
    147 
    148   const SCEVAddRecExpr *Product =
    149       dyn_cast<SCEVAddRecExpr>(SE.getMulExpr(A_rec, B_rec));
    150   ASSERT_TRUE(Product);
    151   ASSERT_EQ(Product->getNumOperands(), 9u);
    152 
    153   SmallVector<const SCEV *, 16> Sum;
    154   Sum.push_back(SE.getMulExpr(A[0], B[0]));
    155   EXPECT_EQ(Product->getOperand(0), SE.getAddExpr(Sum));
    156   Sum.clear();
    157 
    158   // SCEV produces different an equal but different expression for these.
    159   // Re-enable when PR11052 is fixed.
    160 #if 0
    161   Sum.push_back(SE.getMulExpr(A[1], B[0]));
    162   Sum.push_back(SE.getMulExpr(A[0], B[1]));
    163   Sum.push_back(SE.getMulExpr(A[1], B[1]));
    164   EXPECT_EQ(Product->getOperand(1), SE.getAddExpr(Sum));
    165   Sum.clear();
    166 
    167   Sum.push_back(SE.getMulExpr(A[2], B[0]));
    168   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[1]));
    169   Sum.push_back(SE.getMulExpr(A[0], B[2]));
    170   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[2], B[1]));
    171   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[2]));
    172   Sum.push_back(SE.getMulExpr(A[2], B[2]));
    173   EXPECT_EQ(Product->getOperand(2), SE.getAddExpr(Sum));
    174   Sum.clear();
    175 
    176   Sum.push_back(SE.getMulExpr(A[3], B[0]));
    177   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[1]));
    178   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[2]));
    179   Sum.push_back(SE.getMulExpr(A[0], B[3]));
    180   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[1]));
    181   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2]));
    182   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[3]));
    183   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[2]));
    184   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[3]));
    185   Sum.push_back(SE.getMulExpr(A[3], B[3]));
    186   EXPECT_EQ(Product->getOperand(3), SE.getAddExpr(Sum));
    187   Sum.clear();
    188 
    189   Sum.push_back(SE.getMulExpr(A[4], B[0]));
    190   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[1]));
    191   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2]));
    192   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[3]));
    193   Sum.push_back(SE.getMulExpr(A[0], B[4]));
    194   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[1]));
    195   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[2]));
    196   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[2], B[3]));
    197   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[4]));
    198   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[4], B[2]));
    199   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[3]));
    200   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[4]));
    201   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[3]));
    202   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[4]));
    203   Sum.push_back(SE.getMulExpr(A[4], B[4]));
    204   EXPECT_EQ(Product->getOperand(4), SE.getAddExpr(Sum));
    205   Sum.clear();
    206 
    207   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[4], B[1]));
    208   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[3], B[2]));
    209   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[2], B[3]));
    210   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[1], B[4]));
    211   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[2]));
    212   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[3]));
    213   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[2], B[4]));
    214   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[4], B[3]));
    215   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[4]));
    216   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[4]));
    217   EXPECT_EQ(Product->getOperand(5), SE.getAddExpr(Sum));
    218   Sum.clear();
    219 
    220   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[4], B[2]));
    221   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[3], B[3]));
    222   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[2], B[4]));
    223   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[4], B[3]));
    224   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[3], B[4]));
    225   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 90), A[4], B[4]));
    226   EXPECT_EQ(Product->getOperand(6), SE.getAddExpr(Sum));
    227   Sum.clear();
    228 
    229   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[4], B[3]));
    230   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[3], B[4]));
    231   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 140), A[4], B[4]));
    232   EXPECT_EQ(Product->getOperand(7), SE.getAddExpr(Sum));
    233   Sum.clear();
    234 #endif
    235 
    236   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 70), A[4], B[4]));
    237   EXPECT_EQ(Product->getOperand(8), SE.getAddExpr(Sum));
    238 }
    239 
    240 }  // end anonymous namespace
    241 }  // end namespace llvm
    242