Home | History | Annotate | Download | only in Analysis
      1 //===- ScalarEvolutionsTest.cpp - ScalarEvolution unit tests --------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 
     10 #include <llvm/Analysis/ScalarEvolutionExpressions.h>
     11 #include <llvm/Analysis/LoopInfo.h>
     12 #include <llvm/GlobalVariable.h>
     13 #include <llvm/Constants.h>
     14 #include <llvm/LLVMContext.h>
     15 #include <llvm/Module.h>
     16 #include <llvm/PassManager.h>
     17 #include <llvm/ADT/SmallVector.h>
     18 #include "gtest/gtest.h"
     19 
     20 namespace llvm {
     21 namespace {
     22 
     23 // We use this fixture to ensure that we clean up ScalarEvolution before
     24 // deleting the PassManager.
     25 class ScalarEvolutionsTest : public testing::Test {
     26 protected:
     27   ScalarEvolutionsTest() : M("", Context), SE(*new ScalarEvolution) {}
     28   ~ScalarEvolutionsTest() {
     29     // Manually clean up, since we allocated new SCEV objects after the
     30     // pass was finished.
     31     SE.releaseMemory();
     32   }
     33   LLVMContext Context;
     34   Module M;
     35   PassManager PM;
     36   ScalarEvolution &SE;
     37 };
     38 
     39 TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
     40   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context),
     41                                               std::vector<Type *>(), false);
     42   Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
     43   BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
     44   ReturnInst::Create(Context, 0, BB);
     45 
     46   Type *Ty = Type::getInt1Ty(Context);
     47   Constant *Init = Constant::getNullValue(Ty);
     48   Value *V0 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V0");
     49   Value *V1 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V1");
     50   Value *V2 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V2");
     51 
     52   // Create a ScalarEvolution and "run" it so that it gets initialized.
     53   PM.add(&SE);
     54   PM.run(M);
     55 
     56   const SCEV *S0 = SE.getSCEV(V0);
     57   const SCEV *S1 = SE.getSCEV(V1);
     58   const SCEV *S2 = SE.getSCEV(V2);
     59 
     60   const SCEV *P0 = SE.getAddExpr(S0, S0);
     61   const SCEV *P1 = SE.getAddExpr(S1, S1);
     62   const SCEV *P2 = SE.getAddExpr(S2, S2);
     63 
     64   const SCEVMulExpr *M0 = cast<SCEVMulExpr>(P0);
     65   const SCEVMulExpr *M1 = cast<SCEVMulExpr>(P1);
     66   const SCEVMulExpr *M2 = cast<SCEVMulExpr>(P2);
     67 
     68   EXPECT_EQ(cast<SCEVConstant>(M0->getOperand(0))->getValue()->getZExtValue(),
     69             2u);
     70   EXPECT_EQ(cast<SCEVConstant>(M1->getOperand(0))->getValue()->getZExtValue(),
     71             2u);
     72   EXPECT_EQ(cast<SCEVConstant>(M2->getOperand(0))->getValue()->getZExtValue(),
     73             2u);
     74 
     75   // Before the RAUWs, these are all pointing to separate values.
     76   EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0);
     77   EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V1);
     78   EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V2);
     79 
     80   // Do some RAUWs.
     81   V2->replaceAllUsesWith(V1);
     82   V1->replaceAllUsesWith(V0);
     83 
     84   // After the RAUWs, these should all be pointing to V0.
     85   EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0);
     86   EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V0);
     87   EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V0);
     88 }
     89 
     90 TEST_F(ScalarEvolutionsTest, SCEVMultiplyAddRecs) {
     91   Type *Ty = Type::getInt32Ty(Context);
     92   SmallVector<Type *, 10> Types;
     93   Types.append(10, Ty);
     94   FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false);
     95   Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
     96   BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
     97   ReturnInst::Create(Context, 0, BB);
     98 
     99   // Create a ScalarEvolution and "run" it so that it gets initialized.
    100   PM.add(&SE);
    101   PM.run(M);
    102 
    103   // It's possible to produce an empty loop through the default constructor,
    104   // but you can't add any blocks to it without a LoopInfo pass.
    105   Loop L;
    106   const_cast<std::vector<BasicBlock*>&>(L.getBlocks()).push_back(BB);
    107 
    108   Function::arg_iterator AI = F->arg_begin();
    109   SmallVector<const SCEV *, 5> A;
    110   A.push_back(SE.getSCEV(&*AI++));
    111   A.push_back(SE.getSCEV(&*AI++));
    112   A.push_back(SE.getSCEV(&*AI++));
    113   A.push_back(SE.getSCEV(&*AI++));
    114   A.push_back(SE.getSCEV(&*AI++));
    115   const SCEV *A_rec = SE.getAddRecExpr(A, &L, SCEV::FlagAnyWrap);
    116 
    117   SmallVector<const SCEV *, 5> B;
    118   B.push_back(SE.getSCEV(&*AI++));
    119   B.push_back(SE.getSCEV(&*AI++));
    120   B.push_back(SE.getSCEV(&*AI++));
    121   B.push_back(SE.getSCEV(&*AI++));
    122   B.push_back(SE.getSCEV(&*AI++));
    123   const SCEV *B_rec = SE.getAddRecExpr(B, &L, SCEV::FlagAnyWrap);
    124 
    125   /* Spot check that we perform this transformation:
    126      {A0,+,A1,+,A2,+,A3,+,A4} * {B0,+,B1,+,B2,+,B3,+,B4} =
    127      {A0*B0,+,
    128       A1*B0 + A0*B1 + A1*B1,+,
    129       A2*B0 + 2A1*B1 + A0*B2 + 2A2*B1 + 2A1*B2 + A2*B2,+,
    130       A3*B0 + 3A2*B1 + 3A1*B2 + A0*B3 + 3A3*B1 + 6A2*B2 + 3A1*B3 + 3A3*B2 +
    131         3A2*B3 + A3*B3,+,
    132       A4*B0 + 4A3*B1 + 6A2*B2 + 4A1*B3 + A0*B4 + 4A4*B1 + 12A3*B2 + 12A2*B3 +
    133         4A1*B4 + 6A4*B2 + 12A3*B3 + 6A2*B4 + 4A4*B3 + 4A3*B4 + A4*B4,+,
    134       5A4*B1 + 10A3*B2 + 10A2*B3 + 5A1*B4 + 20A4*B2 + 30A3*B3 + 20A2*B4 +
    135         30A4*B3 + 30A3*B4 + 20A4*B4,+,
    136       15A4*B2 + 20A3*B3 + 15A2*B4 + 60A4*B3 + 60A3*B4 + 90A4*B4,+,
    137       35A4*B3 + 35A3*B4 + 140A4*B4,+,
    138       70A4*B4}
    139   */
    140 
    141   const SCEVAddRecExpr *Product =
    142       dyn_cast<SCEVAddRecExpr>(SE.getMulExpr(A_rec, B_rec));
    143   ASSERT_TRUE(Product);
    144   ASSERT_EQ(Product->getNumOperands(), 9u);
    145 
    146   SmallVector<const SCEV *, 16> Sum;
    147   Sum.push_back(SE.getMulExpr(A[0], B[0]));
    148   EXPECT_EQ(Product->getOperand(0), SE.getAddExpr(Sum));
    149   Sum.clear();
    150 
    151   // SCEV produces different an equal but different expression for these.
    152   // Re-enable when PR11052 is fixed.
    153 #if 0
    154   Sum.push_back(SE.getMulExpr(A[1], B[0]));
    155   Sum.push_back(SE.getMulExpr(A[0], B[1]));
    156   Sum.push_back(SE.getMulExpr(A[1], B[1]));
    157   EXPECT_EQ(Product->getOperand(1), SE.getAddExpr(Sum));
    158   Sum.clear();
    159 
    160   Sum.push_back(SE.getMulExpr(A[2], B[0]));
    161   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[1]));
    162   Sum.push_back(SE.getMulExpr(A[0], B[2]));
    163   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[2], B[1]));
    164   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[2]));
    165   Sum.push_back(SE.getMulExpr(A[2], B[2]));
    166   EXPECT_EQ(Product->getOperand(2), SE.getAddExpr(Sum));
    167   Sum.clear();
    168 
    169   Sum.push_back(SE.getMulExpr(A[3], B[0]));
    170   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[1]));
    171   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[2]));
    172   Sum.push_back(SE.getMulExpr(A[0], B[3]));
    173   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[1]));
    174   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2]));
    175   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[3]));
    176   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[2]));
    177   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[3]));
    178   Sum.push_back(SE.getMulExpr(A[3], B[3]));
    179   EXPECT_EQ(Product->getOperand(3), SE.getAddExpr(Sum));
    180   Sum.clear();
    181 
    182   Sum.push_back(SE.getMulExpr(A[4], B[0]));
    183   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[1]));
    184   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2]));
    185   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[3]));
    186   Sum.push_back(SE.getMulExpr(A[0], B[4]));
    187   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[1]));
    188   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[2]));
    189   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[2], B[3]));
    190   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[4]));
    191   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[4], B[2]));
    192   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[3]));
    193   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[4]));
    194   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[3]));
    195   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[4]));
    196   Sum.push_back(SE.getMulExpr(A[4], B[4]));
    197   EXPECT_EQ(Product->getOperand(4), SE.getAddExpr(Sum));
    198   Sum.clear();
    199 
    200   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[4], B[1]));
    201   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[3], B[2]));
    202   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[2], B[3]));
    203   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[1], B[4]));
    204   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[2]));
    205   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[3]));
    206   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[2], B[4]));
    207   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[4], B[3]));
    208   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[4]));
    209   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[4]));
    210   EXPECT_EQ(Product->getOperand(5), SE.getAddExpr(Sum));
    211   Sum.clear();
    212 
    213   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[4], B[2]));
    214   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[3], B[3]));
    215   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[2], B[4]));
    216   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[4], B[3]));
    217   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[3], B[4]));
    218   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 90), A[4], B[4]));
    219   EXPECT_EQ(Product->getOperand(6), SE.getAddExpr(Sum));
    220   Sum.clear();
    221 
    222   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[4], B[3]));
    223   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[3], B[4]));
    224   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 140), A[4], B[4]));
    225   EXPECT_EQ(Product->getOperand(7), SE.getAddExpr(Sum));
    226   Sum.clear();
    227 #endif
    228 
    229   Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 70), A[4], B[4]));
    230   EXPECT_EQ(Product->getOperand(8), SE.getAddExpr(Sum));
    231 }
    232 
    233 }  // end anonymous namespace
    234 }  // end namespace llvm
    235