1 //===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines vectorizer utilities. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/EquivalenceClasses.h" 15 #include "llvm/Analysis/DemandedBits.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 18 #include "llvm/Analysis/ScalarEvolution.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/VectorUtils.h" 21 #include "llvm/IR/GetElementPtrTypeIterator.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/IR/Value.h" 24 #include "llvm/IR/Constants.h" 25 26 using namespace llvm; 27 using namespace llvm::PatternMatch; 28 29 /// \brief Identify if the intrinsic is trivially vectorizable. 30 /// This method returns true if the intrinsic's argument types are all 31 /// scalars for the scalar form of the intrinsic and all vectors for 32 /// the vector form of the intrinsic. 33 bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { 34 switch (ID) { 35 case Intrinsic::sqrt: 36 case Intrinsic::sin: 37 case Intrinsic::cos: 38 case Intrinsic::exp: 39 case Intrinsic::exp2: 40 case Intrinsic::log: 41 case Intrinsic::log10: 42 case Intrinsic::log2: 43 case Intrinsic::fabs: 44 case Intrinsic::minnum: 45 case Intrinsic::maxnum: 46 case Intrinsic::copysign: 47 case Intrinsic::floor: 48 case Intrinsic::ceil: 49 case Intrinsic::trunc: 50 case Intrinsic::rint: 51 case Intrinsic::nearbyint: 52 case Intrinsic::round: 53 case Intrinsic::bswap: 54 case Intrinsic::ctpop: 55 case Intrinsic::pow: 56 case Intrinsic::fma: 57 case Intrinsic::fmuladd: 58 case Intrinsic::ctlz: 59 case Intrinsic::cttz: 60 case Intrinsic::powi: 61 return true; 62 default: 63 return false; 64 } 65 } 66 67 /// \brief Identifies if the intrinsic has a scalar operand. It check for 68 /// ctlz,cttz and powi special intrinsics whose argument is scalar. 69 bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, 70 unsigned ScalarOpdIdx) { 71 switch (ID) { 72 case Intrinsic::ctlz: 73 case Intrinsic::cttz: 74 case Intrinsic::powi: 75 return (ScalarOpdIdx == 1); 76 default: 77 return false; 78 } 79 } 80 81 /// \brief Check call has a unary float signature 82 /// It checks following: 83 /// a) call should have a single argument 84 /// b) argument type should be floating point type 85 /// c) call instruction type and argument type should be same 86 /// d) call should only reads memory. 87 /// If all these condition is met then return ValidIntrinsicID 88 /// else return not_intrinsic. 89 Intrinsic::ID 90 llvm::checkUnaryFloatSignature(const CallInst &I, 91 Intrinsic::ID ValidIntrinsicID) { 92 if (I.getNumArgOperands() != 1 || 93 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 94 I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) 95 return Intrinsic::not_intrinsic; 96 97 return ValidIntrinsicID; 98 } 99 100 /// \brief Check call has a binary float signature 101 /// It checks following: 102 /// a) call should have 2 arguments. 103 /// b) arguments type should be floating point type 104 /// c) call instruction type and arguments type should be same 105 /// d) call should only reads memory. 106 /// If all these condition is met then return ValidIntrinsicID 107 /// else return not_intrinsic. 108 Intrinsic::ID 109 llvm::checkBinaryFloatSignature(const CallInst &I, 110 Intrinsic::ID ValidIntrinsicID) { 111 if (I.getNumArgOperands() != 2 || 112 !I.getArgOperand(0)->getType()->isFloatingPointTy() || 113 !I.getArgOperand(1)->getType()->isFloatingPointTy() || 114 I.getType() != I.getArgOperand(0)->getType() || 115 I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) 116 return Intrinsic::not_intrinsic; 117 118 return ValidIntrinsicID; 119 } 120 121 /// \brief Returns intrinsic ID for call. 122 /// For the input call instruction it finds mapping intrinsic and returns 123 /// its ID, in case it does not found it return not_intrinsic. 124 Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, 125 const TargetLibraryInfo *TLI) { 126 // If we have an intrinsic call, check if it is trivially vectorizable. 127 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { 128 Intrinsic::ID ID = II->getIntrinsicID(); 129 if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || 130 ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) 131 return ID; 132 return Intrinsic::not_intrinsic; 133 } 134 135 if (!TLI) 136 return Intrinsic::not_intrinsic; 137 138 LibFunc::Func Func; 139 Function *F = CI->getCalledFunction(); 140 // We're going to make assumptions on the semantics of the functions, check 141 // that the target knows that it's available in this environment and it does 142 // not have local linkage. 143 if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) 144 return Intrinsic::not_intrinsic; 145 146 // Otherwise check if we have a call to a function that can be turned into a 147 // vector intrinsic. 148 switch (Func) { 149 default: 150 break; 151 case LibFunc::sin: 152 case LibFunc::sinf: 153 case LibFunc::sinl: 154 return checkUnaryFloatSignature(*CI, Intrinsic::sin); 155 case LibFunc::cos: 156 case LibFunc::cosf: 157 case LibFunc::cosl: 158 return checkUnaryFloatSignature(*CI, Intrinsic::cos); 159 case LibFunc::exp: 160 case LibFunc::expf: 161 case LibFunc::expl: 162 return checkUnaryFloatSignature(*CI, Intrinsic::exp); 163 case LibFunc::exp2: 164 case LibFunc::exp2f: 165 case LibFunc::exp2l: 166 return checkUnaryFloatSignature(*CI, Intrinsic::exp2); 167 case LibFunc::log: 168 case LibFunc::logf: 169 case LibFunc::logl: 170 return checkUnaryFloatSignature(*CI, Intrinsic::log); 171 case LibFunc::log10: 172 case LibFunc::log10f: 173 case LibFunc::log10l: 174 return checkUnaryFloatSignature(*CI, Intrinsic::log10); 175 case LibFunc::log2: 176 case LibFunc::log2f: 177 case LibFunc::log2l: 178 return checkUnaryFloatSignature(*CI, Intrinsic::log2); 179 case LibFunc::fabs: 180 case LibFunc::fabsf: 181 case LibFunc::fabsl: 182 return checkUnaryFloatSignature(*CI, Intrinsic::fabs); 183 case LibFunc::fmin: 184 case LibFunc::fminf: 185 case LibFunc::fminl: 186 return checkBinaryFloatSignature(*CI, Intrinsic::minnum); 187 case LibFunc::fmax: 188 case LibFunc::fmaxf: 189 case LibFunc::fmaxl: 190 return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); 191 case LibFunc::copysign: 192 case LibFunc::copysignf: 193 case LibFunc::copysignl: 194 return checkBinaryFloatSignature(*CI, Intrinsic::copysign); 195 case LibFunc::floor: 196 case LibFunc::floorf: 197 case LibFunc::floorl: 198 return checkUnaryFloatSignature(*CI, Intrinsic::floor); 199 case LibFunc::ceil: 200 case LibFunc::ceilf: 201 case LibFunc::ceill: 202 return checkUnaryFloatSignature(*CI, Intrinsic::ceil); 203 case LibFunc::trunc: 204 case LibFunc::truncf: 205 case LibFunc::truncl: 206 return checkUnaryFloatSignature(*CI, Intrinsic::trunc); 207 case LibFunc::rint: 208 case LibFunc::rintf: 209 case LibFunc::rintl: 210 return checkUnaryFloatSignature(*CI, Intrinsic::rint); 211 case LibFunc::nearbyint: 212 case LibFunc::nearbyintf: 213 case LibFunc::nearbyintl: 214 return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); 215 case LibFunc::round: 216 case LibFunc::roundf: 217 case LibFunc::roundl: 218 return checkUnaryFloatSignature(*CI, Intrinsic::round); 219 case LibFunc::pow: 220 case LibFunc::powf: 221 case LibFunc::powl: 222 return checkBinaryFloatSignature(*CI, Intrinsic::pow); 223 } 224 225 return Intrinsic::not_intrinsic; 226 } 227 228 /// \brief Find the operand of the GEP that should be checked for consecutive 229 /// stores. This ignores trailing indices that have no effect on the final 230 /// pointer. 231 unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { 232 const DataLayout &DL = Gep->getModule()->getDataLayout(); 233 unsigned LastOperand = Gep->getNumOperands() - 1; 234 unsigned GEPAllocSize = DL.getTypeAllocSize( 235 cast<PointerType>(Gep->getType()->getScalarType())->getElementType()); 236 237 // Walk backwards and try to peel off zeros. 238 while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { 239 // Find the type we're currently indexing into. 240 gep_type_iterator GEPTI = gep_type_begin(Gep); 241 std::advance(GEPTI, LastOperand - 1); 242 243 // If it's a type with the same allocation size as the result of the GEP we 244 // can peel off the zero index. 245 if (DL.getTypeAllocSize(*GEPTI) != GEPAllocSize) 246 break; 247 --LastOperand; 248 } 249 250 return LastOperand; 251 } 252 253 /// \brief If the argument is a GEP, then returns the operand identified by 254 /// getGEPInductionOperand. However, if there is some other non-loop-invariant 255 /// operand, it returns that instead. 256 Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 257 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 258 if (!GEP) 259 return Ptr; 260 261 unsigned InductionOperand = getGEPInductionOperand(GEP); 262 263 // Check that all of the gep indices are uniform except for our induction 264 // operand. 265 for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) 266 if (i != InductionOperand && 267 !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) 268 return Ptr; 269 return GEP->getOperand(InductionOperand); 270 } 271 272 /// \brief If a value has only one user that is a CastInst, return it. 273 Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { 274 Value *UniqueCast = nullptr; 275 for (User *U : Ptr->users()) { 276 CastInst *CI = dyn_cast<CastInst>(U); 277 if (CI && CI->getType() == Ty) { 278 if (!UniqueCast) 279 UniqueCast = CI; 280 else 281 return nullptr; 282 } 283 } 284 return UniqueCast; 285 } 286 287 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic 288 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. 289 Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { 290 auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); 291 if (!PtrTy || PtrTy->isAggregateType()) 292 return nullptr; 293 294 // Try to remove a gep instruction to make the pointer (actually index at this 295 // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the 296 // pointer, otherwise, we are analyzing the index. 297 Value *OrigPtr = Ptr; 298 299 // The size of the pointer access. 300 int64_t PtrAccessSize = 1; 301 302 Ptr = stripGetElementPtr(Ptr, SE, Lp); 303 const SCEV *V = SE->getSCEV(Ptr); 304 305 if (Ptr != OrigPtr) 306 // Strip off casts. 307 while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) 308 V = C->getOperand(); 309 310 const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); 311 if (!S) 312 return nullptr; 313 314 V = S->getStepRecurrence(*SE); 315 if (!V) 316 return nullptr; 317 318 // Strip off the size of access multiplication if we are still analyzing the 319 // pointer. 320 if (OrigPtr == Ptr) { 321 const DataLayout &DL = Lp->getHeader()->getModule()->getDataLayout(); 322 DL.getTypeAllocSize(PtrTy->getElementType()); 323 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { 324 if (M->getOperand(0)->getSCEVType() != scConstant) 325 return nullptr; 326 327 const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); 328 329 // Huge step value - give up. 330 if (APStepVal.getBitWidth() > 64) 331 return nullptr; 332 333 int64_t StepVal = APStepVal.getSExtValue(); 334 if (PtrAccessSize != StepVal) 335 return nullptr; 336 V = M->getOperand(1); 337 } 338 } 339 340 // Strip off casts. 341 Type *StripedOffRecurrenceCast = nullptr; 342 if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { 343 StripedOffRecurrenceCast = C->getType(); 344 V = C->getOperand(); 345 } 346 347 // Look for the loop invariant symbolic value. 348 const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); 349 if (!U) 350 return nullptr; 351 352 Value *Stride = U->getValue(); 353 if (!Lp->isLoopInvariant(Stride)) 354 return nullptr; 355 356 // If we have stripped off the recurrence cast we have to make sure that we 357 // return the value that is used in this loop so that we can replace it later. 358 if (StripedOffRecurrenceCast) 359 Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); 360 361 return Stride; 362 } 363 364 /// \brief Given a vector and an element number, see if the scalar value is 365 /// already around as a register, for example if it were inserted then extracted 366 /// from the vector. 367 Value *llvm::findScalarElement(Value *V, unsigned EltNo) { 368 assert(V->getType()->isVectorTy() && "Not looking at a vector?"); 369 VectorType *VTy = cast<VectorType>(V->getType()); 370 unsigned Width = VTy->getNumElements(); 371 if (EltNo >= Width) // Out of range access. 372 return UndefValue::get(VTy->getElementType()); 373 374 if (Constant *C = dyn_cast<Constant>(V)) 375 return C->getAggregateElement(EltNo); 376 377 if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { 378 // If this is an insert to a variable element, we don't know what it is. 379 if (!isa<ConstantInt>(III->getOperand(2))) 380 return nullptr; 381 unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); 382 383 // If this is an insert to the element we are looking for, return the 384 // inserted value. 385 if (EltNo == IIElt) 386 return III->getOperand(1); 387 388 // Otherwise, the insertelement doesn't modify the value, recurse on its 389 // vector input. 390 return findScalarElement(III->getOperand(0), EltNo); 391 } 392 393 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { 394 unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); 395 int InEl = SVI->getMaskValue(EltNo); 396 if (InEl < 0) 397 return UndefValue::get(VTy->getElementType()); 398 if (InEl < (int)LHSWidth) 399 return findScalarElement(SVI->getOperand(0), InEl); 400 return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); 401 } 402 403 // Extract a value from a vector add operation with a constant zero. 404 Value *Val = nullptr; Constant *Con = nullptr; 405 if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) 406 if (Constant *Elt = Con->getAggregateElement(EltNo)) 407 if (Elt->isNullValue()) 408 return findScalarElement(Val, EltNo); 409 410 // Otherwise, we don't know. 411 return nullptr; 412 } 413 414 /// \brief Get splat value if the input is a splat vector or return nullptr. 415 /// This function is not fully general. It checks only 2 cases: 416 /// the input value is (1) a splat constants vector or (2) a sequence 417 /// of instructions that broadcast a single value into a vector. 418 /// 419 const llvm::Value *llvm::getSplatValue(const Value *V) { 420 421 if (auto *C = dyn_cast<Constant>(V)) 422 if (isa<VectorType>(V->getType())) 423 return C->getSplatValue(); 424 425 auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); 426 if (!ShuffleInst) 427 return nullptr; 428 // All-zero (or undef) shuffle mask elements. 429 for (int MaskElt : ShuffleInst->getShuffleMask()) 430 if (MaskElt != 0 && MaskElt != -1) 431 return nullptr; 432 // The first shuffle source is 'insertelement' with index 0. 433 auto *InsertEltInst = 434 dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); 435 if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || 436 !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue()) 437 return nullptr; 438 439 return InsertEltInst->getOperand(1); 440 } 441 442 MapVector<Instruction *, uint64_t> 443 llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, 444 const TargetTransformInfo *TTI) { 445 446 // DemandedBits will give us every value's live-out bits. But we want 447 // to ensure no extra casts would need to be inserted, so every DAG 448 // of connected values must have the same minimum bitwidth. 449 EquivalenceClasses<Value *> ECs; 450 SmallVector<Value *, 16> Worklist; 451 SmallPtrSet<Value *, 4> Roots; 452 SmallPtrSet<Value *, 16> Visited; 453 DenseMap<Value *, uint64_t> DBits; 454 SmallPtrSet<Instruction *, 4> InstructionSet; 455 MapVector<Instruction *, uint64_t> MinBWs; 456 457 // Determine the roots. We work bottom-up, from truncs or icmps. 458 bool SeenExtFromIllegalType = false; 459 for (auto *BB : Blocks) 460 for (auto &I : *BB) { 461 InstructionSet.insert(&I); 462 463 if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && 464 !TTI->isTypeLegal(I.getOperand(0)->getType())) 465 SeenExtFromIllegalType = true; 466 467 // Only deal with non-vector integers up to 64-bits wide. 468 if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && 469 !I.getType()->isVectorTy() && 470 I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { 471 // Don't make work for ourselves. If we know the loaded type is legal, 472 // don't add it to the worklist. 473 if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) 474 continue; 475 476 Worklist.push_back(&I); 477 Roots.insert(&I); 478 } 479 } 480 // Early exit. 481 if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) 482 return MinBWs; 483 484 // Now proceed breadth-first, unioning values together. 485 while (!Worklist.empty()) { 486 Value *Val = Worklist.pop_back_val(); 487 Value *Leader = ECs.getOrInsertLeaderValue(Val); 488 489 if (Visited.count(Val)) 490 continue; 491 Visited.insert(Val); 492 493 // Non-instructions terminate a chain successfully. 494 if (!isa<Instruction>(Val)) 495 continue; 496 Instruction *I = cast<Instruction>(Val); 497 498 // If we encounter a type that is larger than 64 bits, we can't represent 499 // it so bail out. 500 if (DB.getDemandedBits(I).getBitWidth() > 64) 501 return MapVector<Instruction *, uint64_t>(); 502 503 uint64_t V = DB.getDemandedBits(I).getZExtValue(); 504 DBits[Leader] |= V; 505 506 // Casts, loads and instructions outside of our range terminate a chain 507 // successfully. 508 if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || 509 !InstructionSet.count(I)) 510 continue; 511 512 // Unsafe casts terminate a chain unsuccessfully. We can't do anything 513 // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to 514 // transform anything that relies on them. 515 if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || 516 !I->getType()->isIntegerTy()) { 517 DBits[Leader] |= ~0ULL; 518 continue; 519 } 520 521 // We don't modify the types of PHIs. Reductions will already have been 522 // truncated if possible, and inductions' sizes will have been chosen by 523 // indvars. 524 if (isa<PHINode>(I)) 525 continue; 526 527 if (DBits[Leader] == ~0ULL) 528 // All bits demanded, no point continuing. 529 continue; 530 531 for (Value *O : cast<User>(I)->operands()) { 532 ECs.unionSets(Leader, O); 533 Worklist.push_back(O); 534 } 535 } 536 537 // Now we've discovered all values, walk them to see if there are 538 // any users we didn't see. If there are, we can't optimize that 539 // chain. 540 for (auto &I : DBits) 541 for (auto *U : I.first->users()) 542 if (U->getType()->isIntegerTy() && DBits.count(U) == 0) 543 DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; 544 545 for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { 546 uint64_t LeaderDemandedBits = 0; 547 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) 548 LeaderDemandedBits |= DBits[*MI]; 549 550 uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - 551 llvm::countLeadingZeros(LeaderDemandedBits); 552 // Round up to a power of 2 553 if (!isPowerOf2_64((uint64_t)MinBW)) 554 MinBW = NextPowerOf2(MinBW); 555 for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { 556 if (!isa<Instruction>(*MI)) 557 continue; 558 Type *Ty = (*MI)->getType(); 559 if (Roots.count(*MI)) 560 Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); 561 if (MinBW < Ty->getScalarSizeInBits()) 562 MinBWs[cast<Instruction>(*MI)] = MinBW; 563 } 564 } 565 566 return MinBWs; 567 } 568