1 //===- ScalarEvolutionsTest.cpp - ScalarEvolution unit tests --------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include <llvm/Analysis/ScalarEvolutionExpressions.h> 11 #include <llvm/Analysis/LoopInfo.h> 12 #include <llvm/GlobalVariable.h> 13 #include <llvm/Constants.h> 14 #include <llvm/LLVMContext.h> 15 #include <llvm/Module.h> 16 #include <llvm/PassManager.h> 17 #include <llvm/ADT/SmallVector.h> 18 #include "gtest/gtest.h" 19 20 namespace llvm { 21 namespace { 22 23 // We use this fixture to ensure that we clean up ScalarEvolution before 24 // deleting the PassManager. 25 class ScalarEvolutionsTest : public testing::Test { 26 protected: 27 ScalarEvolutionsTest() : M("", Context), SE(*new ScalarEvolution) {} 28 ~ScalarEvolutionsTest() { 29 // Manually clean up, since we allocated new SCEV objects after the 30 // pass was finished. 31 SE.releaseMemory(); 32 } 33 LLVMContext Context; 34 Module M; 35 PassManager PM; 36 ScalarEvolution &SE; 37 }; 38 39 TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) { 40 FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), 41 std::vector<Type *>(), false); 42 Function *F = cast<Function>(M.getOrInsertFunction("f", FTy)); 43 BasicBlock *BB = BasicBlock::Create(Context, "entry", F); 44 ReturnInst::Create(Context, 0, BB); 45 46 Type *Ty = Type::getInt1Ty(Context); 47 Constant *Init = Constant::getNullValue(Ty); 48 Value *V0 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V0"); 49 Value *V1 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V1"); 50 Value *V2 = new GlobalVariable(M, Ty, false, GlobalValue::ExternalLinkage, Init, "V2"); 51 52 // Create a ScalarEvolution and "run" it so that it gets initialized. 53 PM.add(&SE); 54 PM.run(M); 55 56 const SCEV *S0 = SE.getSCEV(V0); 57 const SCEV *S1 = SE.getSCEV(V1); 58 const SCEV *S2 = SE.getSCEV(V2); 59 60 const SCEV *P0 = SE.getAddExpr(S0, S0); 61 const SCEV *P1 = SE.getAddExpr(S1, S1); 62 const SCEV *P2 = SE.getAddExpr(S2, S2); 63 64 const SCEVMulExpr *M0 = cast<SCEVMulExpr>(P0); 65 const SCEVMulExpr *M1 = cast<SCEVMulExpr>(P1); 66 const SCEVMulExpr *M2 = cast<SCEVMulExpr>(P2); 67 68 EXPECT_EQ(cast<SCEVConstant>(M0->getOperand(0))->getValue()->getZExtValue(), 69 2u); 70 EXPECT_EQ(cast<SCEVConstant>(M1->getOperand(0))->getValue()->getZExtValue(), 71 2u); 72 EXPECT_EQ(cast<SCEVConstant>(M2->getOperand(0))->getValue()->getZExtValue(), 73 2u); 74 75 // Before the RAUWs, these are all pointing to separate values. 76 EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0); 77 EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V1); 78 EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V2); 79 80 // Do some RAUWs. 81 V2->replaceAllUsesWith(V1); 82 V1->replaceAllUsesWith(V0); 83 84 // After the RAUWs, these should all be pointing to V0. 85 EXPECT_EQ(cast<SCEVUnknown>(M0->getOperand(1))->getValue(), V0); 86 EXPECT_EQ(cast<SCEVUnknown>(M1->getOperand(1))->getValue(), V0); 87 EXPECT_EQ(cast<SCEVUnknown>(M2->getOperand(1))->getValue(), V0); 88 } 89 90 TEST_F(ScalarEvolutionsTest, SCEVMultiplyAddRecs) { 91 Type *Ty = Type::getInt32Ty(Context); 92 SmallVector<Type *, 10> Types; 93 Types.append(10, Ty); 94 FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false); 95 Function *F = cast<Function>(M.getOrInsertFunction("f", FTy)); 96 BasicBlock *BB = BasicBlock::Create(Context, "entry", F); 97 ReturnInst::Create(Context, 0, BB); 98 99 // Create a ScalarEvolution and "run" it so that it gets initialized. 100 PM.add(&SE); 101 PM.run(M); 102 103 // It's possible to produce an empty loop through the default constructor, 104 // but you can't add any blocks to it without a LoopInfo pass. 105 Loop L; 106 const_cast<std::vector<BasicBlock*>&>(L.getBlocks()).push_back(BB); 107 108 Function::arg_iterator AI = F->arg_begin(); 109 SmallVector<const SCEV *, 5> A; 110 A.push_back(SE.getSCEV(&*AI++)); 111 A.push_back(SE.getSCEV(&*AI++)); 112 A.push_back(SE.getSCEV(&*AI++)); 113 A.push_back(SE.getSCEV(&*AI++)); 114 A.push_back(SE.getSCEV(&*AI++)); 115 const SCEV *A_rec = SE.getAddRecExpr(A, &L, SCEV::FlagAnyWrap); 116 117 SmallVector<const SCEV *, 5> B; 118 B.push_back(SE.getSCEV(&*AI++)); 119 B.push_back(SE.getSCEV(&*AI++)); 120 B.push_back(SE.getSCEV(&*AI++)); 121 B.push_back(SE.getSCEV(&*AI++)); 122 B.push_back(SE.getSCEV(&*AI++)); 123 const SCEV *B_rec = SE.getAddRecExpr(B, &L, SCEV::FlagAnyWrap); 124 125 /* Spot check that we perform this transformation: 126 {A0,+,A1,+,A2,+,A3,+,A4} * {B0,+,B1,+,B2,+,B3,+,B4} = 127 {A0*B0,+, 128 A1*B0 + A0*B1 + A1*B1,+, 129 A2*B0 + 2A1*B1 + A0*B2 + 2A2*B1 + 2A1*B2 + A2*B2,+, 130 A3*B0 + 3A2*B1 + 3A1*B2 + A0*B3 + 3A3*B1 + 6A2*B2 + 3A1*B3 + 3A3*B2 + 131 3A2*B3 + A3*B3,+, 132 A4*B0 + 4A3*B1 + 6A2*B2 + 4A1*B3 + A0*B4 + 4A4*B1 + 12A3*B2 + 12A2*B3 + 133 4A1*B4 + 6A4*B2 + 12A3*B3 + 6A2*B4 + 4A4*B3 + 4A3*B4 + A4*B4,+, 134 5A4*B1 + 10A3*B2 + 10A2*B3 + 5A1*B4 + 20A4*B2 + 30A3*B3 + 20A2*B4 + 135 30A4*B3 + 30A3*B4 + 20A4*B4,+, 136 15A4*B2 + 20A3*B3 + 15A2*B4 + 60A4*B3 + 60A3*B4 + 90A4*B4,+, 137 35A4*B3 + 35A3*B4 + 140A4*B4,+, 138 70A4*B4} 139 */ 140 141 const SCEVAddRecExpr *Product = 142 dyn_cast<SCEVAddRecExpr>(SE.getMulExpr(A_rec, B_rec)); 143 ASSERT_TRUE(Product); 144 ASSERT_EQ(Product->getNumOperands(), 9u); 145 146 SmallVector<const SCEV *, 16> Sum; 147 Sum.push_back(SE.getMulExpr(A[0], B[0])); 148 EXPECT_EQ(Product->getOperand(0), SE.getAddExpr(Sum)); 149 Sum.clear(); 150 151 // SCEV produces different an equal but different expression for these. 152 // Re-enable when PR11052 is fixed. 153 #if 0 154 Sum.push_back(SE.getMulExpr(A[1], B[0])); 155 Sum.push_back(SE.getMulExpr(A[0], B[1])); 156 Sum.push_back(SE.getMulExpr(A[1], B[1])); 157 EXPECT_EQ(Product->getOperand(1), SE.getAddExpr(Sum)); 158 Sum.clear(); 159 160 Sum.push_back(SE.getMulExpr(A[2], B[0])); 161 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[1])); 162 Sum.push_back(SE.getMulExpr(A[0], B[2])); 163 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[2], B[1])); 164 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 2), A[1], B[2])); 165 Sum.push_back(SE.getMulExpr(A[2], B[2])); 166 EXPECT_EQ(Product->getOperand(2), SE.getAddExpr(Sum)); 167 Sum.clear(); 168 169 Sum.push_back(SE.getMulExpr(A[3], B[0])); 170 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[1])); 171 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[2])); 172 Sum.push_back(SE.getMulExpr(A[0], B[3])); 173 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[1])); 174 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2])); 175 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[1], B[3])); 176 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[3], B[2])); 177 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 3), A[2], B[3])); 178 Sum.push_back(SE.getMulExpr(A[3], B[3])); 179 EXPECT_EQ(Product->getOperand(3), SE.getAddExpr(Sum)); 180 Sum.clear(); 181 182 Sum.push_back(SE.getMulExpr(A[4], B[0])); 183 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[1])); 184 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[2])); 185 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[3])); 186 Sum.push_back(SE.getMulExpr(A[0], B[4])); 187 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[1])); 188 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[2])); 189 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[2], B[3])); 190 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[1], B[4])); 191 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[4], B[2])); 192 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 12), A[3], B[3])); 193 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 6), A[2], B[4])); 194 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[4], B[3])); 195 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 4), A[3], B[4])); 196 Sum.push_back(SE.getMulExpr(A[4], B[4])); 197 EXPECT_EQ(Product->getOperand(4), SE.getAddExpr(Sum)); 198 Sum.clear(); 199 200 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[4], B[1])); 201 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[3], B[2])); 202 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 10), A[2], B[3])); 203 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 5), A[1], B[4])); 204 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[2])); 205 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[3])); 206 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[2], B[4])); 207 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[4], B[3])); 208 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 30), A[3], B[4])); 209 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[4], B[4])); 210 EXPECT_EQ(Product->getOperand(5), SE.getAddExpr(Sum)); 211 Sum.clear(); 212 213 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[4], B[2])); 214 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 20), A[3], B[3])); 215 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 15), A[2], B[4])); 216 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[4], B[3])); 217 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 60), A[3], B[4])); 218 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 90), A[4], B[4])); 219 EXPECT_EQ(Product->getOperand(6), SE.getAddExpr(Sum)); 220 Sum.clear(); 221 222 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[4], B[3])); 223 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 35), A[3], B[4])); 224 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 140), A[4], B[4])); 225 EXPECT_EQ(Product->getOperand(7), SE.getAddExpr(Sum)); 226 Sum.clear(); 227 #endif 228 229 Sum.push_back(SE.getMulExpr(SE.getConstant(Ty, 70), A[4], B[4])); 230 EXPECT_EQ(Product->getOperand(8), SE.getAddExpr(Sum)); 231 } 232 233 } // end anonymous namespace 234 } // end namespace llvm 235