Home | History | Annotate | Download | only in ARM
      1 //===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 /// \file
     11 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
     12 /// purpose of this pass is do some IR pattern matching to create ACLE
     13 /// DSP intrinsics, which map on these 32-bit SIMD operations.
     14 /// This pass runs only when unaligned accesses is supported/enabled.
     15 //
     16 //===----------------------------------------------------------------------===//
     17 
     18 #include "llvm/ADT/Statistic.h"
     19 #include "llvm/ADT/SmallPtrSet.h"
     20 #include "llvm/Analysis/AliasAnalysis.h"
     21 #include "llvm/Analysis/LoopAccessAnalysis.h"
     22 #include "llvm/Analysis/LoopPass.h"
     23 #include "llvm/Analysis/LoopInfo.h"
     24 #include "llvm/IR/Instructions.h"
     25 #include "llvm/IR/NoFolder.h"
     26 #include "llvm/Transforms/Scalar.h"
     27 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
     28 #include "llvm/Transforms/Utils/LoopUtils.h"
     29 #include "llvm/Pass.h"
     30 #include "llvm/PassRegistry.h"
     31 #include "llvm/PassSupport.h"
     32 #include "llvm/Support/Debug.h"
     33 #include "llvm/IR/PatternMatch.h"
     34 #include "llvm/CodeGen/TargetPassConfig.h"
     35 #include "ARM.h"
     36 #include "ARMSubtarget.h"
     37 
     38 using namespace llvm;
     39 using namespace PatternMatch;
     40 
     41 #define DEBUG_TYPE "arm-parallel-dsp"
     42 
     43 STATISTIC(NumSMLAD , "Number of smlad instructions generated");
     44 
     45 namespace {
     46   struct OpChain;
     47   struct BinOpChain;
     48   struct Reduction;
     49 
     50   using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;
     51   using ReductionList   = SmallVector<Reduction, 8>;
     52   using ValueList       = SmallVector<Value*, 8>;
     53   using MemInstList     = SmallVector<Instruction*, 8>;
     54   using PMACPair        = std::pair<BinOpChain*,BinOpChain*>;
     55   using PMACPairList    = SmallVector<PMACPair, 8>;
     56   using Instructions    = SmallVector<Instruction*,16>;
     57   using MemLocList      = SmallVector<MemoryLocation, 4>;
     58 
     59   struct OpChain {
     60     Instruction   *Root;
     61     ValueList     AllValues;
     62     MemInstList   VecLd;    // List of all load instructions.
     63     MemLocList    MemLocs;  // All memory locations read by this tree.
     64     bool          ReadOnly = true;
     65 
     66     OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
     67     virtual ~OpChain() = default;
     68 
     69     void SetMemoryLocations() {
     70       const auto Size = MemoryLocation::UnknownSize;
     71       for (auto *V : AllValues) {
     72         if (auto *I = dyn_cast<Instruction>(V)) {
     73           if (I->mayWriteToMemory())
     74             ReadOnly = false;
     75           if (auto *Ld = dyn_cast<LoadInst>(V))
     76             MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
     77         }
     78       }
     79     }
     80 
     81     unsigned size() const { return AllValues.size(); }
     82   };
     83 
     84   // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
     85   // 'Reduction' contains the phi-node and accumulator statement from where we
     86   // start pattern matching, and 'BinOpChain' the multiplication
     87   // instructions that are candidates for parallel execution.
     88   struct BinOpChain : public OpChain {
     89     ValueList     LHS;      // List of all (narrow) left hand operands.
     90     ValueList     RHS;      // List of all (narrow) right hand operands.
     91 
     92     BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
     93       OpChain(I, lhs), LHS(lhs), RHS(rhs) {
     94         for (auto *V : RHS)
     95           AllValues.push_back(V);
     96       }
     97   };
     98 
     99   struct Reduction {
    100     PHINode         *Phi;             // The Phi-node from where we start
    101                                       // pattern matching.
    102     Instruction     *AccIntAdd;       // The accumulating integer add statement,
    103                                       // i.e, the reduction statement.
    104 
    105     OpChainList     MACCandidates;    // The MAC candidates associated with
    106                                       // this reduction statement.
    107     Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
    108   };
    109 
    110   class ARMParallelDSP : public LoopPass {
    111     ScalarEvolution   *SE;
    112     AliasAnalysis     *AA;
    113     TargetLibraryInfo *TLI;
    114     DominatorTree     *DT;
    115     LoopInfo          *LI;
    116     Loop              *L;
    117     const DataLayout  *DL;
    118     Module            *M;
    119 
    120     bool InsertParallelMACs(Reduction &Reduction, PMACPairList &PMACPairs);
    121     bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
    122     PMACPairList CreateParallelMACPairs(OpChainList &Candidates);
    123     Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
    124                                  Instruction *Acc, Instruction *InsertAfter);
    125 
    126     /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
    127     /// Dual performs two signed 16x16-bit multiplications. It adds the
    128     /// products to a 32-bit accumulate operand. Optionally, the instruction can
    129     /// exchange the halfwords of the second operand before performing the
    130     /// arithmetic.
    131     bool MatchSMLAD(Function &F);
    132 
    133   public:
    134     static char ID;
    135 
    136     ARMParallelDSP() : LoopPass(ID) { }
    137 
    138     void getAnalysisUsage(AnalysisUsage &AU) const override {
    139       LoopPass::getAnalysisUsage(AU);
    140       AU.addRequired<AssumptionCacheTracker>();
    141       AU.addRequired<ScalarEvolutionWrapperPass>();
    142       AU.addRequired<AAResultsWrapperPass>();
    143       AU.addRequired<TargetLibraryInfoWrapperPass>();
    144       AU.addRequired<LoopInfoWrapperPass>();
    145       AU.addRequired<DominatorTreeWrapperPass>();
    146       AU.addRequired<TargetPassConfig>();
    147       AU.addPreserved<LoopInfoWrapperPass>();
    148       AU.setPreservesCFG();
    149     }
    150 
    151     bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
    152       L = TheLoop;
    153       SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
    154       AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
    155       TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
    156       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
    157       LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
    158       auto &TPC = getAnalysis<TargetPassConfig>();
    159 
    160       BasicBlock *Header = TheLoop->getHeader();
    161       if (!Header)
    162         return false;
    163 
    164       // TODO: We assume the loop header and latch to be the same block.
    165       // This is not a fundamental restriction, but lifting this would just
    166       // require more work to do the transformation and then patch up the CFG.
    167       if (Header != TheLoop->getLoopLatch()) {
    168         LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
    169                              "running pass ARMParallelDSP\n");
    170         return false;
    171       }
    172 
    173       Function &F = *Header->getParent();
    174       M = F.getParent();
    175       DL = &M->getDataLayout();
    176 
    177       auto &TM = TPC.getTM<TargetMachine>();
    178       auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
    179 
    180       if (!ST->allowsUnalignedMem()) {
    181         LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
    182                              "running pass ARMParallelDSP\n");
    183         return false;
    184       }
    185 
    186       if (!ST->hasDSP()) {
    187         LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
    188                              "ARMParallelDSP\n");
    189         return false;
    190       }
    191 
    192       LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
    193       bool Changes = false;
    194 
    195       LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n\n");
    196       Changes = MatchSMLAD(F);
    197       return Changes;
    198     }
    199   };
    200 }
    201 
    202 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
    203 // instructions, which is set to 16. So here we should collect all i8 and i16
    204 // narrow operations.
    205 // TODO: we currently only collect i16, and will support i8 later, so that's
    206 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
    207 template<unsigned MaxBitWidth>
    208 static bool IsNarrowSequence(Value *V, ValueList &VL) {
    209   LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V->dump());
    210   ConstantInt *CInt;
    211 
    212   if (match(V, m_ConstantInt(CInt))) {
    213     // TODO: if a constant is used, it needs to fit within the bit width.
    214     return false;
    215   }
    216 
    217   auto *I = dyn_cast<Instruction>(V);
    218   if (!I)
    219    return false;
    220 
    221   Value *Val, *LHS, *RHS;
    222   if (match(V, m_Trunc(m_Value(Val)))) {
    223     if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
    224       return IsNarrowSequence<MaxBitWidth>(Val, VL);
    225   } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
    226     // TODO: we need to implement sadd16/sadd8 for this, which enables to
    227     // also do the rewrite for smlad8.ll, but it is unsupported for now.
    228     LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
    229     return false;
    230   } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
    231     if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
    232       LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
    233         cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() << "\n");
    234       return false;
    235     }
    236 
    237     if (match(Val, m_Load(m_Value()))) {
    238       LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump());
    239       VL.push_back(Val);
    240       VL.push_back(I);
    241       return true;
    242     }
    243   }
    244   LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
    245   return false;
    246 }
    247 
    248 // Element-by-element comparison of Value lists returning true if they are
    249 // instructions with the same opcode or constants with the same value.
    250 static bool AreSymmetrical(const ValueList &VL0,
    251                            const ValueList &VL1) {
    252   if (VL0.size() != VL1.size()) {
    253     LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
    254                       << VL0.size() << " != " << VL1.size() << "\n");
    255     return false;
    256   }
    257 
    258   const unsigned Pairs = VL0.size();
    259   LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
    260 
    261   for (unsigned i = 0; i < Pairs; ++i) {
    262     const Value *V0 = VL0[i];
    263     const Value *V1 = VL1[i];
    264     const auto *Inst0 = dyn_cast<Instruction>(V0);
    265     const auto *Inst1 = dyn_cast<Instruction>(V1);
    266 
    267     LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
    268                dbgs() << "mul1: "; V0->dump();
    269                dbgs() << "mul2: "; V1->dump());
    270 
    271     if (!Inst0 || !Inst1)
    272       return false;
    273 
    274     if (Inst0->isSameOperationAs(Inst1)) {
    275       LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
    276       continue;
    277     }
    278 
    279     const APInt *C0, *C1;
    280     if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
    281       return false;
    282   }
    283 
    284   LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
    285   return true;
    286 }
    287 
    288 template<typename MemInst>
    289 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
    290                                   MemInstList &VecMem, const DataLayout &DL,
    291                                   ScalarEvolution &SE) {
    292   if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
    293     LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
    294     return false;
    295   }
    296   if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
    297     VecMem.push_back(MemOp0);
    298     VecMem.push_back(MemOp1);
    299     LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
    300     return true;
    301   }
    302   LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
    303   return false;
    304 }
    305 
    306 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
    307                                         MemInstList &VecMem) {
    308   if (!Ld0 || !Ld1)
    309     return false;
    310 
    311   LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
    312     dbgs() << "Ld0:"; Ld0->dump();
    313     dbgs() << "Ld1:"; Ld1->dump();
    314   );
    315 
    316   if (!Ld0->hasOneUse() || !Ld1->hasOneUse()) {
    317     LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
    318     return false;
    319   }
    320 
    321   return AreSequentialAccesses<LoadInst>(Ld0, Ld1, VecMem, *DL, *SE);
    322 }
    323 
    324 PMACPairList
    325 ARMParallelDSP::CreateParallelMACPairs(OpChainList &Candidates) {
    326   const unsigned Elems = Candidates.size();
    327   PMACPairList PMACPairs;
    328 
    329   if (Elems < 2)
    330     return PMACPairs;
    331 
    332   // TODO: for now we simply try to match consecutive pairs i and i+1.
    333   // We can compare all elements, but then we need to compare and evaluate
    334   // different solutions.
    335   for(unsigned i=0; i<Elems-1; i+=2) {
    336     BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
    337     BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[i+1].get());
    338     const Instruction *Mul0 = PMul0->Root;
    339     const Instruction *Mul1 = PMul1->Root;
    340 
    341     if (Mul0 == Mul1)
    342       continue;
    343 
    344     LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
    345                dbgs() << "- "; Mul0->dump();
    346                dbgs() << "- "; Mul1->dump());
    347 
    348     const ValueList &Mul0_LHS = PMul0->LHS;
    349     const ValueList &Mul0_RHS = PMul0->RHS;
    350     const ValueList &Mul1_LHS = PMul1->LHS;
    351     const ValueList &Mul1_RHS = PMul1->RHS;
    352 
    353     if (!AreSymmetrical(Mul0_LHS, Mul1_LHS) ||
    354         !AreSymmetrical(Mul0_RHS, Mul1_RHS))
    355       continue;
    356 
    357     LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
    358     // The first elements of each vector should be loads with sexts. If we find
    359     // that its two pairs of consecutive loads, then these can be transformed
    360     // into two wider loads and the users can be replaced with DSP
    361     // intrinsics.
    362     for (unsigned x = 0; x < Mul0_LHS.size(); x += 2) {
    363       auto *Ld0 = dyn_cast<LoadInst>(Mul0_LHS[x]);
    364       auto *Ld1 = dyn_cast<LoadInst>(Mul1_LHS[x]);
    365       auto *Ld2 = dyn_cast<LoadInst>(Mul0_RHS[x]);
    366       auto *Ld3 = dyn_cast<LoadInst>(Mul1_RHS[x]);
    367 
    368       LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n";
    369                  dbgs() << "\t mul1: "; Mul0_LHS[x]->dump();
    370                  dbgs() << "\t mul2: "; Mul1_LHS[x]->dump();
    371                  dbgs() << "and operands " << x + 2 << ":\n";
    372                  dbgs() << "\t mul1: "; Mul0_RHS[x]->dump();
    373                  dbgs() << "\t mul2: "; Mul1_RHS[x]->dump());
    374 
    375       if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd) &&
    376           AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
    377         LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
    378         PMACPairs.push_back(std::make_pair(PMul0, PMul1));
    379       }
    380     }
    381   }
    382   return PMACPairs;
    383 }
    384 
    385 bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction,
    386                                         PMACPairList &PMACPairs) {
    387   Instruction *Acc = Reduction.Phi;
    388   Instruction *InsertAfter = Reduction.AccIntAdd;
    389 
    390   for (auto &Pair : PMACPairs) {
    391     LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
    392                dbgs() << "- "; Pair.first->Root->dump();
    393                dbgs() << "- "; Pair.second->Root->dump());
    394     auto *VecLd0 = cast<LoadInst>(Pair.first->VecLd[0]);
    395     auto *VecLd1 = cast<LoadInst>(Pair.second->VecLd[0]);
    396     Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, InsertAfter);
    397     InsertAfter = Acc;
    398   }
    399 
    400   if (Acc != Reduction.Phi) {
    401     LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
    402     Reduction.AccIntAdd->replaceAllUsesWith(Acc);
    403     return true;
    404   }
    405   return false;
    406 }
    407 
    408 static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
    409                             ReductionList &Reductions) {
    410   RecurrenceDescriptor RecDesc;
    411   const bool HasFnNoNaNAttr =
    412     F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
    413   const BasicBlock *Latch = TheLoop->getLoopLatch();
    414 
    415   // We need a preheader as getIncomingValueForBlock assumes there is one.
    416   if (!TheLoop->getLoopPreheader()) {
    417     LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
    418     return;
    419   }
    420 
    421   for (PHINode &Phi : Header->phis()) {
    422     const auto *Ty = Phi.getType();
    423     if (!Ty->isIntegerTy(32))
    424       continue;
    425 
    426     const bool IsReduction =
    427       RecurrenceDescriptor::AddReductionVar(&Phi,
    428                                             RecurrenceDescriptor::RK_IntegerAdd,
    429                                             TheLoop, HasFnNoNaNAttr, RecDesc);
    430     if (!IsReduction)
    431       continue;
    432 
    433     Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
    434     if (!Acc)
    435       continue;
    436 
    437     Reductions.push_back(Reduction(&Phi, Acc));
    438   }
    439 
    440   LLVM_DEBUG(
    441     dbgs() << "\nAccumulating integer additions (reductions) found:\n";
    442     for (auto &R : Reductions) {
    443       dbgs() << "-  "; R.Phi->dump();
    444       dbgs() << "-> "; R.AccIntAdd->dump();
    445     }
    446   );
    447 }
    448 
    449 static void AddMACCandidate(OpChainList &Candidates,
    450                             const Instruction *Acc,
    451                             Value *MulOp0, Value *MulOp1, int MulOpNum) {
    452   Instruction *Mul = dyn_cast<Instruction>(Acc->getOperand(MulOpNum));
    453   LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
    454   ValueList LHS;
    455   ValueList RHS;
    456   if (IsNarrowSequence<16>(MulOp0, LHS) &&
    457       IsNarrowSequence<16>(MulOp1, RHS)) {
    458     LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump());
    459     Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
    460   }
    461 }
    462 
    463 static void MatchParallelMACSequences(Reduction &R,
    464                                       OpChainList &Candidates) {
    465   const Instruction *Acc = R.AccIntAdd;
    466   Value *A, *MulOp0, *MulOp1;
    467   LLVM_DEBUG(dbgs() << "\n- Analysing:\t"; Acc->dump());
    468 
    469   // Pattern 1: the accumulator is the RHS of the mul.
    470   while(match(Acc, m_Add(m_Mul(m_Value(MulOp0), m_Value(MulOp1)),
    471                          m_Value(A)))){
    472     AddMACCandidate(Candidates, Acc, MulOp0, MulOp1, 0);
    473     Acc = dyn_cast<Instruction>(A);
    474   }
    475   // Pattern 2: the accumulator is the LHS of the mul.
    476   while(match(Acc, m_Add(m_Value(A),
    477                          m_Mul(m_Value(MulOp0), m_Value(MulOp1))))) {
    478     AddMACCandidate(Candidates, Acc, MulOp0, MulOp1, 1);
    479     Acc = dyn_cast<Instruction>(A);
    480   }
    481 
    482   // The last mul in the chain has a slightly different pattern:
    483   // the mul is the first operand
    484   if (match(Acc, m_Add(m_Mul(m_Value(MulOp0), m_Value(MulOp1)), m_Value(A))))
    485     AddMACCandidate(Candidates, Acc, MulOp0, MulOp1, 0);
    486 
    487   // Because we start at the bottom of the chain, and we work our way up,
    488   // the muls are added in reverse program order to the list.
    489   std::reverse(Candidates.begin(), Candidates.end());
    490 }
    491 
    492 // Collects all instructions that are not part of the MAC chains, which is the
    493 // set of instructions that can potentially alias with the MAC operands.
    494 static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
    495                             Instructions &Writes) {
    496   for (auto &I : *Header) {
    497     if (I.mayReadFromMemory())
    498       Reads.push_back(&I);
    499     if (I.mayWriteToMemory())
    500       Writes.push_back(&I);
    501   }
    502 }
    503 
    504 // Check whether statements in the basic block that write to memory alias with
    505 // the memory locations accessed by the MAC-chains.
    506 // TODO: we need the read statements when we accept more complicated chains.
    507 static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
    508                        Instructions &Writes, OpChainList &MACCandidates) {
    509   LLVM_DEBUG(dbgs() << "Alias checks:\n");
    510   for (auto &MAC : MACCandidates) {
    511     LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
    512 
    513     // At the moment, we allow only simple chains that only consist of reads,
    514     // accumulate their result with an integer add, and thus that don't write
    515     // memory, and simply bail if they do.
    516     if (!MAC->ReadOnly)
    517       return true;
    518 
    519     // Now for all writes in the basic block, check that they don't alias with
    520     // the memory locations accessed by our MAC-chain:
    521     for (auto *I : Writes) {
    522       LLVM_DEBUG(dbgs() << "- "; I->dump());
    523       assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
    524       for (auto &MemLoc : MAC->MemLocs) {
    525         if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
    526                                           ModRefInfo::ModRef))) {
    527           LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
    528           return true;
    529         }
    530       }
    531     }
    532   }
    533 
    534   LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
    535   return false;
    536 }
    537 
    538 static bool CheckMACMemory(OpChainList &Candidates) {
    539   for (auto &C : Candidates) {
    540     // A mul has 2 operands, and a narrow op consist of sext and a load; thus
    541     // we expect at least 4 items in this operand value list.
    542     if (C->size() < 4) {
    543       LLVM_DEBUG(dbgs() << "Operand list too short.\n");
    544       return false;
    545     }
    546     C->SetMemoryLocations();
    547     ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
    548     ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
    549 
    550     // Use +=2 to skip over the expected extend instructions.
    551     for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
    552       if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
    553         return false;
    554     }
    555   }
    556   return true;
    557 }
    558 
    559 // Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
    560 // multiplications.
    561 // To use SMLAD:
    562 // 1) we first need to find integer add reduction PHIs,
    563 // 2) then from the PHI, look for this pattern:
    564 //
    565 // acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
    566 // ld0 = load i16
    567 // sext0 = sext i16 %ld0 to i32
    568 // ld1 = load i16
    569 // sext1 = sext i16 %ld1 to i32
    570 // mul0 = mul %sext0, %sext1
    571 // ld2 = load i16
    572 // sext2 = sext i16 %ld2 to i32
    573 // ld3 = load i16
    574 // sext3 = sext i16 %ld3 to i32
    575 // mul1 = mul i32 %sext2, %sext3
    576 // add0 = add i32 %mul0, %acc0
    577 // acc1 = add i32 %add0, %mul1
    578 //
    579 // Which can be selected to:
    580 //
    581 // ldr.h r0
    582 // ldr.h r1
    583 // smlad r2, r0, r1, r2
    584 //
    585 // If constants are used instead of loads, these will need to be hoisted
    586 // out and into a register.
    587 //
    588 // If loop invariants are used instead of loads, these need to be packed
    589 // before the loop begins.
    590 //
    591 bool ARMParallelDSP::MatchSMLAD(Function &F) {
    592   BasicBlock *Header = L->getHeader();
    593   LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
    594              dbgs() << "Header block:\n"; Header->dump();
    595              dbgs() << "Loop info:\n\n"; L->dump());
    596 
    597   bool Changed = false;
    598   ReductionList Reductions;
    599   MatchReductions(F, L, Header, Reductions);
    600 
    601   for (auto &R : Reductions) {
    602     OpChainList MACCandidates;
    603     MatchParallelMACSequences(R, MACCandidates);
    604     if (!CheckMACMemory(MACCandidates))
    605       continue;
    606 
    607     R.MACCandidates = std::move(MACCandidates);
    608 
    609     LLVM_DEBUG(dbgs() << "MAC candidates:\n";
    610       for (auto &M : R.MACCandidates)
    611         M->Root->dump();
    612       dbgs() << "\n";);
    613   }
    614 
    615   // Collect all instructions that may read or write memory. Our alias
    616   // analysis checks bail out if any of these instructions aliases with an
    617   // instruction from the MAC-chain.
    618   Instructions Reads, Writes;
    619   AliasCandidates(Header, Reads, Writes);
    620 
    621   for (auto &R : Reductions) {
    622     if (AreAliased(AA, Reads, Writes, R.MACCandidates))
    623       return false;
    624     PMACPairList PMACPairs = CreateParallelMACPairs(R.MACCandidates);
    625     Changed |= InsertParallelMACs(R, PMACPairs);
    626   }
    627 
    628   LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
    629   return Changed;
    630 }
    631 
    632 static void CreateLoadIns(IRBuilder<NoFolder> &IRB, Instruction *Acc,
    633                           LoadInst **VecLd) {
    634   const Type *AccTy = Acc->getType();
    635   const unsigned AddrSpace = (*VecLd)->getPointerAddressSpace();
    636 
    637   Value *VecPtr = IRB.CreateBitCast((*VecLd)->getPointerOperand(),
    638                                     AccTy->getPointerTo(AddrSpace));
    639   *VecLd = IRB.CreateAlignedLoad(VecPtr, (*VecLd)->getAlignment());
    640 }
    641 
    642 Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
    643                                              Instruction *Acc,
    644                                              Instruction *InsertAfter) {
    645   LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n";
    646              dbgs() << "- "; VecLd0->dump();
    647              dbgs() << "- "; VecLd1->dump();
    648              dbgs() << "- "; Acc->dump());
    649 
    650   IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
    651                               ++BasicBlock::iterator(InsertAfter));
    652 
    653   // Replace the reduction chain with an intrinsic call
    654   CreateLoadIns(Builder, Acc, &VecLd0);
    655   CreateLoadIns(Builder, Acc, &VecLd1);
    656   Value* Args[] = { VecLd0, VecLd1, Acc };
    657   Function *SMLAD = Intrinsic::getDeclaration(M, Intrinsic::arm_smlad);
    658   CallInst *Call = Builder.CreateCall(SMLAD, Args);
    659   NumSMLAD++;
    660   return Call;
    661 }
    662 
    663 Pass *llvm::createARMParallelDSPPass() {
    664   return new ARMParallelDSP();
    665 }
    666 
    667 char ARMParallelDSP::ID = 0;
    668 
    669 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
    670                 "Transform loops to use DSP intrinsics", false, false)
    671 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
    672                 "Transform loops to use DSP intrinsics", false, false)
    673