Home | History | Annotate | Download | only in CodeGen
      1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
      2 //                                    instrinsics
      3 //
      4 //                     The LLVM Compiler Infrastructure
      5 //
      6 // This file is distributed under the University of Illinois Open Source
      7 // License. See LICENSE.TXT for details.
      8 //
      9 //===----------------------------------------------------------------------===//
     10 //
     11 // This pass replaces masked memory intrinsics - when unsupported by the target
     12 // - with a chain of basic blocks, that deal with the elements one-by-one if the
     13 // appropriate mask bit is set.
     14 //
     15 //===----------------------------------------------------------------------===//
     16 
     17 #include "llvm/ADT/Twine.h"
     18 #include "llvm/Analysis/TargetTransformInfo.h"
     19 #include "llvm/CodeGen/TargetSubtargetInfo.h"
     20 #include "llvm/IR/BasicBlock.h"
     21 #include "llvm/IR/Constant.h"
     22 #include "llvm/IR/Constants.h"
     23 #include "llvm/IR/DerivedTypes.h"
     24 #include "llvm/IR/Function.h"
     25 #include "llvm/IR/IRBuilder.h"
     26 #include "llvm/IR/InstrTypes.h"
     27 #include "llvm/IR/Instruction.h"
     28 #include "llvm/IR/Instructions.h"
     29 #include "llvm/IR/IntrinsicInst.h"
     30 #include "llvm/IR/Intrinsics.h"
     31 #include "llvm/IR/Type.h"
     32 #include "llvm/IR/Value.h"
     33 #include "llvm/Pass.h"
     34 #include "llvm/Support/Casting.h"
     35 #include <algorithm>
     36 #include <cassert>
     37 
     38 using namespace llvm;
     39 
     40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
     41 
     42 namespace {
     43 
     44 class ScalarizeMaskedMemIntrin : public FunctionPass {
     45   const TargetTransformInfo *TTI = nullptr;
     46 
     47 public:
     48   static char ID; // Pass identification, replacement for typeid
     49 
     50   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
     51     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
     52   }
     53 
     54   bool runOnFunction(Function &F) override;
     55 
     56   StringRef getPassName() const override {
     57     return "Scalarize Masked Memory Intrinsics";
     58   }
     59 
     60   void getAnalysisUsage(AnalysisUsage &AU) const override {
     61     AU.addRequired<TargetTransformInfoWrapperPass>();
     62   }
     63 
     64 private:
     65   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
     66   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
     67 };
     68 
     69 } // end anonymous namespace
     70 
     71 char ScalarizeMaskedMemIntrin::ID = 0;
     72 
     73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
     74                 "Scalarize unsupported masked memory intrinsics", false, false)
     75 
     76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
     77   return new ScalarizeMaskedMemIntrin();
     78 }
     79 
     80 // Translate a masked load intrinsic like
     81 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
     82 //                               <16 x i1> %mask, <16 x i32> %passthru)
     83 // to a chain of basic blocks, with loading element one-by-one if
     84 // the appropriate mask bit is set
     85 //
     86 //  %1 = bitcast i8* %addr to i32*
     87 //  %2 = extractelement <16 x i1> %mask, i32 0
     88 //  %3 = icmp eq i1 %2, true
     89 //  br i1 %3, label %cond.load, label %else
     90 //
     91 // cond.load:                                        ; preds = %0
     92 //  %4 = getelementptr i32* %1, i32 0
     93 //  %5 = load i32* %4
     94 //  %6 = insertelement <16 x i32> undef, i32 %5, i32 0
     95 //  br label %else
     96 //
     97 // else:                                             ; preds = %0, %cond.load
     98 //  %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
     99 //  %7 = extractelement <16 x i1> %mask, i32 1
    100 //  %8 = icmp eq i1 %7, true
    101 //  br i1 %8, label %cond.load1, label %else2
    102 //
    103 // cond.load1:                                       ; preds = %else
    104 //  %9 = getelementptr i32* %1, i32 1
    105 //  %10 = load i32* %9
    106 //  %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
    107 //  br label %else2
    108 //
    109 // else2:                                          ; preds = %else, %cond.load1
    110 //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
    111 //  %12 = extractelement <16 x i1> %mask, i32 2
    112 //  %13 = icmp eq i1 %12, true
    113 //  br i1 %13, label %cond.load4, label %else5
    114 //
    115 static void scalarizeMaskedLoad(CallInst *CI) {
    116   Value *Ptr = CI->getArgOperand(0);
    117   Value *Alignment = CI->getArgOperand(1);
    118   Value *Mask = CI->getArgOperand(2);
    119   Value *Src0 = CI->getArgOperand(3);
    120 
    121   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
    122   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
    123   assert(VecType && "Unexpected return type of masked load intrinsic");
    124 
    125   Type *EltTy = CI->getType()->getVectorElementType();
    126 
    127   IRBuilder<> Builder(CI->getContext());
    128   Instruction *InsertPt = CI;
    129   BasicBlock *IfBlock = CI->getParent();
    130   BasicBlock *CondBlock = nullptr;
    131   BasicBlock *PrevIfBlock = CI->getParent();
    132 
    133   Builder.SetInsertPoint(InsertPt);
    134   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    135 
    136   // Short-cut if the mask is all-true.
    137   bool IsAllOnesMask =
    138       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
    139 
    140   if (IsAllOnesMask) {
    141     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
    142     CI->replaceAllUsesWith(NewI);
    143     CI->eraseFromParent();
    144     return;
    145   }
    146 
    147   // Adjust alignment for the scalar instruction.
    148   AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
    149   // Bitcast %addr fron i8* to EltTy*
    150   Type *NewPtrType =
    151       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
    152   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
    153   unsigned VectorWidth = VecType->getNumElements();
    154 
    155   Value *UndefVal = UndefValue::get(VecType);
    156 
    157   // The result vector
    158   Value *VResult = UndefVal;
    159 
    160   if (isa<ConstantVector>(Mask)) {
    161     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    162       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
    163         continue;
    164       Value *Gep =
    165           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
    166       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
    167       VResult =
    168           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
    169     }
    170     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
    171     CI->replaceAllUsesWith(NewI);
    172     CI->eraseFromParent();
    173     return;
    174   }
    175 
    176   PHINode *Phi = nullptr;
    177   Value *PrevPhi = UndefVal;
    178 
    179   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    180     // Fill the "else" block, created in the previous iteration
    181     //
    182     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
    183     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
    184     //  %to_load = icmp eq i1 %mask_1, true
    185     //  br i1 %to_load, label %cond.load, label %else
    186     //
    187     if (Idx > 0) {
    188       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
    189       Phi->addIncoming(VResult, CondBlock);
    190       Phi->addIncoming(PrevPhi, PrevIfBlock);
    191       PrevPhi = Phi;
    192       VResult = Phi;
    193     }
    194 
    195     Value *Predicate =
    196         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
    197     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
    198                                     ConstantInt::get(Predicate->getType(), 1));
    199 
    200     // Create "cond" block
    201     //
    202     //  %EltAddr = getelementptr i32* %1, i32 0
    203     //  %Elt = load i32* %EltAddr
    204     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
    205     //
    206     CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
    207     Builder.SetInsertPoint(InsertPt);
    208 
    209     Value *Gep =
    210         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
    211     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
    212     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
    213 
    214     // Create "else" block, fill it in the next iteration
    215     BasicBlock *NewIfBlock =
    216         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
    217     Builder.SetInsertPoint(InsertPt);
    218     Instruction *OldBr = IfBlock->getTerminator();
    219     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
    220     OldBr->eraseFromParent();
    221     PrevIfBlock = IfBlock;
    222     IfBlock = NewIfBlock;
    223   }
    224 
    225   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
    226   Phi->addIncoming(VResult, CondBlock);
    227   Phi->addIncoming(PrevPhi, PrevIfBlock);
    228   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
    229   CI->replaceAllUsesWith(NewI);
    230   CI->eraseFromParent();
    231 }
    232 
    233 // Translate a masked store intrinsic, like
    234 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
    235 //                               <16 x i1> %mask)
    236 // to a chain of basic blocks, that stores element one-by-one if
    237 // the appropriate mask bit is set
    238 //
    239 //   %1 = bitcast i8* %addr to i32*
    240 //   %2 = extractelement <16 x i1> %mask, i32 0
    241 //   %3 = icmp eq i1 %2, true
    242 //   br i1 %3, label %cond.store, label %else
    243 //
    244 // cond.store:                                       ; preds = %0
    245 //   %4 = extractelement <16 x i32> %val, i32 0
    246 //   %5 = getelementptr i32* %1, i32 0
    247 //   store i32 %4, i32* %5
    248 //   br label %else
    249 //
    250 // else:                                             ; preds = %0, %cond.store
    251 //   %6 = extractelement <16 x i1> %mask, i32 1
    252 //   %7 = icmp eq i1 %6, true
    253 //   br i1 %7, label %cond.store1, label %else2
    254 //
    255 // cond.store1:                                      ; preds = %else
    256 //   %8 = extractelement <16 x i32> %val, i32 1
    257 //   %9 = getelementptr i32* %1, i32 1
    258 //   store i32 %8, i32* %9
    259 //   br label %else2
    260 //   . . .
    261 static void scalarizeMaskedStore(CallInst *CI) {
    262   Value *Src = CI->getArgOperand(0);
    263   Value *Ptr = CI->getArgOperand(1);
    264   Value *Alignment = CI->getArgOperand(2);
    265   Value *Mask = CI->getArgOperand(3);
    266 
    267   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
    268   VectorType *VecType = dyn_cast<VectorType>(Src->getType());
    269   assert(VecType && "Unexpected data type in masked store intrinsic");
    270 
    271   Type *EltTy = VecType->getElementType();
    272 
    273   IRBuilder<> Builder(CI->getContext());
    274   Instruction *InsertPt = CI;
    275   BasicBlock *IfBlock = CI->getParent();
    276   Builder.SetInsertPoint(InsertPt);
    277   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    278 
    279   // Short-cut if the mask is all-true.
    280   bool IsAllOnesMask =
    281       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
    282 
    283   if (IsAllOnesMask) {
    284     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
    285     CI->eraseFromParent();
    286     return;
    287   }
    288 
    289   // Adjust alignment for the scalar instruction.
    290   AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
    291   // Bitcast %addr fron i8* to EltTy*
    292   Type *NewPtrType =
    293       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
    294   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
    295   unsigned VectorWidth = VecType->getNumElements();
    296 
    297   if (isa<ConstantVector>(Mask)) {
    298     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    299       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
    300         continue;
    301       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
    302       Value *Gep =
    303           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
    304       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
    305     }
    306     CI->eraseFromParent();
    307     return;
    308   }
    309 
    310   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    311     // Fill the "else" block, created in the previous iteration
    312     //
    313     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
    314     //  %to_store = icmp eq i1 %mask_1, true
    315     //  br i1 %to_store, label %cond.store, label %else
    316     //
    317     Value *Predicate =
    318         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
    319     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
    320                                     ConstantInt::get(Predicate->getType(), 1));
    321 
    322     // Create "cond" block
    323     //
    324     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
    325     //  %EltAddr = getelementptr i32* %1, i32 0
    326     //  %store i32 %OneElt, i32* %EltAddr
    327     //
    328     BasicBlock *CondBlock =
    329         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
    330     Builder.SetInsertPoint(InsertPt);
    331 
    332     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
    333     Value *Gep =
    334         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
    335     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
    336 
    337     // Create "else" block, fill it in the next iteration
    338     BasicBlock *NewIfBlock =
    339         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
    340     Builder.SetInsertPoint(InsertPt);
    341     Instruction *OldBr = IfBlock->getTerminator();
    342     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
    343     OldBr->eraseFromParent();
    344     IfBlock = NewIfBlock;
    345   }
    346   CI->eraseFromParent();
    347 }
    348 
    349 // Translate a masked gather intrinsic like
    350 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
    351 //                               <16 x i1> %Mask, <16 x i32> %Src)
    352 // to a chain of basic blocks, with loading element one-by-one if
    353 // the appropriate mask bit is set
    354 //
    355 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
    356 // % Mask0 = extractelement <16 x i1> %Mask, i32 0
    357 // % ToLoad0 = icmp eq i1 % Mask0, true
    358 // br i1 % ToLoad0, label %cond.load, label %else
    359 //
    360 // cond.load:
    361 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
    362 // % Load0 = load i32, i32* % Ptr0, align 4
    363 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
    364 // br label %else
    365 //
    366 // else:
    367 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
    368 // % Mask1 = extractelement <16 x i1> %Mask, i32 1
    369 // % ToLoad1 = icmp eq i1 % Mask1, true
    370 // br i1 % ToLoad1, label %cond.load1, label %else2
    371 //
    372 // cond.load1:
    373 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
    374 // % Load1 = load i32, i32* % Ptr1, align 4
    375 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
    376 // br label %else2
    377 // . . .
    378 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
    379 // ret <16 x i32> %Result
    380 static void scalarizeMaskedGather(CallInst *CI) {
    381   Value *Ptrs = CI->getArgOperand(0);
    382   Value *Alignment = CI->getArgOperand(1);
    383   Value *Mask = CI->getArgOperand(2);
    384   Value *Src0 = CI->getArgOperand(3);
    385 
    386   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
    387 
    388   assert(VecType && "Unexpected return type of masked load intrinsic");
    389 
    390   IRBuilder<> Builder(CI->getContext());
    391   Instruction *InsertPt = CI;
    392   BasicBlock *IfBlock = CI->getParent();
    393   BasicBlock *CondBlock = nullptr;
    394   BasicBlock *PrevIfBlock = CI->getParent();
    395   Builder.SetInsertPoint(InsertPt);
    396   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
    397 
    398   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    399 
    400   Value *UndefVal = UndefValue::get(VecType);
    401 
    402   // The result vector
    403   Value *VResult = UndefVal;
    404   unsigned VectorWidth = VecType->getNumElements();
    405 
    406   // Shorten the way if the mask is a vector of constants.
    407   bool IsConstMask = isa<ConstantVector>(Mask);
    408 
    409   if (IsConstMask) {
    410     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    411       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
    412         continue;
    413       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
    414                                                 "Ptr" + Twine(Idx));
    415       LoadInst *Load =
    416           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
    417       VResult = Builder.CreateInsertElement(
    418           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
    419     }
    420     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
    421     CI->replaceAllUsesWith(NewI);
    422     CI->eraseFromParent();
    423     return;
    424   }
    425 
    426   PHINode *Phi = nullptr;
    427   Value *PrevPhi = UndefVal;
    428 
    429   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    430     // Fill the "else" block, created in the previous iteration
    431     //
    432     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
    433     //  %ToLoad1 = icmp eq i1 %Mask1, true
    434     //  br i1 %ToLoad1, label %cond.load, label %else
    435     //
    436     if (Idx > 0) {
    437       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
    438       Phi->addIncoming(VResult, CondBlock);
    439       Phi->addIncoming(PrevPhi, PrevIfBlock);
    440       PrevPhi = Phi;
    441       VResult = Phi;
    442     }
    443 
    444     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
    445                                                     "Mask" + Twine(Idx));
    446     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
    447                                     ConstantInt::get(Predicate->getType(), 1),
    448                                     "ToLoad" + Twine(Idx));
    449 
    450     // Create "cond" block
    451     //
    452     //  %EltAddr = getelementptr i32* %1, i32 0
    453     //  %Elt = load i32* %EltAddr
    454     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
    455     //
    456     CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
    457     Builder.SetInsertPoint(InsertPt);
    458 
    459     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
    460                                               "Ptr" + Twine(Idx));
    461     LoadInst *Load =
    462         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
    463     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
    464                                           "Res" + Twine(Idx));
    465 
    466     // Create "else" block, fill it in the next iteration
    467     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
    468     Builder.SetInsertPoint(InsertPt);
    469     Instruction *OldBr = IfBlock->getTerminator();
    470     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
    471     OldBr->eraseFromParent();
    472     PrevIfBlock = IfBlock;
    473     IfBlock = NewIfBlock;
    474   }
    475 
    476   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
    477   Phi->addIncoming(VResult, CondBlock);
    478   Phi->addIncoming(PrevPhi, PrevIfBlock);
    479   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
    480   CI->replaceAllUsesWith(NewI);
    481   CI->eraseFromParent();
    482 }
    483 
    484 // Translate a masked scatter intrinsic, like
    485 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
    486 //                                  <16 x i1> %Mask)
    487 // to a chain of basic blocks, that stores element one-by-one if
    488 // the appropriate mask bit is set.
    489 //
    490 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
    491 // % Mask0 = extractelement <16 x i1> % Mask, i32 0
    492 // % ToStore0 = icmp eq i1 % Mask0, true
    493 // br i1 %ToStore0, label %cond.store, label %else
    494 //
    495 // cond.store:
    496 // % Elt0 = extractelement <16 x i32> %Src, i32 0
    497 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
    498 // store i32 %Elt0, i32* % Ptr0, align 4
    499 // br label %else
    500 //
    501 // else:
    502 // % Mask1 = extractelement <16 x i1> % Mask, i32 1
    503 // % ToStore1 = icmp eq i1 % Mask1, true
    504 // br i1 % ToStore1, label %cond.store1, label %else2
    505 //
    506 // cond.store1:
    507 // % Elt1 = extractelement <16 x i32> %Src, i32 1
    508 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
    509 // store i32 % Elt1, i32* % Ptr1, align 4
    510 // br label %else2
    511 //   . . .
    512 static void scalarizeMaskedScatter(CallInst *CI) {
    513   Value *Src = CI->getArgOperand(0);
    514   Value *Ptrs = CI->getArgOperand(1);
    515   Value *Alignment = CI->getArgOperand(2);
    516   Value *Mask = CI->getArgOperand(3);
    517 
    518   assert(isa<VectorType>(Src->getType()) &&
    519          "Unexpected data type in masked scatter intrinsic");
    520   assert(isa<VectorType>(Ptrs->getType()) &&
    521          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
    522          "Vector of pointers is expected in masked scatter intrinsic");
    523 
    524   IRBuilder<> Builder(CI->getContext());
    525   Instruction *InsertPt = CI;
    526   BasicBlock *IfBlock = CI->getParent();
    527   Builder.SetInsertPoint(InsertPt);
    528   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
    529 
    530   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
    531   unsigned VectorWidth = Src->getType()->getVectorNumElements();
    532 
    533   // Shorten the way if the mask is a vector of constants.
    534   bool IsConstMask = isa<ConstantVector>(Mask);
    535 
    536   if (IsConstMask) {
    537     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    538       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
    539         continue;
    540       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
    541                                                    "Elt" + Twine(Idx));
    542       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
    543                                                 "Ptr" + Twine(Idx));
    544       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
    545     }
    546     CI->eraseFromParent();
    547     return;
    548   }
    549   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
    550     // Fill the "else" block, created in the previous iteration
    551     //
    552     //  % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
    553     //  % ToStore = icmp eq i1 % Mask1, true
    554     //  br i1 % ToStore, label %cond.store, label %else
    555     //
    556     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
    557                                                     "Mask" + Twine(Idx));
    558     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
    559                                     ConstantInt::get(Predicate->getType(), 1),
    560                                     "ToStore" + Twine(Idx));
    561 
    562     // Create "cond" block
    563     //
    564     //  % Elt1 = extractelement <16 x i32> %Src, i32 1
    565     //  % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
    566     //  %store i32 % Elt1, i32* % Ptr1
    567     //
    568     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
    569     Builder.SetInsertPoint(InsertPt);
    570 
    571     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
    572                                                  "Elt" + Twine(Idx));
    573     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
    574                                               "Ptr" + Twine(Idx));
    575     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
    576 
    577     // Create "else" block, fill it in the next iteration
    578     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
    579     Builder.SetInsertPoint(InsertPt);
    580     Instruction *OldBr = IfBlock->getTerminator();
    581     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
    582     OldBr->eraseFromParent();
    583     IfBlock = NewIfBlock;
    584   }
    585   CI->eraseFromParent();
    586 }
    587 
    588 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
    589   bool EverMadeChange = false;
    590 
    591   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
    592 
    593   bool MadeChange = true;
    594   while (MadeChange) {
    595     MadeChange = false;
    596     for (Function::iterator I = F.begin(); I != F.end();) {
    597       BasicBlock *BB = &*I++;
    598       bool ModifiedDTOnIteration = false;
    599       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
    600 
    601       // Restart BB iteration if the dominator tree of the Function was changed
    602       if (ModifiedDTOnIteration)
    603         break;
    604     }
    605 
    606     EverMadeChange |= MadeChange;
    607   }
    608 
    609   return EverMadeChange;
    610 }
    611 
    612 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
    613   bool MadeChange = false;
    614 
    615   BasicBlock::iterator CurInstIterator = BB.begin();
    616   while (CurInstIterator != BB.end()) {
    617     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
    618       MadeChange |= optimizeCallInst(CI, ModifiedDT);
    619     if (ModifiedDT)
    620       return true;
    621   }
    622 
    623   return MadeChange;
    624 }
    625 
    626 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
    627                                                 bool &ModifiedDT) {
    628   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
    629   if (II) {
    630     switch (II->getIntrinsicID()) {
    631     default:
    632       break;
    633     case Intrinsic::masked_load:
    634       // Scalarize unsupported vector masked load
    635       if (!TTI->isLegalMaskedLoad(CI->getType())) {
    636         scalarizeMaskedLoad(CI);
    637         ModifiedDT = true;
    638         return true;
    639       }
    640       return false;
    641     case Intrinsic::masked_store:
    642       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
    643         scalarizeMaskedStore(CI);
    644         ModifiedDT = true;
    645         return true;
    646       }
    647       return false;
    648     case Intrinsic::masked_gather:
    649       if (!TTI->isLegalMaskedGather(CI->getType())) {
    650         scalarizeMaskedGather(CI);
    651         ModifiedDT = true;
    652         return true;
    653       }
    654       return false;
    655     case Intrinsic::masked_scatter:
    656       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
    657         scalarizeMaskedScatter(CI);
    658         ModifiedDT = true;
    659         return true;
    660       }
    661       return false;
    662     }
    663   }
    664 
    665   return false;
    666 }
    667