diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
commit | c0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch) | |
tree | f42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | |
parent | 344a3780b2e33f6ca763666c380202b18aab72a3 (diff) | |
download | src-c0981da47d5696fe36474fcf86b4ce03ae3ff818.tar.gz src-c0981da47d5696fe36474fcf86b4ce03ae3ff818.zip |
Vendor import of llvm-project main llvmorg-14-init-10186-gff7f2cfa959b.vendor/llvm-project/llvmorg-14-init-10186-gff7f2cfa959b
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 549 |
1 files changed, 272 insertions, 277 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index ca5e473fdecb..06421d553915 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -41,7 +41,7 @@ bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1, (Sh0->getType()->getScalarSizeInBits() - 1) + (Sh1->getType()->getScalarSizeInBits() - 1); APInt MaximalRepresentableShiftAmount = - APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + APInt::getAllOnes(ShAmt0->getType()->getScalarSizeInBits()); return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount); } @@ -172,8 +172,8 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( // There are many variants to this pattern: // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt -// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt -// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt +// c) (x & (-1 l>> MaskShAmt)) << ShiftShAmt +// d) (x & ((-1 << MaskShAmt) l>> MaskShAmt)) << ShiftShAmt // e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt // f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt // All these patterns can be simplified to just: @@ -213,11 +213,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); // (~(-1 << maskNbits)) auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); - // (-1 >> MaskShAmt) - auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); - // ((-1 << MaskShAmt) >> MaskShAmt) + // (-1 l>> MaskShAmt) + auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt)); + // ((-1 << MaskShAmt) l>> MaskShAmt) auto MaskD = - m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); + m_LShr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); Value *X; Constant *NewMask; @@ -240,7 +240,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // that shall remain in the root value (OuterShift). // An extend of an undef value becomes zero because the high bits are never - // completely unknown. Replace the the `undef` shift amounts with final + // completely unknown. Replace the `undef` shift amounts with final // shift bitwidth to ensure that the value remains undef when creating the // subsequent shift op. SumOfShAmts = Constant::replaceUndefsWith( @@ -272,7 +272,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // shall be unset in the root value (OuterShift). // An extend of an undef value becomes zero because the high bits are never - // completely unknown. Replace the the `undef` shift amounts with negated + // completely unknown. Replace the `undef` shift amounts with negated // bitwidth of innermost shift to ensure that the value remains undef when // creating the subsequent shift op. unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); @@ -346,9 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // TODO: Remove the one-use check if the other logic operand (Y) is constant. Value *X, *Y; auto matchFirstShift = [&](Value *V) { - BinaryOperator *BO; APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); - return match(V, m_BinOp(BO)) && BO->getOpcode() == ShiftOpcode && + return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && match(ConstantExpr::getAdd(C0, C1), m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); @@ -661,23 +660,22 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { - bool isLeftShift = I.getOpcode() == Instruction::Shl; - const APInt *Op1C; if (!match(Op1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. + bool IsLeftShift = I.getOpcode() == Instruction::Shl; if (I.getOpcode() != Instruction::AShr && - canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { + canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { LLVM_DEBUG( dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I << "\n"); return replaceInstUsesWith( - I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); + I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); } // See if we can simplify any instructions used by the instruction whose sole @@ -686,202 +684,72 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, unsigned TypeBits = Ty->getScalarSizeInBits(); assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); + (void)TypeBits; if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; - // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) - if (auto *TI = dyn_cast<TruncInst>(Op0)) { - // If 'shift2' is an ashr, we would have to get the sign bit into a funny - // place. Don't try to do this transformation in this case. Also, we - // require that the input operand is a shift-by-constant so that we have - // confidence that the shifts will get folded together. We could do this - // xform in more cases, but it is unlikely to be profitable. - const APInt *TrShiftAmt; - if (I.isLogicalShift() && - match(TI->getOperand(0), m_Shift(m_Value(), m_APInt(TrShiftAmt)))) { - auto *TrOp = cast<Instruction>(TI->getOperand(0)); - Type *SrcTy = TrOp->getType(); - - // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy); - // (shift2 (shift1 & 0x00FF), c2) - Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); - - // For logical shifts, the truncation has the effect of making the high - // part of the register be zeros. Emulate this by inserting an AND to - // clear the top bits as needed. This 'and' will usually be zapped by - // other xforms later if dead. - unsigned SrcSize = SrcTy->getScalarSizeInBits(); - Constant *MaskV = - ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits)); - - // The mask we constructed says what the trunc would do if occurring - // between the shifts. We want to know the effect *after* the second - // shift. We know that it is a logical shift by a constant, so adjust the - // mask as appropriate. - MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt); - // shift1 & 0x00FF - Value *And = Builder.CreateAnd(NSh, MaskV, TI->getName()); - // Return the value truncated to the interesting size. - return new TruncInst(And, Ty); - } - } - - if (Op0->hasOneUse()) { - if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { - // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - Value *V1; - const APInt *CC; - switch (Op0BO->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { - // These operators commute. - // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) - if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && - match(Op0BO->getOperand(1), m_Shr(m_Value(V1), - m_Specific(Op1)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); - // (X + (Y << C)) - Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, - Op0BO->getOperand(1)->getName()); - unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(Ty, Bits); - return BinaryOperator::CreateAnd(X, Mask); - } - - // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) - Value *Op0BOOp1 = Op0BO->getOperand(1); - if (isLeftShift && Op0BOOp1->hasOneUse() && - match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_APInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); - // X & (CC << C) - Value *XM = Builder.CreateAnd( - V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), - V1->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); - } - LLVM_FALLTHROUGH; - } - - case Instruction::Sub: { - // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && - match(Op0BO->getOperand(0), m_Shr(m_Value(V1), - m_Specific(Op1)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); - // (X + (Y << C)) - Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, - Op0BO->getOperand(0)->getName()); - unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(Ty, Bits); - return BinaryOperator::CreateAnd(X, Mask); - } - - // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) - if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && - match(Op0BO->getOperand(0), - m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_APInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); - // X & (CC << C) - Value *XM = Builder.CreateAnd( - V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), - V1->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); - } - - break; - } - } + if (!Op0->hasOneUse()) + return nullptr; - // If the operand is a bitwise operator with a constant RHS, and the - // shift is the only use, we can pull it out of the shift. - const APInt *Op0C; - if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { - if (canShiftBinOpWithConstantRHS(I, Op0BO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(Op0BO->getOperand(1)), Op1); + if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) { + // If the operand is a bitwise operator with a constant RHS, and the + // shift is the only use, we can pull it out of the shift. + const APInt *Op0C; + if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { + if (canShiftBinOpWithConstantRHS(I, Op0BO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(Op0BO->getOperand(1)), Op1); - Value *NewShift = + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); - NewShift->takeName(Op0BO); - - return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, - NewRHS); - } - } - - // If the operand is a subtract with a constant LHS, and the shift - // is the only use, we can pull it out of the shift. - // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) - if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && - match(Op0BO->getOperand(0), m_APInt(Op0C))) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(Op0BO->getOperand(0)), Op1); - - Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); NewShift->takeName(Op0BO); - return BinaryOperator::CreateSub(NewRHS, NewShift); + return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS); } } + } - // If we have a select that conditionally executes some binary operator, - // see if we can pull it the select and operator through the shift. - // - // For example, turning: - // shl (select C, (add X, C1), X), C2 - // Into: - // Y = shl X, C2 - // select C, (add Y, C1 << C2), Y - Value *Cond; - BinaryOperator *TBO; - Value *FalseVal; - if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), - m_Value(FalseVal)))) { - const APInt *C; - if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && - match(TBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, TBO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(TBO->getOperand(1)), Op1); - - Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); - Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, - NewRHS); - return SelectInst::Create(Cond, NewOp, NewShift); - } + // If we have a select that conditionally executes some binary operator, + // see if we can pull it the select and operator through the shift. + // + // For example, turning: + // shl (select C, (add X, C1), X), C2 + // Into: + // Y = shl X, C2 + // select C, (add Y, C1 << C2), Y + Value *Cond; + BinaryOperator *TBO; + Value *FalseVal; + if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), + m_Value(FalseVal)))) { + const APInt *C; + if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && + match(TBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, TBO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(TBO->getOperand(1)), Op1); + + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); + Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewOp, NewShift); } + } - BinaryOperator *FBO; - Value *TrueVal; - if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), - m_OneUse(m_BinOp(FBO))))) { - const APInt *C; - if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && - match(FBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, FBO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(FBO->getOperand(1)), Op1); - - Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); - Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, - NewRHS); - return SelectInst::Create(Cond, NewShift, NewOp); - } + BinaryOperator *FBO; + Value *TrueVal; + if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), + m_OneUse(m_BinOp(FBO))))) { + const APInt *C; + if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && + match(FBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, FBO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(FBO->getOperand(1)), Op1); + + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); + Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewShift, NewOp); } } @@ -908,41 +776,41 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); - const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { - unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); - // shl (zext X), ShAmt --> zext (shl X, ShAmt) + // shl (zext X), C --> zext (shl X, C) // This is only valid if X would have zeros shifted out. Value *X; if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); - if (ShAmt < SrcWidth && - MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) - return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); + if (ShAmtC < SrcWidth && + MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), 0, &I)) + return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); } // (X >> C) << C --> X & (-1 << C) if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - const APInt *ShOp1; - if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1)))) && - ShOp1->ult(BitWidth)) { - unsigned ShrAmt = ShOp1->getZExtValue(); - if (ShrAmt < ShAmt) { - // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + const APInt *C1; + if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); return NewShl; } - if (ShrAmt > ShAmt) { - // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>?exact C1) << C --> X >>?exact (C1 - C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); auto *NewShr = BinaryOperator::Create( cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); NewShr->setIsExact(true); @@ -950,49 +818,135 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } } - if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(ShOp1)))) && - ShOp1->ult(BitWidth)) { - unsigned ShrAmt = ShOp1->getZExtValue(); - if (ShrAmt < ShAmt) { - // If C1 < C2: (X >>? C1) << C2 --> X << (C2 - C1) & (-1 << C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); Builder.Insert(NewShl); - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - if (ShrAmt > ShAmt) { - // If C1 > C2: (X >>? C1) << C2 --> X >>? (C1 - C2) & (-1 << C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>? C1) << C --> (X >>? (C1 - C)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); auto *OldShr = cast<BinaryOperator>(Op0); auto *NewShr = BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff); NewShr->setIsExact(OldShr->isExact()); Builder.Insert(NewShr); - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask)); } } - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { - unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Similar to above, but look through an intermediate trunc instruction. + BinaryOperator *Shr; + if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) && + match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) { + // The larger shift direction survives through the transform. + unsigned ShrAmtC = C1->getZExtValue(); + unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC; + Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff); + auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl; + + // If C1 > C: + // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C) + // If C > C1: + // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C) + Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff"); + Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff"); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask)); + } + + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + unsigned AmtSum = ShAmtC + C1->getZExtValue(); // Oversized shifts are simplified to zero in InstSimplify. if (AmtSum < BitWidth) // (X << C1) << C2 --> X << (C1 + C2) return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); } + // If we have an opposite shift by the same amount, we may be able to + // reorder binops and shifts to eliminate math/logic. + auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) { + switch (BinOpcode) { + default: + return false; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Sub: + // NOTE: Sub is not commutable and the tranforms below may not be valid + // when the shift-right is operand 1 (RHS) of the sub. + return true; + } + }; + BinaryOperator *Op0BO; + if (match(Op0, m_OneUse(m_BinOp(Op0BO))) && + isSuitableBinOpcode(Op0BO->getOpcode())) { + // Commute so shift-right is on LHS of the binop. + // (Y bop (X >> C)) << C -> ((X >> C) bop Y) << C + // (Y bop ((X >> C) & CC)) << C -> (((X >> C) & CC) bop Y) << C + Value *Shr = Op0BO->getOperand(0); + Value *Y = Op0BO->getOperand(1); + Value *X; + const APInt *CC; + if (Op0BO->isCommutative() && Y->hasOneUse() && + (match(Y, m_Shr(m_Value(), m_Specific(Op1))) || + match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))), + m_APInt(CC))))) + std::swap(Shr, Y); + + // ((X >> C) bop Y) << C -> (X bop (Y << C)) & (~0 << C) + if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // (X bop (Y << C)) + Value *B = + Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName()); + unsigned Op1Val = C->getLimitedValue(BitWidth); + APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val); + Constant *Mask = ConstantInt::get(Ty, Bits); + return BinaryOperator::CreateAnd(B, Mask); + } + + // (((X >> C) & CC) bop Y) << C -> (X & (CC << C)) bop (Y << C) + if (match(Shr, + m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))), + m_APInt(CC))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // X & (CC << C) + Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)), + X->getName() + ".mask"); + return BinaryOperator::Create(Op0BO->getOpcode(), M, YS); + } + } + + // (C1 - X) << C --> (C1 << C) - (X << C) + if (match(Op0, m_OneUse(m_Sub(m_APInt(C1), m_Value(X))))) { + Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C)); + Value *NewShift = Builder.CreateShl(X, Op1); + return BinaryOperator::CreateSub(NewLHS, NewShift); + } + // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { + MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, + &I)) { I.setHasNoUnsignedWrap(); return &I; } // If the shifted-out value is all signbits, then this is a NSW shift. - if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { + if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { I.setHasNoSignedWrap(); return &I; } @@ -1048,12 +1002,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); - const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { - unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); unsigned BitWidth = Ty->getScalarSizeInBits(); auto *II = dyn_cast<IntrinsicInst>(Op0); - if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && + if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && (II->getIntrinsicID() == Intrinsic::ctlz || II->getIntrinsicID() == Intrinsic::cttz || II->getIntrinsicID() == Intrinsic::ctpop)) { @@ -1067,78 +1021,81 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } Value *X; - const APInt *ShOp1; - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { - if (ShOp1->ult(ShAmt)) { - unsigned ShlAmt = ShOp1->getZExtValue(); - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + const APInt *C1; + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + if (C1->ult(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShlAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) + // (X <<nuw C1) >>u C --> X >>u (C - C1) auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); NewLShr->setIsExact(I.isExact()); return NewLShr; } - // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) + // (X << C1) >>u C --> (X >>u (C - C1)) & (-1 >> C) Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); } - if (ShOp1->ugt(ShAmt)) { - unsigned ShlAmt = ShOp1->getZExtValue(); - Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + if (C1->ugt(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) + // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(true); return NewShl; } - // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) + // (X << C1) >>u C --> X << (C1 - C) & (-1 >> C) Value *NewShl = Builder.CreateShl(X, ShiftDiff); - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - assert(*ShOp1 == ShAmt); + assert(*C1 == ShAmtC); // (X << C) >>u C --> X & (-1 >>u C) - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { - assert(ShAmt < X->getType()->getScalarSizeInBits() && + assert(ShAmtC < X->getType()->getScalarSizeInBits() && "Big shift not simplified to zero?"); // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN - Value *NewLShr = Builder.CreateLShr(X, ShAmt); + Value *NewLShr = Builder.CreateLShr(X, ShAmtC); return new ZExtInst(NewLShr, Ty); } - if (match(Op0, m_SExt(m_Value(X))) && - (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { - // Are we moving the sign bit to the low bit and widening with high zeros? + if (match(Op0, m_SExt(m_Value(X)))) { unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); - if (ShAmt == BitWidth - 1) { - // lshr (sext i1 X to iN), N-1 --> zext X to iN - if (SrcTyBitWidth == 1) - return new ZExtInst(X, Ty); + // lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0) + if (SrcTyBitWidth == 1) { + auto *NewC = ConstantInt::get( + Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } - // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN - if (Op0->hasOneUse()) { + if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) && + Op0->hasOneUse()) { + // Are we moving the sign bit to the low bit and widening with high + // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN + if (ShAmtC == BitWidth - 1) { Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); return new ZExtInst(NewLShr, Ty); } - } - // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN - if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) { - // The new shift amount can't be more than the narrow source type. - unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1); - Value *AShr = Builder.CreateAShr(X, NewShAmt); - return new ZExtInst(AShr, Ty); + // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN + if (ShAmtC == BitWidth - SrcTyBitWidth) { + // The new shift amount can't be more than the narrow source type. + unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1); + Value *AShr = Builder.CreateAShr(X, NewShAmt); + return new ZExtInst(AShr, Ty); + } } } Value *Y; - if (ShAmt == BitWidth - 1) { + if (ShAmtC == BitWidth - 1) { // lshr i32 or(X,-X), 31 --> zext (X != 0) if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X))))) return new ZExtInst(Builder.CreateIsNotNull(X), Ty); @@ -1150,32 +1107,55 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // Check if a number is negative and odd: // lshr i32 (srem X, 2), 31 --> and (X >> 31), X if (match(Op0, m_OneUse(m_SRem(m_Value(X), m_SpecificInt(2))))) { - Value *Signbit = Builder.CreateLShr(X, ShAmt); + Value *Signbit = Builder.CreateLShr(X, ShAmtC); return BinaryOperator::CreateAnd(Signbit, X); } } - if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { - unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // (X >>u C1) >>u C --> X >>u (C1 + C) + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) { // Oversized shifts are simplified to zero in InstSimplify. + unsigned AmtSum = ShAmtC + C1->getZExtValue(); if (AmtSum < BitWidth) - // (X >>u C1) >>u C2 --> X >>u (C1 + C2) return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); } + Instruction *TruncSrc; + if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) && + match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + + // If the combined shift fits in the source width: + // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC + // + // If the first shift covers the number of bits truncated, then the + // mask instruction is eliminated (and so the use check is relaxed). + if (AmtSum < SrcWidth && + (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) { + Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift"); + Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName()); + + // If the first shift does not cover the number of bits truncated, then + // we require a mask to get rid of high bits in the result. + APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC)); + } + } + // Look for a "splat" mul pattern - it replicates bits across each half of // a value, so a right shift is just a mask of the low bits: // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1 // TODO: Generalize to allow more than just half-width shifts? const APInt *MulC; if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && - ShAmt * 2 == BitWidth && (*MulC - 1).isPowerOf2() && - MulC->logBase2() == ShAmt) + ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { I.setIsExact(); return &I; } @@ -1346,6 +1326,22 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { } } + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` + // as the pattern to splat the lowest bit. + // FIXME: iff X is already masked, we don't need the one-use check. + Value *X; + if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) && + match(Op0, m_OneUse(m_Shl(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1))))) { + Constant *Mask = ConstantInt::get(Ty, 1); + // Retain the knowledge about the ignored lanes. + Mask = Constant::mergeUndefsWith( + Constant::mergeUndefsWith(Mask, cast<Constant>(Op1)), + cast<Constant>(cast<Instruction>(Op0)->getOperand(1))); + X = Builder.CreateAnd(X, Mask); + return BinaryOperator::CreateNeg(X); + } + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) return R; @@ -1354,7 +1350,6 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return BinaryOperator::CreateLShr(Op0, Op1); // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 - Value *X; if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { // Note that we must drop 'exact'-ness of the shift! // Note that we can't keep undef's in -1 vector constant! |