Home | History | Annotate | Download | only in InstCombine
      1 //===- InstCombineShifts.cpp ----------------------------------------------===//
      2 //
      3 //                     The LLVM Compiler Infrastructure
      4 //
      5 // This file is distributed under the University of Illinois Open Source
      6 // License. See LICENSE.TXT for details.
      7 //
      8 //===----------------------------------------------------------------------===//
      9 //
     10 // This file implements the visitShl, visitLShr, and visitAShr functions.
     11 //
     12 //===----------------------------------------------------------------------===//
     13 
     14 #include "InstCombineInternal.h"
     15 #include "llvm/Analysis/ConstantFolding.h"
     16 #include "llvm/Analysis/InstructionSimplify.h"
     17 #include "llvm/IR/IntrinsicInst.h"
     18 #include "llvm/IR/PatternMatch.h"
     19 using namespace llvm;
     20 using namespace PatternMatch;
     21 
     22 #define DEBUG_TYPE "instcombine"
     23 
     24 Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
     25   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
     26   assert(Op0->getType() == Op1->getType());
     27 
     28   // See if we can fold away this shift.
     29   if (SimplifyDemandedInstructionBits(I))
     30     return &I;
     31 
     32   // Try to fold constant and into select arguments.
     33   if (isa<Constant>(Op0))
     34     if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
     35       if (Instruction *R = FoldOpIntoSelect(I, SI))
     36         return R;
     37 
     38   if (Constant *CUI = dyn_cast<Constant>(Op1))
     39     if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
     40       return Res;
     41 
     42   // (C1 shift (A add C2)) -> (C1 shift C2) shift A)
     43   // iff A and C2 are both positive.
     44   Value *A;
     45   Constant *C;
     46   if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C))))
     47     if (isKnownNonNegative(A, DL, 0, &AC, &I, &DT) &&
     48         isKnownNonNegative(C, DL, 0, &AC, &I, &DT))
     49       return BinaryOperator::Create(
     50           I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), Op0, C), A);
     51 
     52   // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2.
     53   // Because shifts by negative values (which could occur if A were negative)
     54   // are undefined.
     55   const APInt *B;
     56   if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) {
     57     // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
     58     // demand the sign bit (and many others) here??
     59     Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1),
     60                                    Op1->getName());
     61     I.setOperand(1, Rem);
     62     return &I;
     63   }
     64 
     65   return nullptr;
     66 }
     67 
     68 /// Return true if we can simplify two logical (either left or right) shifts
     69 /// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
     70 static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
     71                                     Instruction *InnerShift, InstCombiner &IC,
     72                                     Instruction *CxtI) {
     73   assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
     74 
     75   // We need constant scalar or constant splat shifts.
     76   const APInt *InnerShiftConst;
     77   if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
     78     return false;
     79 
     80   // Two logical shifts in the same direction:
     81   // shl (shl X, C1), C2 -->  shl X, C1 + C2
     82   // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
     83   bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
     84   if (IsInnerShl == IsOuterShl)
     85     return true;
     86 
     87   // Equal shift amounts in opposite directions become bitwise 'and':
     88   // lshr (shl X, C), C --> and X, C'
     89   // shl (lshr X, C), C --> and X, C'
     90   if (*InnerShiftConst == OuterShAmt)
     91     return true;
     92 
     93   // If the 2nd shift is bigger than the 1st, we can fold:
     94   // lshr (shl X, C1), C2 -->  and (shl X, C1 - C2), C3
     95   // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
     96   // but it isn't profitable unless we know the and'd out bits are already zero.
     97   // Also, check that the inner shift is valid (less than the type width) or
     98   // we'll crash trying to produce the bit mask for the 'and'.
     99   unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
    100   if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
    101     unsigned InnerShAmt = InnerShiftConst->getZExtValue();
    102     unsigned MaskShift =
    103         IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
    104     APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
    105     if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
    106       return true;
    107   }
    108 
    109   return false;
    110 }
    111 
    112 /// See if we can compute the specified value, but shifted logically to the left
    113 /// or right by some number of bits. This should return true if the expression
    114 /// can be computed for the same cost as the current expression tree. This is
    115 /// used to eliminate extraneous shifting from things like:
    116 ///      %C = shl i128 %A, 64
    117 ///      %D = shl i128 %B, 96
    118 ///      %E = or i128 %C, %D
    119 ///      %F = lshr i128 %E, 64
    120 /// where the client will ask if E can be computed shifted right by 64-bits. If
    121 /// this succeeds, getShiftedValue() will be called to produce the value.
    122 static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
    123                                InstCombiner &IC, Instruction *CxtI) {
    124   // We can always evaluate constants shifted.
    125   if (isa<Constant>(V))
    126     return true;
    127 
    128   Instruction *I = dyn_cast<Instruction>(V);
    129   if (!I) return false;
    130 
    131   // If this is the opposite shift, we can directly reuse the input of the shift
    132   // if the needed bits are already zero in the input.  This allows us to reuse
    133   // the value which means that we don't care if the shift has multiple uses.
    134   //  TODO:  Handle opposite shift by exact value.
    135   ConstantInt *CI = nullptr;
    136   if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
    137       (!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
    138     if (CI->getValue() == NumBits) {
    139       // TODO: Check that the input bits are already zero with MaskedValueIsZero
    140 #if 0
    141       // If this is a truncate of a logical shr, we can truncate it to a smaller
    142       // lshr iff we know that the bits we would otherwise be shifting in are
    143       // already zeros.
    144       uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
    145       uint32_t BitWidth = Ty->getScalarSizeInBits();
    146       if (MaskedValueIsZero(I->getOperand(0),
    147             APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) &&
    148           CI->getLimitedValue(BitWidth) < BitWidth) {
    149         return CanEvaluateTruncated(I->getOperand(0), Ty);
    150       }
    151 #endif
    152 
    153     }
    154   }
    155 
    156   // We can't mutate something that has multiple uses: doing so would
    157   // require duplicating the instruction in general, which isn't profitable.
    158   if (!I->hasOneUse()) return false;
    159 
    160   switch (I->getOpcode()) {
    161   default: return false;
    162   case Instruction::And:
    163   case Instruction::Or:
    164   case Instruction::Xor:
    165     // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
    166     return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
    167            canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
    168 
    169   case Instruction::Shl:
    170   case Instruction::LShr:
    171     return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
    172 
    173   case Instruction::Select: {
    174     SelectInst *SI = cast<SelectInst>(I);
    175     Value *TrueVal = SI->getTrueValue();
    176     Value *FalseVal = SI->getFalseValue();
    177     return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
    178            canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
    179   }
    180   case Instruction::PHI: {
    181     // We can change a phi if we can change all operands.  Note that we never
    182     // get into trouble with cyclic PHIs here because we only consider
    183     // instructions with a single use.
    184     PHINode *PN = cast<PHINode>(I);
    185     for (Value *IncValue : PN->incoming_values())
    186       if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
    187         return false;
    188     return true;
    189   }
    190   }
    191 }
    192 
    193 /// Fold OuterShift (InnerShift X, C1), C2.
    194 /// See canEvaluateShiftedShift() for the constraints on these instructions.
    195 static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
    196                                bool IsOuterShl,
    197                                InstCombiner::BuilderTy &Builder) {
    198   bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
    199   Type *ShType = InnerShift->getType();
    200   unsigned TypeWidth = ShType->getScalarSizeInBits();
    201 
    202   // We only accept shifts-by-a-constant in canEvaluateShifted().
    203   const APInt *C1;
    204   match(InnerShift->getOperand(1), m_APInt(C1));
    205   unsigned InnerShAmt = C1->getZExtValue();
    206 
    207   // Change the shift amount and clear the appropriate IR flags.
    208   auto NewInnerShift = [&](unsigned ShAmt) {
    209     InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
    210     if (IsInnerShl) {
    211       InnerShift->setHasNoUnsignedWrap(false);
    212       InnerShift->setHasNoSignedWrap(false);
    213     } else {
    214       InnerShift->setIsExact(false);
    215     }
    216     return InnerShift;
    217   };
    218 
    219   // Two logical shifts in the same direction:
    220   // shl (shl X, C1), C2 -->  shl X, C1 + C2
    221   // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
    222   if (IsInnerShl == IsOuterShl) {
    223     // If this is an oversized composite shift, then unsigned shifts get 0.
    224     if (InnerShAmt + OuterShAmt >= TypeWidth)
    225       return Constant::getNullValue(ShType);
    226 
    227     return NewInnerShift(InnerShAmt + OuterShAmt);
    228   }
    229 
    230   // Equal shift amounts in opposite directions become bitwise 'and':
    231   // lshr (shl X, C), C --> and X, C'
    232   // shl (lshr X, C), C --> and X, C'
    233   if (InnerShAmt == OuterShAmt) {
    234     APInt Mask = IsInnerShl
    235                      ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
    236                      : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
    237     Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
    238                                    ConstantInt::get(ShType, Mask));
    239     if (auto *AndI = dyn_cast<Instruction>(And)) {
    240       AndI->moveBefore(InnerShift);
    241       AndI->takeName(InnerShift);
    242     }
    243     return And;
    244   }
    245 
    246   assert(InnerShAmt > OuterShAmt &&
    247          "Unexpected opposite direction logical shift pair");
    248 
    249   // In general, we would need an 'and' for this transform, but
    250   // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
    251   // lshr (shl X, C1), C2 -->  shl X, C1 - C2
    252   // shl (lshr X, C1), C2 --> lshr X, C1 - C2
    253   return NewInnerShift(InnerShAmt - OuterShAmt);
    254 }
    255 
    256 /// When canEvaluateShifted() returns true for an expression, this function
    257 /// inserts the new computation that produces the shifted value.
    258 static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
    259                               InstCombiner &IC, const DataLayout &DL) {
    260   // We can always evaluate constants shifted.
    261   if (Constant *C = dyn_cast<Constant>(V)) {
    262     if (isLeftShift)
    263       V = IC.Builder.CreateShl(C, NumBits);
    264     else
    265       V = IC.Builder.CreateLShr(C, NumBits);
    266     // If we got a constantexpr back, try to simplify it with TD info.
    267     if (auto *C = dyn_cast<Constant>(V))
    268       if (auto *FoldedC =
    269               ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo()))
    270         V = FoldedC;
    271     return V;
    272   }
    273 
    274   Instruction *I = cast<Instruction>(V);
    275   IC.Worklist.Add(I);
    276 
    277   switch (I->getOpcode()) {
    278   default: llvm_unreachable("Inconsistency with CanEvaluateShifted");
    279   case Instruction::And:
    280   case Instruction::Or:
    281   case Instruction::Xor:
    282     // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
    283     I->setOperand(
    284         0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
    285     I->setOperand(
    286         1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
    287     return I;
    288 
    289   case Instruction::Shl:
    290   case Instruction::LShr:
    291     return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
    292                             IC.Builder);
    293 
    294   case Instruction::Select:
    295     I->setOperand(
    296         1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
    297     I->setOperand(
    298         2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
    299     return I;
    300   case Instruction::PHI: {
    301     // We can change a phi if we can change all operands.  Note that we never
    302     // get into trouble with cyclic PHIs here because we only consider
    303     // instructions with a single use.
    304     PHINode *PN = cast<PHINode>(I);
    305     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
    306       PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
    307                                               isLeftShift, IC, DL));
    308     return PN;
    309   }
    310   }
    311 }
    312 
    313 // If this is a bitwise operator or add with a constant RHS we might be able
    314 // to pull it through a shift.
    315 static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
    316                                          BinaryOperator *BO,
    317                                          const APInt &C) {
    318   bool IsValid = true;     // Valid only for And, Or Xor,
    319   bool HighBitSet = false; // Transform ifhigh bit of constant set?
    320 
    321   switch (BO->getOpcode()) {
    322   default: IsValid = false; break;   // Do not perform transform!
    323   case Instruction::Add:
    324     IsValid = Shift.getOpcode() == Instruction::Shl;
    325     break;
    326   case Instruction::Or:
    327   case Instruction::Xor:
    328     HighBitSet = false;
    329     break;
    330   case Instruction::And:
    331     HighBitSet = true;
    332     break;
    333   }
    334 
    335   // If this is a signed shift right, and the high bit is modified
    336   // by the logical operation, do not perform the transformation.
    337   // The HighBitSet boolean indicates the value of the high bit of
    338   // the constant which would cause it to be modified for this
    339   // operation.
    340   //
    341   if (IsValid && Shift.getOpcode() == Instruction::AShr)
    342     IsValid = C.isNegative() == HighBitSet;
    343 
    344   return IsValid;
    345 }
    346 
    347 Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
    348                                                BinaryOperator &I) {
    349   bool isLeftShift = I.getOpcode() == Instruction::Shl;
    350 
    351   const APInt *Op1C;
    352   if (!match(Op1, m_APInt(Op1C)))
    353     return nullptr;
    354 
    355   // See if we can propagate this shift into the input, this covers the trivial
    356   // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
    357   if (I.getOpcode() != Instruction::AShr &&
    358       canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
    359     LLVM_DEBUG(
    360         dbgs() << "ICE: GetShiftedValue propagating shift through expression"
    361                   " to eliminate shift:\n  IN: "
    362                << *Op0 << "\n  SH: " << I << "\n");
    363 
    364     return replaceInstUsesWith(
    365         I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
    366   }
    367 
    368   // See if we can simplify any instructions used by the instruction whose sole
    369   // purpose is to compute bits we don't care about.
    370   unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
    371 
    372   assert(!Op1C->uge(TypeBits) &&
    373          "Shift over the type width should have been removed already");
    374 
    375   if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
    376     return FoldedShift;
    377 
    378   // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2))
    379   if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) {
    380     Instruction *TrOp = dyn_cast<Instruction>(TI->getOperand(0));
    381     // If 'shift2' is an ashr, we would have to get the sign bit into a funny
    382     // place.  Don't try to do this transformation in this case.  Also, we
    383     // require that the input operand is a shift-by-constant so that we have
    384     // confidence that the shifts will get folded together.  We could do this
    385     // xform in more cases, but it is unlikely to be profitable.
    386     if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
    387         isa<ConstantInt>(TrOp->getOperand(1))) {
    388       // Okay, we'll do this xform.  Make the shift of shift.
    389       Constant *ShAmt =
    390           ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
    391       // (shift2 (shift1 & 0x00FF), c2)
    392       Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
    393 
    394       // For logical shifts, the truncation has the effect of making the high
    395       // part of the register be zeros.  Emulate this by inserting an AND to
    396       // clear the top bits as needed.  This 'and' will usually be zapped by
    397       // other xforms later if dead.
    398       unsigned SrcSize = TrOp->getType()->getScalarSizeInBits();
    399       unsigned DstSize = TI->getType()->getScalarSizeInBits();
    400       APInt MaskV(APInt::getLowBitsSet(SrcSize, DstSize));
    401 
    402       // The mask we constructed says what the trunc would do if occurring
    403       // between the shifts.  We want to know the effect *after* the second
    404       // shift.  We know that it is a logical shift by a constant, so adjust the
    405       // mask as appropriate.
    406       if (I.getOpcode() == Instruction::Shl)
    407         MaskV <<= Op1C->getZExtValue();
    408       else {
    409         assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
    410         MaskV.lshrInPlace(Op1C->getZExtValue());
    411       }
    412 
    413       // shift1 & 0x00FF
    414       Value *And = Builder.CreateAnd(NSh,
    415                                      ConstantInt::get(I.getContext(), MaskV),
    416                                      TI->getName());
    417 
    418       // Return the value truncated to the interesting size.
    419       return new TruncInst(And, I.getType());
    420     }
    421   }
    422 
    423   if (Op0->hasOneUse()) {
    424     if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) {
    425       // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
    426       Value *V1, *V2;
    427       ConstantInt *CC;
    428       switch (Op0BO->getOpcode()) {
    429       default: break;
    430       case Instruction::Add:
    431       case Instruction::And:
    432       case Instruction::Or:
    433       case Instruction::Xor: {
    434         // These operators commute.
    435         // Turn (Y + (X >> C)) << C  ->  (X + (Y << C)) & (~0 << C)
    436         if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
    437             match(Op0BO->getOperand(1), m_Shr(m_Value(V1),
    438                   m_Specific(Op1)))) {
    439           Value *YS =         // (Y << C)
    440             Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
    441           // (X + (Y << C))
    442           Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1,
    443                                          Op0BO->getOperand(1)->getName());
    444           unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
    445 
    446           APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
    447           Constant *Mask = ConstantInt::get(I.getContext(), Bits);
    448           if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
    449             Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
    450           return BinaryOperator::CreateAnd(X, Mask);
    451         }
    452 
    453         // Turn (Y + ((X >> C) & CC)) << C  ->  ((X & (CC << C)) + (Y << C))
    454         Value *Op0BOOp1 = Op0BO->getOperand(1);
    455         if (isLeftShift && Op0BOOp1->hasOneUse() &&
    456             match(Op0BOOp1,
    457                   m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
    458                         m_ConstantInt(CC)))) {
    459           Value *YS =   // (Y << C)
    460             Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
    461           // X & (CC << C)
    462           Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
    463                                         V1->getName()+".mask");
    464           return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM);
    465         }
    466         LLVM_FALLTHROUGH;
    467       }
    468 
    469       case Instruction::Sub: {
    470         // Turn ((X >> C) + Y) << C  ->  (X + (Y << C)) & (~0 << C)
    471         if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
    472             match(Op0BO->getOperand(0), m_Shr(m_Value(V1),
    473                   m_Specific(Op1)))) {
    474           Value *YS =  // (Y << C)
    475             Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
    476           // (X + (Y << C))
    477           Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS,
    478                                          Op0BO->getOperand(0)->getName());
    479           unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
    480 
    481           APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
    482           Constant *Mask = ConstantInt::get(I.getContext(), Bits);
    483           if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
    484             Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
    485           return BinaryOperator::CreateAnd(X, Mask);
    486         }
    487 
    488         // Turn (((X >> C)&CC) + Y) << C  ->  (X + (Y << C)) & (CC << C)
    489         if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
    490             match(Op0BO->getOperand(0),
    491                   m_And(m_OneUse(m_Shr(m_Value(V1), m_Value(V2))),
    492                         m_ConstantInt(CC))) && V2 == Op1) {
    493           Value *YS = // (Y << C)
    494             Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
    495           // X & (CC << C)
    496           Value *XM = Builder.CreateAnd(V1, ConstantExpr::getShl(CC, Op1),
    497                                         V1->getName()+".mask");
    498 
    499           return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS);
    500         }
    501 
    502         break;
    503       }
    504       }
    505 
    506 
    507       // If the operand is a bitwise operator with a constant RHS, and the
    508       // shift is the only use, we can pull it out of the shift.
    509       const APInt *Op0C;
    510       if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
    511         if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) {
    512           Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
    513                                      cast<Constant>(Op0BO->getOperand(1)), Op1);
    514 
    515           Value *NewShift =
    516             Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1);
    517           NewShift->takeName(Op0BO);
    518 
    519           return BinaryOperator::Create(Op0BO->getOpcode(), NewShift,
    520                                         NewRHS);
    521         }
    522       }
    523 
    524       // If the operand is a subtract with a constant LHS, and the shift
    525       // is the only use, we can pull it out of the shift.
    526       // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
    527       if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
    528           match(Op0BO->getOperand(0), m_APInt(Op0C))) {
    529         Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
    530                                    cast<Constant>(Op0BO->getOperand(0)), Op1);
    531 
    532         Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1);
    533         NewShift->takeName(Op0BO);
    534 
    535         return BinaryOperator::CreateSub(NewRHS, NewShift);
    536       }
    537     }
    538 
    539     // If we have a select that conditionally executes some binary operator,
    540     // see if we can pull it the select and operator through the shift.
    541     //
    542     // For example, turning:
    543     //   shl (select C, (add X, C1), X), C2
    544     // Into:
    545     //   Y = shl X, C2
    546     //   select C, (add Y, C1 << C2), Y
    547     Value *Cond;
    548     BinaryOperator *TBO;
    549     Value *FalseVal;
    550     if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)),
    551                             m_Value(FalseVal)))) {
    552       const APInt *C;
    553       if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal &&
    554           match(TBO->getOperand(1), m_APInt(C)) &&
    555           canShiftBinOpWithConstantRHS(I, TBO, *C)) {
    556         Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
    557                                        cast<Constant>(TBO->getOperand(1)), Op1);
    558 
    559         Value *NewShift =
    560           Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1);
    561         Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift,
    562                                            NewRHS);
    563         return SelectInst::Create(Cond, NewOp, NewShift);
    564       }
    565     }
    566 
    567     BinaryOperator *FBO;
    568     Value *TrueVal;
    569     if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal),
    570                             m_OneUse(m_BinOp(FBO))))) {
    571       const APInt *C;
    572       if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal &&
    573           match(FBO->getOperand(1), m_APInt(C)) &&
    574           canShiftBinOpWithConstantRHS(I, FBO, *C)) {
    575         Constant *NewRHS = ConstantExpr::get(I.getOpcode(),
    576                                        cast<Constant>(FBO->getOperand(1)), Op1);
    577 
    578         Value *NewShift =
    579           Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1);
    580         Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift,
    581                                            NewRHS);
    582         return SelectInst::Create(Cond, NewShift, NewOp);
    583       }
    584     }
    585   }
    586 
    587   return nullptr;
    588 }
    589 
    590 Instruction *InstCombiner::visitShl(BinaryOperator &I) {
    591   if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1),
    592                                  I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
    593                                  SQ.getWithInstruction(&I)))
    594     return replaceInstUsesWith(I, V);
    595 
    596   if (Instruction *X = foldShuffledBinop(I))
    597     return X;
    598 
    599   if (Instruction *V = commonShiftTransforms(I))
    600     return V;
    601 
    602   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
    603   Type *Ty = I.getType();
    604   const APInt *ShAmtAPInt;
    605   if (match(Op1, m_APInt(ShAmtAPInt))) {
    606     unsigned ShAmt = ShAmtAPInt->getZExtValue();
    607     unsigned BitWidth = Ty->getScalarSizeInBits();
    608 
    609     // shl (zext X), ShAmt --> zext (shl X, ShAmt)
    610     // This is only valid if X would have zeros shifted out.
    611     Value *X;
    612     if (match(Op0, m_ZExt(m_Value(X)))) {
    613       unsigned SrcWidth = X->getType()->getScalarSizeInBits();
    614       if (ShAmt < SrcWidth &&
    615           MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
    616         return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty);
    617     }
    618 
    619     // (X >> C) << C --> X & (-1 << C)
    620     if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) {
    621       APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
    622       return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
    623     }
    624 
    625     // FIXME: we do not yet transform non-exact shr's. The backend (DAGCombine)
    626     // needs a few fixes for the rotate pattern recognition first.
    627     const APInt *ShOp1;
    628     if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1))))) {
    629       unsigned ShrAmt = ShOp1->getZExtValue();
    630       if (ShrAmt < ShAmt) {
    631         // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
    632         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
    633         auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
    634         NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
    635         NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
    636         return NewShl;
    637       }
    638       if (ShrAmt > ShAmt) {
    639         // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
    640         Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
    641         auto *NewShr = BinaryOperator::Create(
    642             cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
    643         NewShr->setIsExact(true);
    644         return NewShr;
    645       }
    646     }
    647 
    648     if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
    649       unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
    650       // Oversized shifts are simplified to zero in InstSimplify.
    651       if (AmtSum < BitWidth)
    652         // (X << C1) << C2 --> X << (C1 + C2)
    653         return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
    654     }
    655 
    656     // If the shifted-out value is known-zero, then this is a NUW shift.
    657     if (!I.hasNoUnsignedWrap() &&
    658         MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
    659       I.setHasNoUnsignedWrap();
    660       return &I;
    661     }
    662 
    663     // If the shifted-out value is all signbits, then this is a NSW shift.
    664     if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
    665       I.setHasNoSignedWrap();
    666       return &I;
    667     }
    668   }
    669 
    670   // Transform  (x >> y) << y  to  x & (-1 << y)
    671   // Valid for any type of right-shift.
    672   Value *X;
    673   if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
    674     Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
    675     Value *Mask = Builder.CreateShl(AllOnes, Op1);
    676     return BinaryOperator::CreateAnd(Mask, X);
    677   }
    678 
    679   Constant *C1;
    680   if (match(Op1, m_Constant(C1))) {
    681     Constant *C2;
    682     Value *X;
    683     // (C2 << X) << C1 --> (C2 << C1) << X
    684     if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
    685       return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
    686 
    687     // (X * C2) << C1 --> X * (C2 << C1)
    688     if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
    689       return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
    690   }
    691 
    692   return nullptr;
    693 }
    694 
    695 Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
    696   if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
    697                                   SQ.getWithInstruction(&I)))
    698     return replaceInstUsesWith(I, V);
    699 
    700   if (Instruction *X = foldShuffledBinop(I))
    701     return X;
    702 
    703   if (Instruction *R = commonShiftTransforms(I))
    704     return R;
    705 
    706   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
    707   Type *Ty = I.getType();
    708   const APInt *ShAmtAPInt;
    709   if (match(Op1, m_APInt(ShAmtAPInt))) {
    710     unsigned ShAmt = ShAmtAPInt->getZExtValue();
    711     unsigned BitWidth = Ty->getScalarSizeInBits();
    712     auto *II = dyn_cast<IntrinsicInst>(Op0);
    713     if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
    714         (II->getIntrinsicID() == Intrinsic::ctlz ||
    715          II->getIntrinsicID() == Intrinsic::cttz ||
    716          II->getIntrinsicID() == Intrinsic::ctpop)) {
    717       // ctlz.i32(x)>>5  --> zext(x == 0)
    718       // cttz.i32(x)>>5  --> zext(x == 0)
    719       // ctpop.i32(x)>>5 --> zext(x == -1)
    720       bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
    721       Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
    722       Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS);
    723       return new ZExtInst(Cmp, Ty);
    724     }
    725 
    726     Value *X;
    727     const APInt *ShOp1;
    728     if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
    729       unsigned ShlAmt = ShOp1->getZExtValue();
    730       if (ShlAmt < ShAmt) {
    731         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
    732         if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
    733           // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
    734           auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
    735           NewLShr->setIsExact(I.isExact());
    736           return NewLShr;
    737         }
    738         // (X << C1) >>u C2  --> (X >>u (C2 - C1)) & (-1 >> C2)
    739         Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact());
    740         APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
    741         return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
    742       }
    743       if (ShlAmt > ShAmt) {
    744         Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
    745         if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
    746           // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
    747           auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
    748           NewShl->setHasNoUnsignedWrap(true);
    749           return NewShl;
    750         }
    751         // (X << C1) >>u C2  --> X << (C1 - C2) & (-1 >> C2)
    752         Value *NewShl = Builder.CreateShl(X, ShiftDiff);
    753         APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
    754         return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
    755       }
    756       assert(ShlAmt == ShAmt);
    757       // (X << C) >>u C --> X & (-1 >>u C)
    758       APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
    759       return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
    760     }
    761 
    762     if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) &&
    763         (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
    764       assert(ShAmt < X->getType()->getScalarSizeInBits() &&
    765              "Big shift not simplified to zero?");
    766       // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
    767       Value *NewLShr = Builder.CreateLShr(X, ShAmt);
    768       return new ZExtInst(NewLShr, Ty);
    769     }
    770 
    771     if (match(Op0, m_SExt(m_Value(X))) &&
    772         (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
    773       // Are we moving the sign bit to the low bit and widening with high zeros?
    774       unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
    775       if (ShAmt == BitWidth - 1) {
    776         // lshr (sext i1 X to iN), N-1 --> zext X to iN
    777         if (SrcTyBitWidth == 1)
    778           return new ZExtInst(X, Ty);
    779 
    780         // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
    781         if (Op0->hasOneUse()) {
    782           Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
    783           return new ZExtInst(NewLShr, Ty);
    784         }
    785       }
    786 
    787       // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
    788       if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) {
    789         // The new shift amount can't be more than the narrow source type.
    790         unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1);
    791         Value *AShr = Builder.CreateAShr(X, NewShAmt);
    792         return new ZExtInst(AShr, Ty);
    793       }
    794     }
    795 
    796     if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
    797       unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
    798       // Oversized shifts are simplified to zero in InstSimplify.
    799       if (AmtSum < BitWidth)
    800         // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
    801         return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
    802     }
    803 
    804     // If the shifted-out value is known-zero, then this is an exact shift.
    805     if (!I.isExact() &&
    806         MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
    807       I.setIsExact();
    808       return &I;
    809     }
    810   }
    811 
    812   // Transform  (x << y) >> y  to  x & (-1 >> y)
    813   Value *X;
    814   if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
    815     Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
    816     Value *Mask = Builder.CreateLShr(AllOnes, Op1);
    817     return BinaryOperator::CreateAnd(Mask, X);
    818   }
    819 
    820   return nullptr;
    821 }
    822 
    823 Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
    824   if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
    825                                   SQ.getWithInstruction(&I)))
    826     return replaceInstUsesWith(I, V);
    827 
    828   if (Instruction *X = foldShuffledBinop(I))
    829     return X;
    830 
    831   if (Instruction *R = commonShiftTransforms(I))
    832     return R;
    833 
    834   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
    835   Type *Ty = I.getType();
    836   unsigned BitWidth = Ty->getScalarSizeInBits();
    837   const APInt *ShAmtAPInt;
    838   if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
    839     unsigned ShAmt = ShAmtAPInt->getZExtValue();
    840 
    841     // If the shift amount equals the difference in width of the destination
    842     // and source scalar types:
    843     // ashr (shl (zext X), C), C --> sext X
    844     Value *X;
    845     if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
    846         ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
    847       return new SExtInst(X, Ty);
    848 
    849     // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
    850     // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
    851     const APInt *ShOp1;
    852     if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
    853         ShOp1->ult(BitWidth)) {
    854       unsigned ShlAmt = ShOp1->getZExtValue();
    855       if (ShlAmt < ShAmt) {
    856         // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
    857         Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
    858         auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
    859         NewAShr->setIsExact(I.isExact());
    860         return NewAShr;
    861       }
    862       if (ShlAmt > ShAmt) {
    863         // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
    864         Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
    865         auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
    866         NewShl->setHasNoSignedWrap(true);
    867         return NewShl;
    868       }
    869     }
    870 
    871     if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
    872         ShOp1->ult(BitWidth)) {
    873       unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
    874       // Oversized arithmetic shifts replicate the sign bit.
    875       AmtSum = std::min(AmtSum, BitWidth - 1);
    876       // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
    877       return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
    878     }
    879 
    880     if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
    881         (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
    882       // ashr (sext X), C --> sext (ashr X, C')
    883       Type *SrcTy = X->getType();
    884       ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1);
    885       Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt));
    886       return new SExtInst(NewSh, Ty);
    887     }
    888 
    889     // If the shifted-out value is known-zero, then this is an exact shift.
    890     if (!I.isExact() &&
    891         MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
    892       I.setIsExact();
    893       return &I;
    894     }
    895   }
    896 
    897   // See if we can turn a signed shr into an unsigned shr.
    898   if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
    899     return BinaryOperator::CreateLShr(Op0, Op1);
    900 
    901   return nullptr;
    902 }
    903