1 //===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass reassociates n-ary add expressions and eliminates the redundancy 11 // exposed by the reassociation. 12 // 13 // A motivating example: 14 // 15 // void foo(int a, int b) { 16 // bar(a + b); 17 // bar((a + 2) + b); 18 // } 19 // 20 // An ideal compiler should reassociate (a + 2) + b to (a + b) + 2 and simplify 21 // the above code to 22 // 23 // int t = a + b; 24 // bar(t); 25 // bar(t + 2); 26 // 27 // However, the Reassociate pass is unable to do that because it processes each 28 // instruction individually and believes (a + 2) + b is the best form according 29 // to its rank system. 30 // 31 // To address this limitation, NaryReassociate reassociates an expression in a 32 // form that reuses existing instructions. As a result, NaryReassociate can 33 // reassociate (a + 2) + b in the example to (a + b) + 2 because it detects that 34 // (a + b) is computed before. 35 // 36 // NaryReassociate works as follows. For every instruction in the form of (a + 37 // b) + c, it checks whether a + c or b + c is already computed by a dominating 38 // instruction. If so, it then reassociates (a + b) + c into (a + c) + b or (b + 39 // c) + a and removes the redundancy accordingly. To efficiently look up whether 40 // an expression is computed before, we store each instruction seen and its SCEV 41 // into an SCEV-to-instruction map. 42 // 43 // Although the algorithm pattern-matches only ternary additions, it 44 // automatically handles many >3-ary expressions by walking through the function 45 // in the depth-first order. For example, given 46 // 47 // (a + c) + d 48 // ((a + b) + c) + d 49 // 50 // NaryReassociate first rewrites (a + b) + c to (a + c) + b, and then rewrites 51 // ((a + c) + b) + d into ((a + c) + d) + b. 52 // 53 // Finally, the above dominator-based algorithm may need to be run multiple 54 // iterations before emitting optimal code. One source of this need is that we 55 // only split an operand when it is used only once. The above algorithm can 56 // eliminate an instruction and decrease the usage count of its operands. As a 57 // result, an instruction that previously had multiple uses may become a 58 // single-use instruction and thus eligible for split consideration. For 59 // example, 60 // 61 // ac = a + c 62 // ab = a + b 63 // abc = ab + c 64 // ab2 = ab + b 65 // ab2c = ab2 + c 66 // 67 // In the first iteration, we cannot reassociate abc to ac+b because ab is used 68 // twice. However, we can reassociate ab2c to abc+b in the first iteration. As a 69 // result, ab2 becomes dead and ab will be used only once in the second 70 // iteration. 71 // 72 // Limitations and TODO items: 73 // 74 // 1) We only considers n-ary adds and muls for now. This should be extended 75 // and generalized. 76 // 77 //===----------------------------------------------------------------------===// 78 79 #include "llvm/Analysis/AssumptionCache.h" 80 #include "llvm/Analysis/ScalarEvolution.h" 81 #include "llvm/Analysis/TargetLibraryInfo.h" 82 #include "llvm/Analysis/TargetTransformInfo.h" 83 #include "llvm/Analysis/ValueTracking.h" 84 #include "llvm/IR/Dominators.h" 85 #include "llvm/IR/Module.h" 86 #include "llvm/IR/PatternMatch.h" 87 #include "llvm/Support/Debug.h" 88 #include "llvm/Support/raw_ostream.h" 89 #include "llvm/Transforms/Scalar.h" 90 #include "llvm/Transforms/Utils/Local.h" 91 using namespace llvm; 92 using namespace PatternMatch; 93 94 #define DEBUG_TYPE "nary-reassociate" 95 96 namespace { 97 class NaryReassociate : public FunctionPass { 98 public: 99 static char ID; 100 101 NaryReassociate(): FunctionPass(ID) { 102 initializeNaryReassociatePass(*PassRegistry::getPassRegistry()); 103 } 104 105 bool doInitialization(Module &M) override { 106 DL = &M.getDataLayout(); 107 return false; 108 } 109 bool runOnFunction(Function &F) override; 110 111 void getAnalysisUsage(AnalysisUsage &AU) const override { 112 AU.addPreserved<DominatorTreeWrapperPass>(); 113 AU.addPreserved<ScalarEvolutionWrapperPass>(); 114 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 115 AU.addRequired<AssumptionCacheTracker>(); 116 AU.addRequired<DominatorTreeWrapperPass>(); 117 AU.addRequired<ScalarEvolutionWrapperPass>(); 118 AU.addRequired<TargetLibraryInfoWrapperPass>(); 119 AU.addRequired<TargetTransformInfoWrapperPass>(); 120 AU.setPreservesCFG(); 121 } 122 123 private: 124 // Runs only one iteration of the dominator-based algorithm. See the header 125 // comments for why we need multiple iterations. 126 bool doOneIteration(Function &F); 127 128 // Reassociates I for better CSE. 129 Instruction *tryReassociate(Instruction *I); 130 131 // Reassociate GEP for better CSE. 132 Instruction *tryReassociateGEP(GetElementPtrInst *GEP); 133 // Try splitting GEP at the I-th index and see whether either part can be 134 // CSE'ed. This is a helper function for tryReassociateGEP. 135 // 136 // \p IndexedType The element type indexed by GEP's I-th index. This is 137 // equivalent to 138 // GEP->getIndexedType(GEP->getPointerOperand(), 0-th index, 139 // ..., i-th index). 140 GetElementPtrInst *tryReassociateGEPAtIndex(GetElementPtrInst *GEP, 141 unsigned I, Type *IndexedType); 142 // Given GEP's I-th index = LHS + RHS, see whether &Base[..][LHS][..] or 143 // &Base[..][RHS][..] can be CSE'ed and rewrite GEP accordingly. 144 GetElementPtrInst *tryReassociateGEPAtIndex(GetElementPtrInst *GEP, 145 unsigned I, Value *LHS, 146 Value *RHS, Type *IndexedType); 147 148 // Reassociate binary operators for better CSE. 149 Instruction *tryReassociateBinaryOp(BinaryOperator *I); 150 151 // A helper function for tryReassociateBinaryOp. LHS and RHS are explicitly 152 // passed. 153 Instruction *tryReassociateBinaryOp(Value *LHS, Value *RHS, 154 BinaryOperator *I); 155 // Rewrites I to (LHS op RHS) if LHS is computed already. 156 Instruction *tryReassociatedBinaryOp(const SCEV *LHS, Value *RHS, 157 BinaryOperator *I); 158 159 // Tries to match Op1 and Op2 by using V. 160 bool matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, Value *&Op2); 161 162 // Gets SCEV for (LHS op RHS). 163 const SCEV *getBinarySCEV(BinaryOperator *I, const SCEV *LHS, 164 const SCEV *RHS); 165 166 // Returns the closest dominator of \c Dominatee that computes 167 // \c CandidateExpr. Returns null if not found. 168 Instruction *findClosestMatchingDominator(const SCEV *CandidateExpr, 169 Instruction *Dominatee); 170 // GetElementPtrInst implicitly sign-extends an index if the index is shorter 171 // than the pointer size. This function returns whether Index is shorter than 172 // GEP's pointer size, i.e., whether Index needs to be sign-extended in order 173 // to be an index of GEP. 174 bool requiresSignExtension(Value *Index, GetElementPtrInst *GEP); 175 176 AssumptionCache *AC; 177 const DataLayout *DL; 178 DominatorTree *DT; 179 ScalarEvolution *SE; 180 TargetLibraryInfo *TLI; 181 TargetTransformInfo *TTI; 182 // A lookup table quickly telling which instructions compute the given SCEV. 183 // Note that there can be multiple instructions at different locations 184 // computing to the same SCEV, so we map a SCEV to an instruction list. For 185 // example, 186 // 187 // if (p1) 188 // foo(a + b); 189 // if (p2) 190 // bar(a + b); 191 DenseMap<const SCEV *, SmallVector<WeakVH, 2>> SeenExprs; 192 }; 193 } // anonymous namespace 194 195 char NaryReassociate::ID = 0; 196 INITIALIZE_PASS_BEGIN(NaryReassociate, "nary-reassociate", "Nary reassociation", 197 false, false) 198 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 199 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 200 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 201 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 202 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 203 INITIALIZE_PASS_END(NaryReassociate, "nary-reassociate", "Nary reassociation", 204 false, false) 205 206 FunctionPass *llvm::createNaryReassociatePass() { 207 return new NaryReassociate(); 208 } 209 210 bool NaryReassociate::runOnFunction(Function &F) { 211 if (skipFunction(F)) 212 return false; 213 214 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 215 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 216 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 217 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 218 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 219 220 bool Changed = false, ChangedInThisIteration; 221 do { 222 ChangedInThisIteration = doOneIteration(F); 223 Changed |= ChangedInThisIteration; 224 } while (ChangedInThisIteration); 225 return Changed; 226 } 227 228 // Whitelist the instruction types NaryReassociate handles for now. 229 static bool isPotentiallyNaryReassociable(Instruction *I) { 230 switch (I->getOpcode()) { 231 case Instruction::Add: 232 case Instruction::GetElementPtr: 233 case Instruction::Mul: 234 return true; 235 default: 236 return false; 237 } 238 } 239 240 bool NaryReassociate::doOneIteration(Function &F) { 241 bool Changed = false; 242 SeenExprs.clear(); 243 // Process the basic blocks in pre-order of the dominator tree. This order 244 // ensures that all bases of a candidate are in Candidates when we process it. 245 for (auto Node = GraphTraits<DominatorTree *>::nodes_begin(DT); 246 Node != GraphTraits<DominatorTree *>::nodes_end(DT); ++Node) { 247 BasicBlock *BB = Node->getBlock(); 248 for (auto I = BB->begin(); I != BB->end(); ++I) { 249 if (SE->isSCEVable(I->getType()) && isPotentiallyNaryReassociable(&*I)) { 250 const SCEV *OldSCEV = SE->getSCEV(&*I); 251 if (Instruction *NewI = tryReassociate(&*I)) { 252 Changed = true; 253 SE->forgetValue(&*I); 254 I->replaceAllUsesWith(NewI); 255 // If SeenExprs constains I's WeakVH, that entry will be replaced with 256 // nullptr. 257 RecursivelyDeleteTriviallyDeadInstructions(&*I, TLI); 258 I = NewI->getIterator(); 259 } 260 // Add the rewritten instruction to SeenExprs; the original instruction 261 // is deleted. 262 const SCEV *NewSCEV = SE->getSCEV(&*I); 263 SeenExprs[NewSCEV].push_back(WeakVH(&*I)); 264 // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I) 265 // is equivalent to I. However, ScalarEvolution::getSCEV may 266 // weaken nsw causing NewSCEV not to equal OldSCEV. For example, suppose 267 // we reassociate 268 // I = &a[sext(i +nsw j)] // assuming sizeof(a[0]) = 4 269 // to 270 // NewI = &a[sext(i)] + sext(j). 271 // 272 // ScalarEvolution computes 273 // getSCEV(I) = a + 4 * sext(i + j) 274 // getSCEV(newI) = a + 4 * sext(i) + 4 * sext(j) 275 // which are different SCEVs. 276 // 277 // To alleviate this issue of ScalarEvolution not always capturing 278 // equivalence, we add I to SeenExprs[OldSCEV] as well so that we can 279 // map both SCEV before and after tryReassociate(I) to I. 280 // 281 // This improvement is exercised in @reassociate_gep_nsw in nary-gep.ll. 282 if (NewSCEV != OldSCEV) 283 SeenExprs[OldSCEV].push_back(WeakVH(&*I)); 284 } 285 } 286 } 287 return Changed; 288 } 289 290 Instruction *NaryReassociate::tryReassociate(Instruction *I) { 291 switch (I->getOpcode()) { 292 case Instruction::Add: 293 case Instruction::Mul: 294 return tryReassociateBinaryOp(cast<BinaryOperator>(I)); 295 case Instruction::GetElementPtr: 296 return tryReassociateGEP(cast<GetElementPtrInst>(I)); 297 default: 298 llvm_unreachable("should be filtered out by isPotentiallyNaryReassociable"); 299 } 300 } 301 302 static bool isGEPFoldable(GetElementPtrInst *GEP, 303 const TargetTransformInfo *TTI) { 304 SmallVector<const Value*, 4> Indices; 305 for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) 306 Indices.push_back(*I); 307 return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(), 308 Indices) == TargetTransformInfo::TCC_Free; 309 } 310 311 Instruction *NaryReassociate::tryReassociateGEP(GetElementPtrInst *GEP) { 312 // Not worth reassociating GEP if it is foldable. 313 if (isGEPFoldable(GEP, TTI)) 314 return nullptr; 315 316 gep_type_iterator GTI = gep_type_begin(*GEP); 317 for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) { 318 if (isa<SequentialType>(*GTI++)) { 319 if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1, *GTI)) { 320 return NewGEP; 321 } 322 } 323 } 324 return nullptr; 325 } 326 327 bool NaryReassociate::requiresSignExtension(Value *Index, 328 GetElementPtrInst *GEP) { 329 unsigned PointerSizeInBits = 330 DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace()); 331 return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits; 332 } 333 334 GetElementPtrInst * 335 NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, 336 Type *IndexedType) { 337 Value *IndexToSplit = GEP->getOperand(I + 1); 338 if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) { 339 IndexToSplit = SExt->getOperand(0); 340 } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) { 341 // zext can be treated as sext if the source is non-negative. 342 if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT)) 343 IndexToSplit = ZExt->getOperand(0); 344 } 345 346 if (AddOperator *AO = dyn_cast<AddOperator>(IndexToSplit)) { 347 // If the I-th index needs sext and the underlying add is not equipped with 348 // nsw, we cannot split the add because 349 // sext(LHS + RHS) != sext(LHS) + sext(RHS). 350 if (requiresSignExtension(IndexToSplit, GEP) && 351 computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) != 352 OverflowResult::NeverOverflows) 353 return nullptr; 354 355 Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); 356 // IndexToSplit = LHS + RHS. 357 if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType)) 358 return NewGEP; 359 // Symmetrically, try IndexToSplit = RHS + LHS. 360 if (LHS != RHS) { 361 if (auto *NewGEP = 362 tryReassociateGEPAtIndex(GEP, I, RHS, LHS, IndexedType)) 363 return NewGEP; 364 } 365 } 366 return nullptr; 367 } 368 369 GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( 370 GetElementPtrInst *GEP, unsigned I, Value *LHS, Value *RHS, 371 Type *IndexedType) { 372 // Look for GEP's closest dominator that has the same SCEV as GEP except that 373 // the I-th index is replaced with LHS. 374 SmallVector<const SCEV *, 4> IndexExprs; 375 for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) 376 IndexExprs.push_back(SE->getSCEV(*Index)); 377 // Replace the I-th index with LHS. 378 IndexExprs[I] = SE->getSCEV(LHS); 379 if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) && 380 DL->getTypeSizeInBits(LHS->getType()) < 381 DL->getTypeSizeInBits(GEP->getOperand(I)->getType())) { 382 // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to 383 // zext if the source operand is proved non-negative. We should do that 384 // consistently so that CandidateExpr more likely appears before. See 385 // @reassociate_gep_assume for an example of this canonicalization. 386 IndexExprs[I] = 387 SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType()); 388 } 389 const SCEV *CandidateExpr = SE->getGEPExpr( 390 GEP->getSourceElementType(), SE->getSCEV(GEP->getPointerOperand()), 391 IndexExprs, GEP->isInBounds()); 392 393 Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP); 394 if (Candidate == nullptr) 395 return nullptr; 396 397 IRBuilder<> Builder(GEP); 398 // Candidate does not necessarily have the same pointer type as GEP. Use 399 // bitcast or pointer cast to make sure they have the same type, so that the 400 // later RAUW doesn't complain. 401 Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->getType()); 402 assert(Candidate->getType() == GEP->getType()); 403 404 // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType) 405 uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType); 406 Type *ElementType = GEP->getResultElementType(); 407 uint64_t ElementSize = DL->getTypeAllocSize(ElementType); 408 // Another less rare case: because I is not necessarily the last index of the 409 // GEP, the size of the type at the I-th index (IndexedSize) is not 410 // necessarily divisible by ElementSize. For example, 411 // 412 // #pragma pack(1) 413 // struct S { 414 // int a[3]; 415 // int64 b[8]; 416 // }; 417 // #pragma pack() 418 // 419 // sizeof(S) = 100 is indivisible by sizeof(int64) = 8. 420 // 421 // TODO: bail out on this case for now. We could emit uglygep. 422 if (IndexedSize % ElementSize != 0) 423 return nullptr; 424 425 // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0]))); 426 Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); 427 if (RHS->getType() != IntPtrTy) 428 RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy); 429 if (IndexedSize != ElementSize) { 430 RHS = Builder.CreateMul( 431 RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize)); 432 } 433 GetElementPtrInst *NewGEP = 434 cast<GetElementPtrInst>(Builder.CreateGEP(Candidate, RHS)); 435 NewGEP->setIsInBounds(GEP->isInBounds()); 436 NewGEP->takeName(GEP); 437 return NewGEP; 438 } 439 440 Instruction *NaryReassociate::tryReassociateBinaryOp(BinaryOperator *I) { 441 Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); 442 if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) 443 return NewI; 444 if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I)) 445 return NewI; 446 return nullptr; 447 } 448 449 Instruction *NaryReassociate::tryReassociateBinaryOp(Value *LHS, Value *RHS, 450 BinaryOperator *I) { 451 Value *A = nullptr, *B = nullptr; 452 // To be conservative, we reassociate I only when it is the only user of (A op 453 // B). 454 if (LHS->hasOneUse() && matchTernaryOp(I, LHS, A, B)) { 455 // I = (A op B) op RHS 456 // = (A op RHS) op B or (B op RHS) op A 457 const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); 458 const SCEV *RHSExpr = SE->getSCEV(RHS); 459 if (BExpr != RHSExpr) { 460 if (auto *NewI = 461 tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr), B, I)) 462 return NewI; 463 } 464 if (AExpr != RHSExpr) { 465 if (auto *NewI = 466 tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I)) 467 return NewI; 468 } 469 } 470 return nullptr; 471 } 472 473 Instruction *NaryReassociate::tryReassociatedBinaryOp(const SCEV *LHSExpr, 474 Value *RHS, 475 BinaryOperator *I) { 476 // Look for the closest dominator LHS of I that computes LHSExpr, and replace 477 // I with LHS op RHS. 478 auto *LHS = findClosestMatchingDominator(LHSExpr, I); 479 if (LHS == nullptr) 480 return nullptr; 481 482 Instruction *NewI = nullptr; 483 switch (I->getOpcode()) { 484 case Instruction::Add: 485 NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); 486 break; 487 case Instruction::Mul: 488 NewI = BinaryOperator::CreateMul(LHS, RHS, "", I); 489 break; 490 default: 491 llvm_unreachable("Unexpected instruction."); 492 } 493 NewI->takeName(I); 494 return NewI; 495 } 496 497 bool NaryReassociate::matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, 498 Value *&Op2) { 499 switch (I->getOpcode()) { 500 case Instruction::Add: 501 return match(V, m_Add(m_Value(Op1), m_Value(Op2))); 502 case Instruction::Mul: 503 return match(V, m_Mul(m_Value(Op1), m_Value(Op2))); 504 default: 505 llvm_unreachable("Unexpected instruction."); 506 } 507 return false; 508 } 509 510 const SCEV *NaryReassociate::getBinarySCEV(BinaryOperator *I, const SCEV *LHS, 511 const SCEV *RHS) { 512 switch (I->getOpcode()) { 513 case Instruction::Add: 514 return SE->getAddExpr(LHS, RHS); 515 case Instruction::Mul: 516 return SE->getMulExpr(LHS, RHS); 517 default: 518 llvm_unreachable("Unexpected instruction."); 519 } 520 return nullptr; 521 } 522 523 Instruction * 524 NaryReassociate::findClosestMatchingDominator(const SCEV *CandidateExpr, 525 Instruction *Dominatee) { 526 auto Pos = SeenExprs.find(CandidateExpr); 527 if (Pos == SeenExprs.end()) 528 return nullptr; 529 530 auto &Candidates = Pos->second; 531 // Because we process the basic blocks in pre-order of the dominator tree, a 532 // candidate that doesn't dominate the current instruction won't dominate any 533 // future instruction either. Therefore, we pop it out of the stack. This 534 // optimization makes the algorithm O(n). 535 while (!Candidates.empty()) { 536 // Candidates stores WeakVHs, so a candidate can be nullptr if it's removed 537 // during rewriting. 538 if (Value *Candidate = Candidates.back()) { 539 Instruction *CandidateInstruction = cast<Instruction>(Candidate); 540 if (DT->dominates(CandidateInstruction, Dominatee)) 541 return CandidateInstruction; 542 } 543 Candidates.pop_back(); 544 } 545 return nullptr; 546 } 547