Home | History | Annotate | Download | only in AArch64
      1 //===-- AArch64AddressTypePromotion.cpp --- Promote type for addr accesses -==//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This pass tries to promote the computations use to obtained a sign extended
     11 // value used into memory accesses.
     12 // E.g.
     13 // a = add nsw i32 b, 3
     14 // d = sext i32 a to i64
     15 // e = getelementptr ..., i64 d
     16 //
     17 // =>
     18 // f = sext i32 b to i64
     19 // a = add nsw i64 f, 3
     20 // e = getelementptr ..., i64 a
     21 //
     22 // This is legal to do if the computations are marked with either nsw or nuw
     23 // markers. Moreover, the current heuristic is simple: it does not create new
     24 // sext operations, i.e., it gives up when a sext would have forked (e.g., if a
     25 // = add i32 b, c, two sexts are required to promote the computation).
     26 //
     27 // FIXME: This pass may be useful for other targets too.
     28 // ===---------------------------------------------------------------------===//
     29 
     30 #include "AArch64.h"
     31 #include "llvm/ADT/DenseMap.h"
     32 #include "llvm/ADT/SmallPtrSet.h"
     33 #include "llvm/ADT/SmallVector.h"
     34 #include "llvm/IR/Constants.h"
     35 #include "llvm/IR/Dominators.h"
     36 #include "llvm/IR/Function.h"
     37 #include "llvm/IR/Instructions.h"
     38 #include "llvm/IR/Module.h"
     39 #include "llvm/IR/Operator.h"
     40 #include "llvm/Pass.h"
     41 #include "llvm/Support/CommandLine.h"
     42 #include "llvm/Support/Debug.h"
     43 #include "llvm/Support/raw_ostream.h"
     44 
     45 using namespace llvm;
     46 
     47 #define DEBUG_TYPE "aarch64-type-promotion"
     48 
     49 static cl::opt<bool>
     50 EnableAddressTypePromotion("aarch64-type-promotion", cl::Hidden,
     51                            cl::desc("Enable the type promotion pass"),
     52                            cl::init(true));
     53 static cl::opt<bool>
     54 EnableMerge("aarch64-type-promotion-merge", cl::Hidden,
     55             cl::desc("Enable merging of redundant sexts when one is dominating"
     56                      " the other."),
     57             cl::init(true));
     58 
     59 #define AARCH64_TYPE_PROMO_NAME "AArch64 Address Type Promotion"
     60 
     61 //===----------------------------------------------------------------------===//
     62 //                       AArch64AddressTypePromotion
     63 //===----------------------------------------------------------------------===//
     64 
     65 namespace llvm {
     66 void initializeAArch64AddressTypePromotionPass(PassRegistry &);
     67 }
     68 
     69 namespace {
     70 class AArch64AddressTypePromotion : public FunctionPass {
     71 
     72 public:
     73   static char ID;
     74   AArch64AddressTypePromotion()
     75       : FunctionPass(ID), Func(nullptr), ConsideredSExtType(nullptr) {
     76     initializeAArch64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
     77   }
     78 
     79   const char *getPassName() const override {
     80     return AARCH64_TYPE_PROMO_NAME;
     81   }
     82 
     83   /// Iterate over the functions and promote the computation of interesting
     84   // sext instructions.
     85   bool runOnFunction(Function &F) override;
     86 
     87 private:
     88   /// The current function.
     89   Function *Func;
     90   /// Filter out all sexts that does not have this type.
     91   /// Currently initialized with Int64Ty.
     92   Type *ConsideredSExtType;
     93 
     94   // This transformation requires dominator info.
     95   void getAnalysisUsage(AnalysisUsage &AU) const override {
     96     AU.setPreservesCFG();
     97     AU.addRequired<DominatorTreeWrapperPass>();
     98     AU.addPreserved<DominatorTreeWrapperPass>();
     99     FunctionPass::getAnalysisUsage(AU);
    100   }
    101 
    102   typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
    103   typedef SmallVector<Instruction *, 16> Instructions;
    104   typedef DenseMap<Value *, Instructions> ValueToInsts;
    105 
    106   /// Check if it is profitable to move a sext through this instruction.
    107   /// Currently, we consider it is profitable if:
    108   /// - Inst is used only once (no need to insert truncate).
    109   /// - Inst has only one operand that will require a sext operation (we do
    110   ///   do not create new sext operation).
    111   bool shouldGetThrough(const Instruction *Inst);
    112 
    113   /// Check if it is possible and legal to move a sext through this
    114   /// instruction.
    115   /// Current heuristic considers that we can get through:
    116   /// - Arithmetic operation marked with the nsw or nuw flag.
    117   /// - Other sext operation.
    118   /// - Truncate operation if it was just dropping sign extended bits.
    119   bool canGetThrough(const Instruction *Inst);
    120 
    121   /// Move sext operations through safe to sext instructions.
    122   bool propagateSignExtension(Instructions &SExtInsts);
    123 
    124   /// Is this sext should be considered for code motion.
    125   /// We look for sext with ConsideredSExtType and uses in at least one
    126   // GetElementPtrInst.
    127   bool shouldConsiderSExt(const Instruction *SExt) const;
    128 
    129   /// Collect all interesting sext operations, i.e., the ones with the right
    130   /// type and used in memory accesses.
    131   /// More precisely, a sext instruction is considered as interesting if it
    132   /// is used in a "complex" getelementptr or it exits at least another
    133   /// sext instruction that sign extended the same initial value.
    134   /// A getelementptr is considered as "complex" if it has more than 2
    135   // operands.
    136   void analyzeSExtension(Instructions &SExtInsts);
    137 
    138   /// Merge redundant sign extension operations in common dominator.
    139   void mergeSExts(ValueToInsts &ValToSExtendedUses,
    140                   SetOfInstructions &ToRemove);
    141 };
    142 } // end anonymous namespace.
    143 
    144 char AArch64AddressTypePromotion::ID = 0;
    145 
    146 INITIALIZE_PASS_BEGIN(AArch64AddressTypePromotion, "aarch64-type-promotion",
    147                       AARCH64_TYPE_PROMO_NAME, false, false)
    148 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
    149 INITIALIZE_PASS_END(AArch64AddressTypePromotion, "aarch64-type-promotion",
    150                     AARCH64_TYPE_PROMO_NAME, false, false)
    151 
    152 FunctionPass *llvm::createAArch64AddressTypePromotionPass() {
    153   return new AArch64AddressTypePromotion();
    154 }
    155 
    156 bool AArch64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
    157   if (isa<SExtInst>(Inst))
    158     return true;
    159 
    160   const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
    161   if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
    162       (BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
    163     return true;
    164 
    165   // sext(trunc(sext)) --> sext
    166   if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
    167     const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
    168     // Check that the truncate just drop sign extended bits.
    169     if (Inst->getType()->getIntegerBitWidth() >=
    170             Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
    171         Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
    172             ConsideredSExtType->getIntegerBitWidth())
    173       return true;
    174   }
    175 
    176   return false;
    177 }
    178 
    179 bool AArch64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
    180   // If the type of the sext is the same as the considered one, this sext
    181   // will become useless.
    182   // Otherwise, we will have to do something to preserve the original value,
    183   // unless it is used once.
    184   if (isa<SExtInst>(Inst) &&
    185       (Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
    186     return true;
    187 
    188   // If the Inst is used more that once, we may need to insert truncate
    189   // operations and we don't do that at the moment.
    190   if (!Inst->hasOneUse())
    191     return false;
    192 
    193   // This truncate is used only once, thus if we can get thourgh, it will become
    194   // useless.
    195   if (isa<TruncInst>(Inst))
    196     return true;
    197 
    198   // If both operands are not constant, a new sext will be created here.
    199   // Current heuristic is: each step should be profitable.
    200   // Therefore we don't allow to increase the number of sext even if it may
    201   // be profitable later on.
    202   if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
    203     return true;
    204 
    205   return false;
    206 }
    207 
    208 static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
    209   return !(isa<SelectInst>(Inst) && OpIdx == 0);
    210 }
    211 
    212 bool
    213 AArch64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
    214   if (SExt->getType() != ConsideredSExtType)
    215     return false;
    216 
    217   for (const User *U : SExt->users()) {
    218     if (isa<GetElementPtrInst>(U))
    219       return true;
    220   }
    221 
    222   return false;
    223 }
    224 
    225 // Input:
    226 // - SExtInsts contains all the sext instructions that are used directly in
    227 //   GetElementPtrInst, i.e., access to memory.
    228 // Algorithm:
    229 // - For each sext operation in SExtInsts:
    230 //   Let var be the operand of sext.
    231 //   while it is profitable (see shouldGetThrough), legal, and safe
    232 //   (see canGetThrough) to move sext through var's definition:
    233 //   * promote the type of var's definition.
    234 //   * fold var into sext uses.
    235 //   * move sext above var's definition.
    236 //   * update sext operand to use the operand of var that should be sign
    237 //     extended (by construction there is only one).
    238 //
    239 //   E.g.,
    240 //   a = ... i32 c, 3
    241 //   b = sext i32 a to i64 <- is it legal/safe/profitable to get through 'a'
    242 //   ...
    243 //   = b
    244 // => Yes, update the code
    245 //   b = sext i32 c to i64
    246 //   a = ... i64 b, 3
    247 //   ...
    248 //   = a
    249 // Iterate on 'c'.
    250 bool
    251 AArch64AddressTypePromotion::propagateSignExtension(Instructions &SExtInsts) {
    252   DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");
    253 
    254   bool LocalChange = false;
    255   SetOfInstructions ToRemove;
    256   ValueToInsts ValToSExtendedUses;
    257   while (!SExtInsts.empty()) {
    258     // Get through simple chain.
    259     Instruction *SExt = SExtInsts.pop_back_val();
    260 
    261     DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');
    262 
    263     // If this SExt has already been merged continue.
    264     if (SExt->use_empty() && ToRemove.count(SExt)) {
    265       DEBUG(dbgs() << "No uses => marked as delete\n");
    266       continue;
    267     }
    268 
    269     // Now try to get through the chain of definitions.
    270     while (auto *Inst = dyn_cast<Instruction>(SExt->getOperand(0))) {
    271       DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
    272       if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
    273         // We cannot get through something that is not an Instruction
    274         // or not safe to SExt.
    275         DEBUG(dbgs() << "Cannot get through\n");
    276         break;
    277       }
    278 
    279       LocalChange = true;
    280       // If this is a sign extend, it becomes useless.
    281       if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
    282         DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
    283         // We cannot use replaceAllUsesWith here because we may trigger some
    284         // assertion on the type as all involved sext operation may have not
    285         // been moved yet.
    286         while (!Inst->use_empty()) {
    287           Use &U = *Inst->use_begin();
    288           Instruction *User = dyn_cast<Instruction>(U.getUser());
    289           assert(User && "User of sext is not an Instruction!");
    290           User->setOperand(U.getOperandNo(), SExt);
    291         }
    292         ToRemove.insert(Inst);
    293         SExt->setOperand(0, Inst->getOperand(0));
    294         SExt->moveBefore(Inst);
    295         continue;
    296       }
    297 
    298       // Get through the Instruction:
    299       // 1. Update its type.
    300       // 2. Replace the uses of SExt by Inst.
    301       // 3. Sign extend each operand that needs to be sign extended.
    302 
    303       // Step #1.
    304       Inst->mutateType(SExt->getType());
    305       // Step #2.
    306       SExt->replaceAllUsesWith(Inst);
    307       // Step #3.
    308       Instruction *SExtForOpnd = SExt;
    309 
    310       DEBUG(dbgs() << "Propagate SExt to operands\n");
    311       for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
    312            ++OpIdx) {
    313         DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
    314         if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
    315             !shouldSExtOperand(Inst, OpIdx)) {
    316           DEBUG(dbgs() << "No need to propagate\n");
    317           continue;
    318         }
    319         // Check if we can statically sign extend the operand.
    320         Value *Opnd = Inst->getOperand(OpIdx);
    321         if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
    322           DEBUG(dbgs() << "Statically sign extend\n");
    323           Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
    324                                                          Cst->getSExtValue()));
    325           continue;
    326         }
    327         // UndefValue are typed, so we have to statically sign extend them.
    328         if (isa<UndefValue>(Opnd)) {
    329           DEBUG(dbgs() << "Statically sign extend\n");
    330           Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
    331           continue;
    332         }
    333 
    334         // Otherwise we have to explicity sign extend it.
    335         assert(SExtForOpnd &&
    336                "Only one operand should have been sign extended");
    337 
    338         SExtForOpnd->setOperand(0, Opnd);
    339 
    340         DEBUG(dbgs() << "Move before:\n" << *Inst << "\nSign extend\n");
    341         // Move the sign extension before the insertion point.
    342         SExtForOpnd->moveBefore(Inst);
    343         Inst->setOperand(OpIdx, SExtForOpnd);
    344         // If more sext are required, new instructions will have to be created.
    345         SExtForOpnd = nullptr;
    346       }
    347       if (SExtForOpnd == SExt) {
    348         DEBUG(dbgs() << "Sign extension is useless now\n");
    349         ToRemove.insert(SExt);
    350         break;
    351       }
    352     }
    353 
    354     // If the use is already of the right type, connect its uses to its argument
    355     // and delete it.
    356     // This can happen for an Instruction all uses of which are sign extended.
    357     if (!ToRemove.count(SExt) &&
    358         SExt->getType() == SExt->getOperand(0)->getType()) {
    359       DEBUG(dbgs() << "Sign extension is useless, attach its use to "
    360                       "its argument\n");
    361       SExt->replaceAllUsesWith(SExt->getOperand(0));
    362       ToRemove.insert(SExt);
    363     } else
    364       ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
    365   }
    366 
    367   if (EnableMerge)
    368     mergeSExts(ValToSExtendedUses, ToRemove);
    369 
    370   // Remove all instructions marked as ToRemove.
    371   for (Instruction *I: ToRemove)
    372     I->eraseFromParent();
    373   return LocalChange;
    374 }
    375 
    376 void AArch64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
    377                                              SetOfInstructions &ToRemove) {
    378   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    379 
    380   for (auto &Entry : ValToSExtendedUses) {
    381     Instructions &Insts = Entry.second;
    382     Instructions CurPts;
    383     for (Instruction *Inst : Insts) {
    384       if (ToRemove.count(Inst))
    385         continue;
    386       bool inserted = false;
    387       for (auto &Pt : CurPts) {
    388         if (DT.dominates(Inst, Pt)) {
    389           DEBUG(dbgs() << "Replace all uses of:\n" << *Pt << "\nwith:\n"
    390                        << *Inst << '\n');
    391           Pt->replaceAllUsesWith(Inst);
    392           ToRemove.insert(Pt);
    393           Pt = Inst;
    394           inserted = true;
    395           break;
    396         }
    397         if (!DT.dominates(Pt, Inst))
    398           // Give up if we need to merge in a common dominator as the
    399           // expermients show it is not profitable.
    400           continue;
    401 
    402         DEBUG(dbgs() << "Replace all uses of:\n" << *Inst << "\nwith:\n"
    403                      << *Pt << '\n');
    404         Inst->replaceAllUsesWith(Pt);
    405         ToRemove.insert(Inst);
    406         inserted = true;
    407         break;
    408       }
    409       if (!inserted)
    410         CurPts.push_back(Inst);
    411     }
    412   }
    413 }
    414 
    415 void AArch64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
    416   DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");
    417 
    418   DenseMap<Value *, Instruction *> SeenChains;
    419 
    420   for (auto &BB : *Func) {
    421     for (auto &II : BB) {
    422       Instruction *SExt = &II;
    423 
    424       // Collect all sext operation per type.
    425       if (!isa<SExtInst>(SExt) || !shouldConsiderSExt(SExt))
    426         continue;
    427 
    428       DEBUG(dbgs() << "Found:\n" << (*SExt) << '\n');
    429 
    430       // Cases where we actually perform the optimization:
    431       // 1. SExt is used in a getelementptr with more than 2 operand =>
    432       //    likely we can merge some computation if they are done on 64 bits.
    433       // 2. The beginning of the SExt chain is SExt several time. =>
    434       //    code sharing is possible.
    435 
    436       bool insert = false;
    437       // #1.
    438       for (const User *U : SExt->users()) {
    439         const Instruction *Inst = dyn_cast<GetElementPtrInst>(U);
    440         if (Inst && Inst->getNumOperands() > 2) {
    441           DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" << *Inst
    442                        << '\n');
    443           insert = true;
    444           break;
    445         }
    446       }
    447 
    448       // #2.
    449       // Check the head of the chain.
    450       Instruction *Inst = SExt;
    451       Value *Last;
    452       do {
    453         int OpdIdx = 0;
    454         const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
    455         if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
    456           OpdIdx = 1;
    457         Last = Inst->getOperand(OpdIdx);
    458         Inst = dyn_cast<Instruction>(Last);
    459       } while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));
    460 
    461       DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
    462       DenseMap<Value *, Instruction *>::iterator AlreadySeen =
    463           SeenChains.find(Last);
    464       if (insert || AlreadySeen != SeenChains.end()) {
    465         DEBUG(dbgs() << "Insert\n");
    466         SExtInsts.push_back(SExt);
    467         if (AlreadySeen != SeenChains.end() && AlreadySeen->second != nullptr) {
    468           DEBUG(dbgs() << "Insert chain member\n");
    469           SExtInsts.push_back(AlreadySeen->second);
    470           SeenChains[Last] = nullptr;
    471         }
    472       } else {
    473         DEBUG(dbgs() << "Record its chain membership\n");
    474         SeenChains[Last] = SExt;
    475       }
    476     }
    477   }
    478 }
    479 
    480 bool AArch64AddressTypePromotion::runOnFunction(Function &F) {
    481   if (skipFunction(F))
    482     return false;
    483 
    484   if (!EnableAddressTypePromotion || F.isDeclaration())
    485     return false;
    486   Func = &F;
    487   ConsideredSExtType = Type::getInt64Ty(Func->getContext());
    488 
    489   DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n');
    490 
    491   Instructions SExtInsts;
    492   analyzeSExtension(SExtInsts);
    493   return propagateSignExtension(SExtInsts);
    494 }
    495