1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // instrinsics 3 // 4 // The LLVM Compiler Infrastructure 5 // 6 // This file is distributed under the University of Illinois Open Source 7 // License. See LICENSE.TXT for details. 8 // 9 //===----------------------------------------------------------------------===// 10 // 11 // This pass replaces masked memory intrinsics - when unsupported by the target 12 // - with a chain of basic blocks, that deal with the elements one-by-one if the 13 // appropriate mask bit is set. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/CodeGen/TargetSubtargetInfo.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/Constant.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/DerivedTypes.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/InstrTypes.h" 27 #include "llvm/IR/Instruction.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/IntrinsicInst.h" 30 #include "llvm/IR/Intrinsics.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/IR/Value.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Casting.h" 35 #include <algorithm> 36 #include <cassert> 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 41 42 namespace { 43 44 class ScalarizeMaskedMemIntrin : public FunctionPass { 45 const TargetTransformInfo *TTI = nullptr; 46 47 public: 48 static char ID; // Pass identification, replacement for typeid 49 50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) { 51 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry()); 52 } 53 54 bool runOnFunction(Function &F) override; 55 56 StringRef getPassName() const override { 57 return "Scalarize Masked Memory Intrinsics"; 58 } 59 60 void getAnalysisUsage(AnalysisUsage &AU) const override { 61 AU.addRequired<TargetTransformInfoWrapperPass>(); 62 } 63 64 private: 65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT); 66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); 67 }; 68 69 } // end anonymous namespace 70 71 char ScalarizeMaskedMemIntrin::ID = 0; 72 73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, 74 "Scalarize unsupported masked memory intrinsics", false, false) 75 76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() { 77 return new ScalarizeMaskedMemIntrin(); 78 } 79 80 // Translate a masked load intrinsic like 81 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 82 // <16 x i1> %mask, <16 x i32> %passthru) 83 // to a chain of basic blocks, with loading element one-by-one if 84 // the appropriate mask bit is set 85 // 86 // %1 = bitcast i8* %addr to i32* 87 // %2 = extractelement <16 x i1> %mask, i32 0 88 // %3 = icmp eq i1 %2, true 89 // br i1 %3, label %cond.load, label %else 90 // 91 // cond.load: ; preds = %0 92 // %4 = getelementptr i32* %1, i32 0 93 // %5 = load i32* %4 94 // %6 = insertelement <16 x i32> undef, i32 %5, i32 0 95 // br label %else 96 // 97 // else: ; preds = %0, %cond.load 98 // %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ] 99 // %7 = extractelement <16 x i1> %mask, i32 1 100 // %8 = icmp eq i1 %7, true 101 // br i1 %8, label %cond.load1, label %else2 102 // 103 // cond.load1: ; preds = %else 104 // %9 = getelementptr i32* %1, i32 1 105 // %10 = load i32* %9 106 // %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1 107 // br label %else2 108 // 109 // else2: ; preds = %else, %cond.load1 110 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 111 // %12 = extractelement <16 x i1> %mask, i32 2 112 // %13 = icmp eq i1 %12, true 113 // br i1 %13, label %cond.load4, label %else5 114 // 115 static void scalarizeMaskedLoad(CallInst *CI) { 116 Value *Ptr = CI->getArgOperand(0); 117 Value *Alignment = CI->getArgOperand(1); 118 Value *Mask = CI->getArgOperand(2); 119 Value *Src0 = CI->getArgOperand(3); 120 121 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 122 VectorType *VecType = dyn_cast<VectorType>(CI->getType()); 123 assert(VecType && "Unexpected return type of masked load intrinsic"); 124 125 Type *EltTy = CI->getType()->getVectorElementType(); 126 127 IRBuilder<> Builder(CI->getContext()); 128 Instruction *InsertPt = CI; 129 BasicBlock *IfBlock = CI->getParent(); 130 BasicBlock *CondBlock = nullptr; 131 BasicBlock *PrevIfBlock = CI->getParent(); 132 133 Builder.SetInsertPoint(InsertPt); 134 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 135 136 // Short-cut if the mask is all-true. 137 bool IsAllOnesMask = 138 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue(); 139 140 if (IsAllOnesMask) { 141 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal); 142 CI->replaceAllUsesWith(NewI); 143 CI->eraseFromParent(); 144 return; 145 } 146 147 // Adjust alignment for the scalar instruction. 148 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8); 149 // Bitcast %addr fron i8* to EltTy* 150 Type *NewPtrType = 151 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace()); 152 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 153 unsigned VectorWidth = VecType->getNumElements(); 154 155 Value *UndefVal = UndefValue::get(VecType); 156 157 // The result vector 158 Value *VResult = UndefVal; 159 160 if (isa<ConstantVector>(Mask)) { 161 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 162 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 163 continue; 164 Value *Gep = 165 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 166 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal); 167 VResult = 168 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx)); 169 } 170 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0); 171 CI->replaceAllUsesWith(NewI); 172 CI->eraseFromParent(); 173 return; 174 } 175 176 PHINode *Phi = nullptr; 177 Value *PrevPhi = UndefVal; 178 179 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 180 // Fill the "else" block, created in the previous iteration 181 // 182 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 183 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 184 // %to_load = icmp eq i1 %mask_1, true 185 // br i1 %to_load, label %cond.load, label %else 186 // 187 if (Idx > 0) { 188 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 189 Phi->addIncoming(VResult, CondBlock); 190 Phi->addIncoming(PrevPhi, PrevIfBlock); 191 PrevPhi = Phi; 192 VResult = Phi; 193 } 194 195 Value *Predicate = 196 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx)); 197 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 198 ConstantInt::get(Predicate->getType(), 1)); 199 200 // Create "cond" block 201 // 202 // %EltAddr = getelementptr i32* %1, i32 0 203 // %Elt = load i32* %EltAddr 204 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 205 // 206 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load"); 207 Builder.SetInsertPoint(InsertPt); 208 209 Value *Gep = 210 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 211 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal); 212 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx)); 213 214 // Create "else" block, fill it in the next iteration 215 BasicBlock *NewIfBlock = 216 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 217 Builder.SetInsertPoint(InsertPt); 218 Instruction *OldBr = IfBlock->getTerminator(); 219 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 220 OldBr->eraseFromParent(); 221 PrevIfBlock = IfBlock; 222 IfBlock = NewIfBlock; 223 } 224 225 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select"); 226 Phi->addIncoming(VResult, CondBlock); 227 Phi->addIncoming(PrevPhi, PrevIfBlock); 228 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0); 229 CI->replaceAllUsesWith(NewI); 230 CI->eraseFromParent(); 231 } 232 233 // Translate a masked store intrinsic, like 234 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 235 // <16 x i1> %mask) 236 // to a chain of basic blocks, that stores element one-by-one if 237 // the appropriate mask bit is set 238 // 239 // %1 = bitcast i8* %addr to i32* 240 // %2 = extractelement <16 x i1> %mask, i32 0 241 // %3 = icmp eq i1 %2, true 242 // br i1 %3, label %cond.store, label %else 243 // 244 // cond.store: ; preds = %0 245 // %4 = extractelement <16 x i32> %val, i32 0 246 // %5 = getelementptr i32* %1, i32 0 247 // store i32 %4, i32* %5 248 // br label %else 249 // 250 // else: ; preds = %0, %cond.store 251 // %6 = extractelement <16 x i1> %mask, i32 1 252 // %7 = icmp eq i1 %6, true 253 // br i1 %7, label %cond.store1, label %else2 254 // 255 // cond.store1: ; preds = %else 256 // %8 = extractelement <16 x i32> %val, i32 1 257 // %9 = getelementptr i32* %1, i32 1 258 // store i32 %8, i32* %9 259 // br label %else2 260 // . . . 261 static void scalarizeMaskedStore(CallInst *CI) { 262 Value *Src = CI->getArgOperand(0); 263 Value *Ptr = CI->getArgOperand(1); 264 Value *Alignment = CI->getArgOperand(2); 265 Value *Mask = CI->getArgOperand(3); 266 267 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 268 VectorType *VecType = dyn_cast<VectorType>(Src->getType()); 269 assert(VecType && "Unexpected data type in masked store intrinsic"); 270 271 Type *EltTy = VecType->getElementType(); 272 273 IRBuilder<> Builder(CI->getContext()); 274 Instruction *InsertPt = CI; 275 BasicBlock *IfBlock = CI->getParent(); 276 Builder.SetInsertPoint(InsertPt); 277 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 278 279 // Short-cut if the mask is all-true. 280 bool IsAllOnesMask = 281 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue(); 282 283 if (IsAllOnesMask) { 284 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 285 CI->eraseFromParent(); 286 return; 287 } 288 289 // Adjust alignment for the scalar instruction. 290 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8); 291 // Bitcast %addr fron i8* to EltTy* 292 Type *NewPtrType = 293 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace()); 294 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 295 unsigned VectorWidth = VecType->getNumElements(); 296 297 if (isa<ConstantVector>(Mask)) { 298 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 299 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 300 continue; 301 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx)); 302 Value *Gep = 303 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 304 Builder.CreateAlignedStore(OneElt, Gep, AlignVal); 305 } 306 CI->eraseFromParent(); 307 return; 308 } 309 310 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 311 // Fill the "else" block, created in the previous iteration 312 // 313 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 314 // %to_store = icmp eq i1 %mask_1, true 315 // br i1 %to_store, label %cond.store, label %else 316 // 317 Value *Predicate = 318 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx)); 319 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 320 ConstantInt::get(Predicate->getType(), 1)); 321 322 // Create "cond" block 323 // 324 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 325 // %EltAddr = getelementptr i32* %1, i32 0 326 // %store i32 %OneElt, i32* %EltAddr 327 // 328 BasicBlock *CondBlock = 329 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 330 Builder.SetInsertPoint(InsertPt); 331 332 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx)); 333 Value *Gep = 334 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 335 Builder.CreateAlignedStore(OneElt, Gep, AlignVal); 336 337 // Create "else" block, fill it in the next iteration 338 BasicBlock *NewIfBlock = 339 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 340 Builder.SetInsertPoint(InsertPt); 341 Instruction *OldBr = IfBlock->getTerminator(); 342 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 343 OldBr->eraseFromParent(); 344 IfBlock = NewIfBlock; 345 } 346 CI->eraseFromParent(); 347 } 348 349 // Translate a masked gather intrinsic like 350 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 351 // <16 x i1> %Mask, <16 x i32> %Src) 352 // to a chain of basic blocks, with loading element one-by-one if 353 // the appropriate mask bit is set 354 // 355 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 356 // % Mask0 = extractelement <16 x i1> %Mask, i32 0 357 // % ToLoad0 = icmp eq i1 % Mask0, true 358 // br i1 % ToLoad0, label %cond.load, label %else 359 // 360 // cond.load: 361 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 362 // % Load0 = load i32, i32* % Ptr0, align 4 363 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0 364 // br label %else 365 // 366 // else: 367 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0] 368 // % Mask1 = extractelement <16 x i1> %Mask, i32 1 369 // % ToLoad1 = icmp eq i1 % Mask1, true 370 // br i1 % ToLoad1, label %cond.load1, label %else2 371 // 372 // cond.load1: 373 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 374 // % Load1 = load i32, i32* % Ptr1, align 4 375 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1 376 // br label %else2 377 // . . . 378 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 379 // ret <16 x i32> %Result 380 static void scalarizeMaskedGather(CallInst *CI) { 381 Value *Ptrs = CI->getArgOperand(0); 382 Value *Alignment = CI->getArgOperand(1); 383 Value *Mask = CI->getArgOperand(2); 384 Value *Src0 = CI->getArgOperand(3); 385 386 VectorType *VecType = dyn_cast<VectorType>(CI->getType()); 387 388 assert(VecType && "Unexpected return type of masked load intrinsic"); 389 390 IRBuilder<> Builder(CI->getContext()); 391 Instruction *InsertPt = CI; 392 BasicBlock *IfBlock = CI->getParent(); 393 BasicBlock *CondBlock = nullptr; 394 BasicBlock *PrevIfBlock = CI->getParent(); 395 Builder.SetInsertPoint(InsertPt); 396 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 397 398 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 399 400 Value *UndefVal = UndefValue::get(VecType); 401 402 // The result vector 403 Value *VResult = UndefVal; 404 unsigned VectorWidth = VecType->getNumElements(); 405 406 // Shorten the way if the mask is a vector of constants. 407 bool IsConstMask = isa<ConstantVector>(Mask); 408 409 if (IsConstMask) { 410 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 411 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 412 continue; 413 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 414 "Ptr" + Twine(Idx)); 415 LoadInst *Load = 416 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx)); 417 VResult = Builder.CreateInsertElement( 418 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx)); 419 } 420 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0); 421 CI->replaceAllUsesWith(NewI); 422 CI->eraseFromParent(); 423 return; 424 } 425 426 PHINode *Phi = nullptr; 427 Value *PrevPhi = UndefVal; 428 429 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 430 // Fill the "else" block, created in the previous iteration 431 // 432 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 433 // %ToLoad1 = icmp eq i1 %Mask1, true 434 // br i1 %ToLoad1, label %cond.load, label %else 435 // 436 if (Idx > 0) { 437 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 438 Phi->addIncoming(VResult, CondBlock); 439 Phi->addIncoming(PrevPhi, PrevIfBlock); 440 PrevPhi = Phi; 441 VResult = Phi; 442 } 443 444 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx), 445 "Mask" + Twine(Idx)); 446 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 447 ConstantInt::get(Predicate->getType(), 1), 448 "ToLoad" + Twine(Idx)); 449 450 // Create "cond" block 451 // 452 // %EltAddr = getelementptr i32* %1, i32 0 453 // %Elt = load i32* %EltAddr 454 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 455 // 456 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); 457 Builder.SetInsertPoint(InsertPt); 458 459 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 460 "Ptr" + Twine(Idx)); 461 LoadInst *Load = 462 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx)); 463 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx), 464 "Res" + Twine(Idx)); 465 466 // Create "else" block, fill it in the next iteration 467 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 468 Builder.SetInsertPoint(InsertPt); 469 Instruction *OldBr = IfBlock->getTerminator(); 470 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 471 OldBr->eraseFromParent(); 472 PrevIfBlock = IfBlock; 473 IfBlock = NewIfBlock; 474 } 475 476 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select"); 477 Phi->addIncoming(VResult, CondBlock); 478 Phi->addIncoming(PrevPhi, PrevIfBlock); 479 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0); 480 CI->replaceAllUsesWith(NewI); 481 CI->eraseFromParent(); 482 } 483 484 // Translate a masked scatter intrinsic, like 485 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 486 // <16 x i1> %Mask) 487 // to a chain of basic blocks, that stores element one-by-one if 488 // the appropriate mask bit is set. 489 // 490 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 491 // % Mask0 = extractelement <16 x i1> % Mask, i32 0 492 // % ToStore0 = icmp eq i1 % Mask0, true 493 // br i1 %ToStore0, label %cond.store, label %else 494 // 495 // cond.store: 496 // % Elt0 = extractelement <16 x i32> %Src, i32 0 497 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 498 // store i32 %Elt0, i32* % Ptr0, align 4 499 // br label %else 500 // 501 // else: 502 // % Mask1 = extractelement <16 x i1> % Mask, i32 1 503 // % ToStore1 = icmp eq i1 % Mask1, true 504 // br i1 % ToStore1, label %cond.store1, label %else2 505 // 506 // cond.store1: 507 // % Elt1 = extractelement <16 x i32> %Src, i32 1 508 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 509 // store i32 % Elt1, i32* % Ptr1, align 4 510 // br label %else2 511 // . . . 512 static void scalarizeMaskedScatter(CallInst *CI) { 513 Value *Src = CI->getArgOperand(0); 514 Value *Ptrs = CI->getArgOperand(1); 515 Value *Alignment = CI->getArgOperand(2); 516 Value *Mask = CI->getArgOperand(3); 517 518 assert(isa<VectorType>(Src->getType()) && 519 "Unexpected data type in masked scatter intrinsic"); 520 assert(isa<VectorType>(Ptrs->getType()) && 521 isa<PointerType>(Ptrs->getType()->getVectorElementType()) && 522 "Vector of pointers is expected in masked scatter intrinsic"); 523 524 IRBuilder<> Builder(CI->getContext()); 525 Instruction *InsertPt = CI; 526 BasicBlock *IfBlock = CI->getParent(); 527 Builder.SetInsertPoint(InsertPt); 528 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 529 530 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 531 unsigned VectorWidth = Src->getType()->getVectorNumElements(); 532 533 // Shorten the way if the mask is a vector of constants. 534 bool IsConstMask = isa<ConstantVector>(Mask); 535 536 if (IsConstMask) { 537 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 538 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 539 continue; 540 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx), 541 "Elt" + Twine(Idx)); 542 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 543 "Ptr" + Twine(Idx)); 544 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 545 } 546 CI->eraseFromParent(); 547 return; 548 } 549 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 550 // Fill the "else" block, created in the previous iteration 551 // 552 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx 553 // % ToStore = icmp eq i1 % Mask1, true 554 // br i1 % ToStore, label %cond.store, label %else 555 // 556 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx), 557 "Mask" + Twine(Idx)); 558 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 559 ConstantInt::get(Predicate->getType(), 1), 560 "ToStore" + Twine(Idx)); 561 562 // Create "cond" block 563 // 564 // % Elt1 = extractelement <16 x i32> %Src, i32 1 565 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 566 // %store i32 % Elt1, i32* % Ptr1 567 // 568 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); 569 Builder.SetInsertPoint(InsertPt); 570 571 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx), 572 "Elt" + Twine(Idx)); 573 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 574 "Ptr" + Twine(Idx)); 575 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 576 577 // Create "else" block, fill it in the next iteration 578 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 579 Builder.SetInsertPoint(InsertPt); 580 Instruction *OldBr = IfBlock->getTerminator(); 581 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 582 OldBr->eraseFromParent(); 583 IfBlock = NewIfBlock; 584 } 585 CI->eraseFromParent(); 586 } 587 588 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) { 589 bool EverMadeChange = false; 590 591 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 592 593 bool MadeChange = true; 594 while (MadeChange) { 595 MadeChange = false; 596 for (Function::iterator I = F.begin(); I != F.end();) { 597 BasicBlock *BB = &*I++; 598 bool ModifiedDTOnIteration = false; 599 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration); 600 601 // Restart BB iteration if the dominator tree of the Function was changed 602 if (ModifiedDTOnIteration) 603 break; 604 } 605 606 EverMadeChange |= MadeChange; 607 } 608 609 return EverMadeChange; 610 } 611 612 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) { 613 bool MadeChange = false; 614 615 BasicBlock::iterator CurInstIterator = BB.begin(); 616 while (CurInstIterator != BB.end()) { 617 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 618 MadeChange |= optimizeCallInst(CI, ModifiedDT); 619 if (ModifiedDT) 620 return true; 621 } 622 623 return MadeChange; 624 } 625 626 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI, 627 bool &ModifiedDT) { 628 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 629 if (II) { 630 switch (II->getIntrinsicID()) { 631 default: 632 break; 633 case Intrinsic::masked_load: 634 // Scalarize unsupported vector masked load 635 if (!TTI->isLegalMaskedLoad(CI->getType())) { 636 scalarizeMaskedLoad(CI); 637 ModifiedDT = true; 638 return true; 639 } 640 return false; 641 case Intrinsic::masked_store: 642 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) { 643 scalarizeMaskedStore(CI); 644 ModifiedDT = true; 645 return true; 646 } 647 return false; 648 case Intrinsic::masked_gather: 649 if (!TTI->isLegalMaskedGather(CI->getType())) { 650 scalarizeMaskedGather(CI); 651 ModifiedDT = true; 652 return true; 653 } 654 return false; 655 case Intrinsic::masked_scatter: 656 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) { 657 scalarizeMaskedScatter(CI); 658 ModifiedDT = true; 659 return true; 660 } 661 return false; 662 } 663 } 664 665 return false; 666 } 667