diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
15 files changed, 3339 insertions, 2058 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index b68efc993723..91ca44e0f11e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -797,7 +797,7 @@ static Value *checkForNegativeOperand(BinaryOperator &I, // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2)) // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2)) if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) - if (C1->countTrailingZeros() == 0) + if (C1->countr_zero() == 0) if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { Value *NewOr = Builder.CreateOr(Z, ~(*C2)); return Builder.CreateSub(RHS, NewOr, "sub"); @@ -880,8 +880,15 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1); // ~X + C --> (C-1) - X - if (match(Op0, m_Not(m_Value(X)))) - return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); + if (match(Op0, m_Not(m_Value(X)))) { + // ~X + C has NSW and (C-1) won't oveflow => (C-1)-X can have NSW + auto *COne = ConstantInt::get(Op1C->getType(), 1); + bool WillNotSOV = willNotOverflowSignedSub(Op1C, COne, Add); + BinaryOperator *Res = + BinaryOperator::CreateSub(ConstantExpr::getSub(Op1C, COne), X); + Res->setHasNoSignedWrap(Add.hasNoSignedWrap() && WillNotSOV); + return Res; + } // (iN X s>> (N - 1)) + 1 --> zext (X > -1) const APInt *C; @@ -975,6 +982,16 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } + // Fold (add (zext (add X, -1)), 1) -> (zext X) if X is non-zero. + // TODO: There's a general form for any constant on the outer add. + if (C->isOne()) { + if (match(Op0, m_ZExt(m_Add(m_Value(X), m_AllOnes())))) { + const SimplifyQuery Q = SQ.getWithInstruction(&Add); + if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ZExtInst(X, Ty); + } + } + return nullptr; } @@ -1366,6 +1383,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *X = foldNoWrapAdd(I, Builder)) return X; + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) @@ -1421,6 +1441,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { Value *Sub = Builder.CreateSub(A, B); return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2)); } + + // Canonicalize a constant sub operand as an add operand for better folding: + // (C1 - A) + B --> (B - A) + C1 + if (match(&I, m_c_Add(m_OneUse(m_Sub(m_ImmConstant(C1), m_Value(A))), + m_Value(B)))) { + Value *Sub = Builder.CreateSub(B, A, "reass.sub"); + return BinaryOperator::CreateAdd(Sub, C1); + } } // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) @@ -1439,7 +1467,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) && - C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countLeadingZeros())) { + C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) { Constant *NewMask = ConstantInt::get(RHS->getType(), *C1 - 1); return BinaryOperator::CreateAnd(A, NewMask); } @@ -1451,6 +1479,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A)))))) return new ZExtInst(B, LHS->getType()); + // zext(A) + sext(A) --> 0 if A is i1 + if (match(&I, m_c_BinOp(m_ZExt(m_Value(A)), m_SExt(m_Deferred(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1515,7 +1548,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { const APInt *NegPow2C; if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))), m_Value(B)))) { - Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros()); + Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countr_zero()); Value *Shl = Builder.CreateShl(A, ShiftAmtC); return BinaryOperator::CreateSub(B, Shl); } @@ -1536,6 +1569,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *Ashr = foldAddToAshr(I)) return Ashr; + // min(A, B) + max(A, B) => A + B. + if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)), + m_c_SMin(m_Deferred(A), m_Deferred(B))), + m_c_Add(m_UMax(m_Value(A), m_Value(B)), + m_c_UMin(m_Deferred(A), m_Deferred(B)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I); + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. @@ -1575,6 +1615,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return Changed ? &I : nullptr; } @@ -1786,6 +1832,20 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); } + // minumum(X, Y) + maximum(X, Y) => X + Y. + if (match(&I, + m_c_FAdd(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), + m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), + m_Deferred(Y))))) { + BinaryOperator *Result = BinaryOperator::CreateFAddFMF(X, Y, &I); + // We cannot preserve ninf if nnan flag is not set. + // If X is NaN and Y is Inf then in original program we had NaN + NaN, + // while in optimized version NaN + Inf and this is a poison with ninf flag. + if (!Result->hasNoNaNs()) + Result->setHasNoInfs(false); + return Result; + } + return nullptr; } @@ -1956,8 +2016,17 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { Constant *C2; // C-(X+C2) --> (C-C2)-X - if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) - return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); + if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) { + // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW + // => (C-C2)-X can have NSW + bool WillNotSOV = willNotOverflowSignedSub(C, C2, I); + BinaryOperator *Res = + BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); + auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); + Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() && + WillNotSOV); + return Res; + } } auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * { @@ -2325,7 +2394,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { const APInt *AddC, *AndC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { - unsigned Cttz = AddC->countTrailingZeros(); + unsigned Cttz = AddC->countr_zero(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); if ((HighMask & *AndC).isZero()) return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC))); @@ -2388,6 +2457,21 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return replaceInstUsesWith(I, Mul); } + // max(X,Y) nsw/nuw - min(X,Y) --> abs(X nsw - Y) + if (match(Op0, m_OneUse(m_c_SMax(m_Value(X), m_Value(Y)))) && + match(Op1, m_OneUse(m_c_SMin(m_Specific(X), m_Specific(Y))))) { + if (I.hasNoUnsignedWrap() || I.hasNoSignedWrap()) { + Value *Sub = + Builder.CreateSub(X, Y, "sub", /*HasNUW=*/false, /*HasNSW=*/true); + Value *Call = + Builder.CreateBinaryIntrinsic(Intrinsic::abs, Sub, Builder.getTrue()); + return replaceInstUsesWith(I, Call); + } + } + + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return TryToNarrowDeduceFlags(); } @@ -2567,7 +2651,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // Note that if this fsub was really an fneg, the fadd with -0.0 will get // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. - if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { + if (I.hasNoSignedZeros() || cannotBeNegativeZero(Op0, SQ.DL, SQ.TLI)) { if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 97a001b2ed32..8a1fb6b7f17e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -625,7 +625,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, return RHS; } - if (Mask & BMask_Mixed) { + if (Mask & (BMask_Mixed | BMask_NotMixed)) { + // Mixed: // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -636,24 +637,50 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. + + // NotMixed: + // (icmp ne (A & B), C) & (icmp ne (A & D), E) + // -> (icmp ne (A & (B & D)), (C & E)) + // Check the intersection (B & D) for inequality. + // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B + // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the + // B and the D, don't contradict. + // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous + // operation should delete these icmps if it hadn't been met. + const APInt *OldConstC, *OldConstE; if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; - const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; - const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; + auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * { + CC = IsNot ? CmpInst::getInversePredicate(CC) : CC; + const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE; - // If there is a conflict, we should actually return a false for the - // whole construct. - if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return ConstantInt::get(LHS->getType(), !IsAnd); + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd); - Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewAnd = Builder.CreateAnd(A, NewOr1); - Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); - return Builder.CreateICmp(NewCC, NewAnd, NewOr2); - } + if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB)) + return nullptr; + APInt BD, CE; + if (IsNot) { + BD = *ConstB & *ConstD; + CE = ConstC & ConstE; + } else { + BD = *ConstB | *ConstD; + CE = ConstC | ConstE; + } + Value *NewAnd = Builder.CreateAnd(A, BD); + Value *CEVal = ConstantInt::get(A->getType(), CE); + return Builder.CreateICmp(CC, CEVal, NewAnd); + }; + + if (Mask & BMask_Mixed) + return FoldBMixed(NewCC, false); + if (Mask & BMask_NotMixed) // can be else also + return FoldBMixed(NewCC, true); + } return nullptr; } @@ -928,6 +955,108 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, return nullptr; } +/// Try to fold (icmp(A & B) == 0) & (icmp(A & D) != E) into (icmp A u< D) iff +/// B is a contiguous set of ones starting from the most significant bit +/// (negative power of 2), D and E are equal, and D is a contiguous set of ones +/// starting at the most significant zero bit in B. Parameter B supports masking +/// using undef/poison in either scalar or vector values. +static Value *foldNegativePower2AndShiftedMask( + Value *A, Value *B, Value *D, Value *E, ICmpInst::Predicate PredL, + ICmpInst::Predicate PredR, InstCombiner::BuilderTy &Builder) { + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); + if (PredL != ICmpInst::ICMP_EQ || PredR != ICmpInst::ICMP_NE) + return nullptr; + + if (!match(B, m_NegatedPower2()) || !match(D, m_ShiftedMask()) || + !match(E, m_ShiftedMask())) + return nullptr; + + // Test scalar arguments for conversion. B has been validated earlier to be a + // negative power of two and thus is guaranteed to have one or more contiguous + // ones starting from the MSB followed by zero or more contiguous zeros. D has + // been validated earlier to be a shifted set of one or more contiguous ones. + // In order to match, B leading ones and D leading zeros should be equal. The + // predicate that B be a negative power of 2 prevents the condition of there + // ever being zero leading ones. Thus 0 == 0 cannot occur. The predicate that + // D always be a shifted mask prevents the condition of D equaling 0. This + // prevents matching the condition where B contains the maximum number of + // leading one bits (-1) and D contains the maximum number of leading zero + // bits (0). + auto isReducible = [](const Value *B, const Value *D, const Value *E) { + const APInt *BCst, *DCst, *ECst; + return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) && + match(E, m_APInt(ECst)) && *DCst == *ECst && + (isa<UndefValue>(B) || + (BCst->countLeadingOnes() == DCst->countLeadingZeros())); + }; + + // Test vector type arguments for conversion. + if (const auto *BVTy = dyn_cast<VectorType>(B->getType())) { + const auto *BFVTy = dyn_cast<FixedVectorType>(BVTy); + const auto *BConst = dyn_cast<Constant>(B); + const auto *DConst = dyn_cast<Constant>(D); + const auto *EConst = dyn_cast<Constant>(E); + + if (!BFVTy || !BConst || !DConst || !EConst) + return nullptr; + + for (unsigned I = 0; I != BFVTy->getNumElements(); ++I) { + const auto *BElt = BConst->getAggregateElement(I); + const auto *DElt = DConst->getAggregateElement(I); + const auto *EElt = EConst->getAggregateElement(I); + + if (!BElt || !DElt || !EElt) + return nullptr; + if (!isReducible(BElt, DElt, EElt)) + return nullptr; + } + } else { + // Test scalar type arguments for conversion. + if (!isReducible(B, D, E)) + return nullptr; + } + return Builder.CreateICmp(ICmpInst::ICMP_ULT, A, D); +} + +/// Try to fold ((icmp X u< P) & (icmp(X & M) != M)) or ((icmp X s> -1) & +/// (icmp(X & M) != M)) into (icmp X u< M). Where P is a power of 2, M < P, and +/// M is a contiguous shifted mask starting at the right most significant zero +/// bit in P. SGT is supported as when P is the largest representable power of +/// 2, an earlier optimization converts the expression into (icmp X s> -1). +/// Parameter P supports masking using undef/poison in either scalar or vector +/// values. +static Value *foldPowerOf2AndShiftedMask(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool JoinedByAnd, + InstCombiner::BuilderTy &Builder) { + if (!JoinedByAnd) + return nullptr; + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; + ICmpInst::Predicate CmpPred0 = Cmp0->getPredicate(), + CmpPred1 = Cmp1->getPredicate(); + // Assuming P is a 2^n, getMaskedTypeForICmpPair will normalize (icmp X u< + // 2^n) into (icmp (X & ~(2^n-1)) == 0) and (icmp X s> -1) into (icmp (X & + // SignMask) == 0). + std::optional<std::pair<unsigned, unsigned>> MaskPair = + getMaskedTypeForICmpPair(A, B, C, D, E, Cmp0, Cmp1, CmpPred0, CmpPred1); + if (!MaskPair) + return nullptr; + + const auto compareBMask = BMask_NotMixed | BMask_NotAllOnes; + unsigned CmpMask0 = MaskPair->first; + unsigned CmpMask1 = MaskPair->second; + if ((CmpMask0 & Mask_AllZeros) && (CmpMask1 == compareBMask)) { + if (Value *V = foldNegativePower2AndShiftedMask(A, B, D, E, CmpPred0, + CmpPred1, Builder)) + return V; + } else if ((CmpMask0 == compareBMask) && (CmpMask1 & Mask_AllZeros)) { + if (Value *V = foldNegativePower2AndShiftedMask(A, D, B, C, CmpPred1, + CmpPred0, Builder)) + return V; + } + return nullptr; +} + /// Commuted variants are assumed to be handled by calling this function again /// with the parameters swapped. static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, @@ -1313,9 +1442,44 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, return Right; } + // Turn at least two fcmps with constants into llvm.is.fpclass. + // + // If we can represent a combined value test with one class call, we can + // potentially eliminate 4-6 instructions. If we can represent a test with a + // single fcmp with fneg and fabs, that's likely a better canonical form. + if (LHS->hasOneUse() && RHS->hasOneUse()) { + auto [ClassValRHS, ClassMaskRHS] = + fcmpToClassTest(PredR, *RHS->getFunction(), RHS0, RHS1); + if (ClassValRHS) { + auto [ClassValLHS, ClassMaskLHS] = + fcmpToClassTest(PredL, *LHS->getFunction(), LHS0, LHS1); + if (ClassValLHS == ClassValRHS) { + unsigned CombinedMask = IsAnd ? (ClassMaskLHS & ClassMaskRHS) + : (ClassMaskLHS | ClassMaskRHS); + return Builder.CreateIntrinsic( + Intrinsic::is_fpclass, {ClassValLHS->getType()}, + {ClassValLHS, Builder.getInt32(CombinedMask)}); + } + } + } + return nullptr; } +/// Match an fcmp against a special value that performs a test possible by +/// llvm.is.fpclass. +static bool matchIsFPClassLikeFCmp(Value *Op, Value *&ClassVal, + uint64_t &ClassMask) { + auto *FCmp = dyn_cast<FCmpInst>(Op); + if (!FCmp || !FCmp->hasOneUse()) + return false; + + std::tie(ClassVal, ClassMask) = + fcmpToClassTest(FCmp->getPredicate(), *FCmp->getParent()->getParent(), + FCmp->getOperand(0), FCmp->getOperand(1)); + return ClassVal != nullptr; +} + /// or (is_fpclass x, mask0), (is_fpclass x, mask1) /// -> is_fpclass x, (mask0 | mask1) /// and (is_fpclass x, mask0), (is_fpclass x, mask1) @@ -1324,13 +1488,25 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, /// -> is_fpclass x, (mask0 ^ mask1) Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, Value *Op0, Value *Op1) { - Value *ClassVal; + Value *ClassVal0 = nullptr; + Value *ClassVal1 = nullptr; uint64_t ClassMask0, ClassMask1; - if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( - m_Value(ClassVal), m_ConstantInt(ClassMask0)))) && + // Restrict to folding one fcmp into one is.fpclass for now, don't introduce a + // new class. + // + // TODO: Support forming is.fpclass out of 2 separate fcmps when codegen is + // better. + + bool IsLHSClass = + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( + m_Value(ClassVal0), m_ConstantInt(ClassMask0)))); + bool IsRHSClass = match(Op1, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( - m_Specific(ClassVal), m_ConstantInt(ClassMask1))))) { + m_Value(ClassVal1), m_ConstantInt(ClassMask1)))); + if ((((IsLHSClass || matchIsFPClassLikeFCmp(Op0, ClassVal0, ClassMask0)) && + (IsRHSClass || matchIsFPClassLikeFCmp(Op1, ClassVal1, ClassMask1)))) && + ClassVal0 == ClassVal1) { unsigned NewClassMask; switch (BO.getOpcode()) { case Instruction::And: @@ -1346,11 +1522,24 @@ Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, llvm_unreachable("not a binary logic operator"); } - // TODO: Also check for special fcmps - auto *II = cast<IntrinsicInst>(Op0); - II->setArgOperand( - 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); - return replaceInstUsesWith(BO, II); + if (IsLHSClass) { + auto *II = cast<IntrinsicInst>(Op0); + II->setArgOperand( + 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); + return replaceInstUsesWith(BO, II); + } + + if (IsRHSClass) { + auto *II = cast<IntrinsicInst>(Op1); + II->setArgOperand( + 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); + return replaceInstUsesWith(BO, II); + } + + CallInst *NewClass = + Builder.CreateIntrinsic(Intrinsic::is_fpclass, {ClassVal0->getType()}, + {ClassVal0, Builder.getInt32(NewClassMask)}); + return replaceInstUsesWith(BO, NewClass); } return nullptr; @@ -1523,6 +1712,39 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding"); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // fold bitwise(A >> BW - 1, zext(icmp)) (BW is the scalar bits of the + // type of A) + // -> bitwise(zext(A < 0), zext(icmp)) + // -> zext(bitwise(A < 0, icmp)) + auto FoldBitwiseICmpZeroWithICmp = [&](Value *Op0, + Value *Op1) -> Instruction * { + ICmpInst::Predicate Pred; + Value *A; + bool IsMatched = + match(Op0, + m_OneUse(m_LShr( + m_Value(A), + m_SpecificInt(Op0->getType()->getScalarSizeInBits() - 1)))) && + match(Op1, m_OneUse(m_ZExt(m_ICmp(Pred, m_Value(), m_Value())))); + + if (!IsMatched) + return nullptr; + + auto *ICmpL = + Builder.CreateICmpSLT(A, Constant::getNullValue(A->getType())); + auto *ICmpR = cast<ZExtInst>(Op1)->getOperand(0); + auto *BitwiseOp = Builder.CreateBinOp(LogicOpc, ICmpL, ICmpR); + + return new ZExtInst(BitwiseOp, Op0->getType()); + }; + + if (auto *Ret = FoldBitwiseICmpZeroWithICmp(Op0, Op1)) + return Ret; + + if (auto *Ret = FoldBitwiseICmpZeroWithICmp(Op1, Op0)) + return Ret; + CastInst *Cast0 = dyn_cast<CastInst>(Op0); if (!Cast0) return nullptr; @@ -1906,16 +2128,16 @@ static Instruction *canonicalizeLogicFirst(BinaryOperator &I, return nullptr; unsigned Width = Ty->getScalarSizeInBits(); - unsigned LastOneMath = Width - C2->countTrailingZeros(); + unsigned LastOneMath = Width - C2->countr_zero(); switch (OpC) { case Instruction::And: - if (C->countLeadingOnes() < LastOneMath) + if (C->countl_one() < LastOneMath) return nullptr; break; case Instruction::Xor: case Instruction::Or: - if (C->countLeadingZeros() < LastOneMath) + if (C->countl_zero() < LastOneMath) return nullptr; break; default: @@ -1923,7 +2145,51 @@ static Instruction *canonicalizeLogicFirst(BinaryOperator &I, } Value *NewBinOp = Builder.CreateBinOp(OpC, X, ConstantInt::get(Ty, *C)); - return BinaryOperator::CreateAdd(NewBinOp, ConstantInt::get(Ty, *C2)); + return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, NewBinOp, + ConstantInt::get(Ty, *C2), Op0); +} + +// binop(shift(ShiftedC1, ShAmt), shift(ShiftedC2, add(ShAmt, AddC))) -> +// shift(binop(ShiftedC1, shift(ShiftedC2, AddC)), ShAmt) +// where both shifts are the same and AddC is a valid shift amount. +Instruction *InstCombinerImpl::foldBinOpOfDisplacedShifts(BinaryOperator &I) { + assert((I.isBitwiseLogicOp() || I.getOpcode() == Instruction::Add) && + "Unexpected opcode"); + + Value *ShAmt; + Constant *ShiftedC1, *ShiftedC2, *AddC; + Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (!match(&I, + m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)), + m_Shift(m_ImmConstant(ShiftedC2), + m_Add(m_Deferred(ShAmt), m_ImmConstant(AddC)))))) + return nullptr; + + // Make sure the add constant is a valid shift amount. + if (!match(AddC, + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(BitWidth, BitWidth)))) + return nullptr; + + // Avoid constant expressions. + auto *Op0Inst = dyn_cast<Instruction>(I.getOperand(0)); + auto *Op1Inst = dyn_cast<Instruction>(I.getOperand(1)); + if (!Op0Inst || !Op1Inst) + return nullptr; + + // Both shifts must be the same. + Instruction::BinaryOps ShiftOp = + static_cast<Instruction::BinaryOps>(Op0Inst->getOpcode()); + if (ShiftOp != Op1Inst->getOpcode()) + return nullptr; + + // For adds, only left shifts are supported. + if (I.getOpcode() == Instruction::Add && ShiftOp != Instruction::Shl) + return nullptr; + + Value *NewC = Builder.CreateBinOp( + I.getOpcode(), ShiftedC1, Builder.CreateBinOp(ShiftOp, ShiftedC2, AddC)); + return BinaryOperator::Create(ShiftOp, NewC, ShAmt); } // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches @@ -1964,6 +2230,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *X, *Y; @@ -2033,7 +2302,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) { // If we add zeros to every bit below a mask, the add has no effect: // (X + AddC) & LowMaskC --> X & LowMaskC - unsigned Ctlz = C->countLeadingZeros(); + unsigned Ctlz = C->countl_zero(); APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); if ((*AddC & LowMask).isZero()) return BinaryOperator::CreateAnd(X, Op1); @@ -2150,7 +2419,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { const APInt *C3 = C; Value *X; if (C3->isPowerOf2()) { - Constant *Log2C3 = ConstantInt::get(Ty, C3->countTrailingZeros()); + Constant *Log2C3 = ConstantInt::get(Ty, C3->countr_zero()); if (match(Op0, m_OneUse(m_LShr(m_Shl(m_ImmConstant(C1), m_Value(X)), m_ImmConstant(C2)))) && match(C1, m_Power2())) { @@ -2407,6 +2676,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) return Folded; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + return nullptr; } @@ -2718,34 +2990,47 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, return nullptr; } -// (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1) -// (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1) -static Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, - bool IsAnd, bool IsLogical, - IRBuilderBase &Builder) { +// (icmp eq X, C) | (icmp ult Other, (X - C)) -> (icmp ule Other, (X - (C + 1))) +// (icmp ne X, C) & (icmp uge Other, (X - C)) -> (icmp ugt Other, (X - (C + 1))) +static Value *foldAndOrOfICmpEqConstantAndICmp(ICmpInst *LHS, ICmpInst *RHS, + bool IsAnd, bool IsLogical, + IRBuilderBase &Builder) { + Value *LHS0 = LHS->getOperand(0); + Value *RHS0 = RHS->getOperand(0); + Value *RHS1 = RHS->getOperand(1); + ICmpInst::Predicate LPred = IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); ICmpInst::Predicate RPred = IsAnd ? RHS->getInversePredicate() : RHS->getPredicate(); - Value *LHS0 = LHS->getOperand(0); - if (LPred != ICmpInst::ICMP_EQ || !match(LHS->getOperand(1), m_Zero()) || + + const APInt *CInt; + if (LPred != ICmpInst::ICMP_EQ || + !match(LHS->getOperand(1), m_APIntAllowUndef(CInt)) || !LHS0->getType()->isIntOrIntVectorTy() || !(LHS->hasOneUse() || RHS->hasOneUse())) return nullptr; + auto MatchRHSOp = [LHS0, CInt](const Value *RHSOp) { + return match(RHSOp, + m_Add(m_Specific(LHS0), m_SpecificIntAllowUndef(-*CInt))) || + (CInt->isZero() && RHSOp == LHS0); + }; + Value *Other; - if (RPred == ICmpInst::ICMP_ULT && RHS->getOperand(1) == LHS0) - Other = RHS->getOperand(0); - else if (RPred == ICmpInst::ICMP_UGT && RHS->getOperand(0) == LHS0) - Other = RHS->getOperand(1); + if (RPred == ICmpInst::ICMP_ULT && MatchRHSOp(RHS1)) + Other = RHS0; + else if (RPred == ICmpInst::ICMP_UGT && MatchRHSOp(RHS0)) + Other = RHS1; else return nullptr; if (IsLogical) Other = Builder.CreateFreeze(Other); + return Builder.CreateICmp( IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE, - Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())), + Builder.CreateSub(LHS0, ConstantInt::get(LHS0->getType(), *CInt + 1)), Other); } @@ -2792,12 +3077,12 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return V; if (Value *V = - foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, IsLogical, Builder)) + foldAndOrOfICmpEqConstantAndICmp(LHS, RHS, IsAnd, IsLogical, Builder)) return V; // We can treat logical like bitwise here, because both operands are used on // the LHS, and as such poison from both will propagate. - if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, - /*IsLogical*/ false, Builder)) + if (Value *V = foldAndOrOfICmpEqConstantAndICmp(RHS, LHS, IsAnd, + /*IsLogical*/ false, Builder)) return V; if (Value *V = @@ -2836,6 +3121,9 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldIsPowerOf2(LHS, RHS, IsAnd, Builder)) return V; + if (Value *V = foldPowerOf2AndShiftedMask(LHS, RHS, IsAnd, Builder)) + return V; + // TODO: Verify whether this is safe for logical and/or. if (!IsLogical) { if (Value *X = foldUnsignedUnderflowCheck(LHS, RHS, IsAnd, Q, Builder)) @@ -2849,7 +3137,7 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + // TODO: Remove this and below when foldLogOpOfMaskedICmps can handle undefs. if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && PredL == PredR && match(LHS1, m_ZeroInt()) && match(RHS1, m_ZeroInt()) && LHS0->getType() == RHS0->getType()) { @@ -2858,6 +3146,16 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Constant::getNullValue(NewOr->getType())); } + // (icmp ne A, -1) | (icmp ne B, -1) --> (icmp ne (A&B), -1) + // (icmp eq A, -1) & (icmp eq B, -1) --> (icmp eq (A&B), -1) + if (!IsLogical && PredL == (IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + PredL == PredR && match(LHS1, m_AllOnes()) && match(RHS1, m_AllOnes()) && + LHS0->getType() == RHS0->getType()) { + Value *NewAnd = Builder.CreateAnd(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewAnd, + Constant::getAllOnesValue(LHS0->getType())); + } + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; @@ -2998,6 +3296,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Concat = matchOrConcat(I, Builder)) return replaceInstUsesWith(I, Concat); + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && @@ -3416,6 +3717,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) return Folded; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + return nullptr; } @@ -3715,6 +4019,24 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor, return nullptr; } +static bool canFreelyInvert(InstCombiner &IC, Value *Op, + Instruction *IgnoredUser) { + auto *I = dyn_cast<Instruction>(Op); + return I && IC.isFreeToInvert(I, /*WillInvertAllUses=*/true) && + InstCombiner::canFreelyInvertAllUsersOf(I, IgnoredUser); +} + +static Value *freelyInvert(InstCombinerImpl &IC, Value *Op, + Instruction *IgnoredUser) { + auto *I = cast<Instruction>(Op); + IC.Builder.SetInsertPoint(&*I->getInsertionPointAfterDef()); + Value *NotOp = IC.Builder.CreateNot(Op, Op->getName() + ".not"); + Op->replaceUsesWithIf(NotOp, + [NotOp](Use &U) { return U.getUser() != NotOp; }); + IC.freelyInvertAllUsersOf(NotOp, IgnoredUser); + return NotOp; +} + // Transform // z = ~(x &/| y) // into: @@ -3739,28 +4061,11 @@ bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) { return false; // And can the operands be adapted? - for (Value *Op : {Op0, Op1}) - if (!(InstCombiner::isFreeToInvert(Op, /*WillInvertAllUses=*/true) && - (match(Op, m_ImmConstant()) || - (isa<Instruction>(Op) && - InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op), - /*IgnoredUser=*/&I))))) - return false; + if (!canFreelyInvert(*this, Op0, &I) || !canFreelyInvert(*this, Op1, &I)) + return false; - for (Value **Op : {&Op0, &Op1}) { - Value *NotOp; - if (auto *C = dyn_cast<Constant>(*Op)) { - NotOp = ConstantExpr::getNot(C); - } else { - Builder.SetInsertPoint( - &*cast<Instruction>(*Op)->getInsertionPointAfterDef()); - NotOp = Builder.CreateNot(*Op, (*Op)->getName() + ".not"); - (*Op)->replaceUsesWithIf( - NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; }); - freelyInvertAllUsersOf(NotOp, /*IgnoredUser=*/&I); - } - *Op = NotOp; - } + Op0 = freelyInvert(*this, Op0, &I); + Op1 = freelyInvert(*this, Op1, &I); Builder.SetInsertPoint(I.getInsertionPointAfterDef()); Value *NewLogicOp; @@ -3794,20 +4099,11 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { Value *NotOp0 = nullptr; Value *NotOp1 = nullptr; Value **OpToInvert = nullptr; - if (match(Op0, m_Not(m_Value(NotOp0))) && - InstCombiner::isFreeToInvert(Op1, /*WillInvertAllUses=*/true) && - (match(Op1, m_ImmConstant()) || - (isa<Instruction>(Op1) && - InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op1), - /*IgnoredUser=*/&I)))) { + if (match(Op0, m_Not(m_Value(NotOp0))) && canFreelyInvert(*this, Op1, &I)) { Op0 = NotOp0; OpToInvert = &Op1; } else if (match(Op1, m_Not(m_Value(NotOp1))) && - InstCombiner::isFreeToInvert(Op0, /*WillInvertAllUses=*/true) && - (match(Op0, m_ImmConstant()) || - (isa<Instruction>(Op0) && - InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op0), - /*IgnoredUser=*/&I)))) { + canFreelyInvert(*this, Op0, &I)) { Op1 = NotOp1; OpToInvert = &Op0; } else @@ -3817,19 +4113,7 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) return false; - if (auto *C = dyn_cast<Constant>(*OpToInvert)) { - *OpToInvert = ConstantExpr::getNot(C); - } else { - Builder.SetInsertPoint( - &*cast<Instruction>(*OpToInvert)->getInsertionPointAfterDef()); - Value *NotOpToInvert = - Builder.CreateNot(*OpToInvert, (*OpToInvert)->getName() + ".not"); - (*OpToInvert)->replaceUsesWithIf(NotOpToInvert, [NotOpToInvert](Use &U) { - return U.getUser() != NotOpToInvert; - }); - freelyInvertAllUsersOf(NotOpToInvert, /*IgnoredUser=*/&I); - *OpToInvert = NotOpToInvert; - } + *OpToInvert = freelyInvert(*this, *OpToInvert, &I); Builder.SetInsertPoint(&*I.getInsertionPointAfterDef()); Value *NewBinOp; @@ -3896,8 +4180,8 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) return BinaryOperator::CreateAShr(X, Y); - // Bit-hack form of a signbit test: - // iN ~X >>s (N-1) --> sext i1 (X > -1) to iN + // Bit-hack form of a signbit test for iN type: + // ~(X >>s (N - 1)) --> sext i1 (X > -1) to iN unsigned FullShift = Ty->getScalarSizeInBits() - 1; if (match(NotVal, m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))))) { Value *IsNotNeg = Builder.CreateIsNotNeg(X, "isnotneg"); @@ -4071,6 +4355,9 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *R = foldNot(I)) return R; + if (Instruction *R = foldBinOpShiftWithShift(I)) + return R; + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) // This it a special case in haveNoCommonBitsSet, but the computeKnownBits // calls in there are unnecessary as SimplifyDemandedInstructionBits should @@ -4280,6 +4567,23 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { } } + // (A & B) ^ (A | C) --> A ? ~B : C -- There are 4 commuted variants. + if (I.getType()->isIntOrIntVectorTy(1) && + match(Op0, m_OneUse(m_LogicalAnd(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_LogicalOr(m_Value(C), m_Value(D))))) { + bool NeedFreeze = isa<SelectInst>(Op0) && isa<SelectInst>(Op1) && B == D; + if (B == C || B == D) + std::swap(A, B); + if (A == C) + std::swap(C, D); + if (A == D) { + if (NeedFreeze) + A = Builder.CreateFreeze(A); + Value *NotB = Builder.CreateNot(B); + return SelectInst::Create(A, NotB, C); + } + } + if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (Value *V = foldXorOfICmps(LHS, RHS, I)) @@ -4313,5 +4617,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I)) return Folded; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) + return Res; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index e73667f9c02e..cba282cea72b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -116,24 +116,10 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { return &RMWI; } - AtomicOrdering Ordering = RMWI.getOrdering(); - assert(Ordering != AtomicOrdering::NotAtomic && - Ordering != AtomicOrdering::Unordered && + assert(RMWI.getOrdering() != AtomicOrdering::NotAtomic && + RMWI.getOrdering() != AtomicOrdering::Unordered && "AtomicRMWs don't make sense with Unordered or NotAtomic"); - // Any atomicrmw xchg with no uses can be converted to a atomic store if the - // ordering is compatible. - if (RMWI.getOperation() == AtomicRMWInst::Xchg && - RMWI.use_empty()) { - if (Ordering != AtomicOrdering::Release && - Ordering != AtomicOrdering::Monotonic) - return nullptr; - new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), - /*isVolatile*/ false, RMWI.getAlign(), Ordering, - RMWI.getSyncScopeID(), &RMWI); - return eraseInstFromFunction(RMWI); - } - if (!isIdempotentRMW(RMWI)) return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index fbf1327143a8..d3ec6a7aa667 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/AttributeMask.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -439,9 +440,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType()); ElementCount VF = WideLoadTy->getElementCount(); - Constant *EC = - ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue()); - Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC; + Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF); Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1)); Value *Extract = Builder.CreateExtractElement(II.getArgOperand(0), LastLane); @@ -533,16 +532,15 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType())); } - // If the operand is a select with constant arm(s), try to hoist ctlz/cttz. - if (auto *Sel = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = IC.FoldOpIntoSelect(II, Sel)) - return R; - if (IsTZ) { // cttz(-x) -> cttz(x) if (match(Op0, m_Neg(m_Value(X)))) return IC.replaceOperand(II, 0, X); + // cttz(-x & x) -> cttz(x) + if (match(Op0, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) + return IC.replaceOperand(II, 0, X); + // cttz(sext(x)) -> cttz(zext(x)) if (match(Op0, m_OneUse(m_SExt(m_Value(X))))) { auto *Zext = IC.Builder.CreateZExt(X, II.getType()); @@ -599,8 +597,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { } // Add range metadata since known bits can't completely reflect what we know. - // TODO: Handle splat vectors. - auto *IT = dyn_cast<IntegerType>(Op0->getType()); + auto *IT = cast<IntegerType>(Op0->getType()->getScalarType()); if (IT && IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { Metadata *LowAndHigh[] = { ConstantAsMetadata::get(ConstantInt::get(IT, DefiniteZeros)), @@ -657,11 +654,6 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { return CastInst::Create(Instruction::ZExt, NarrowPop, Ty); } - // If the operand is a select with constant arm(s), try to hoist ctpop. - if (auto *Sel = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = IC.FoldOpIntoSelect(II, Sel)) - return R; - KnownBits Known(BitWidth); IC.computeKnownBits(Op0, Known, 0, &II); @@ -683,12 +675,8 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { Constant::getNullValue(Ty)), Ty); - // FIXME: Try to simplify vectors of integers. - auto *IT = dyn_cast<IntegerType>(Ty); - if (!IT) - return nullptr; - // Add range metadata since known bits can't completely reflect what we know. + auto *IT = cast<IntegerType>(Ty->getScalarType()); unsigned MinCount = Known.countMinPopulation(); unsigned MaxCount = Known.countMaxPopulation(); if (IT->getBitWidth() != 1 && !II.getMetadata(LLVMContext::MD_range)) { @@ -830,10 +818,204 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { return nullptr; } +static bool inputDenormalIsIEEE(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + return F.getDenormalMode(Ty->getFltSemantics()).Input == DenormalMode::IEEE; +} + +static bool inputDenormalIsDAZ(const Function &F, const Type *Ty) { + Ty = Ty->getScalarType(); + return F.getDenormalMode(Ty->getFltSemantics()).inputsAreZero(); +} + +/// \returns the compare predicate type if the test performed by +/// llvm.is.fpclass(x, \p Mask) is equivalent to fcmp o__ x, 0.0 with the +/// floating-point environment assumed for \p F for type \p Ty +static FCmpInst::Predicate fpclassTestIsFCmp0(FPClassTest Mask, + const Function &F, Type *Ty) { + switch (static_cast<unsigned>(Mask)) { + case fcZero: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OEQ; + break; + case fcZero | fcSubnormal: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OEQ; + break; + case fcPositive | fcNegZero: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OGE; + break; + case fcPositive | fcNegZero | fcNegSubnormal: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OGE; + break; + case fcPosSubnormal | fcPosNormal | fcPosInf: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OGT; + break; + case fcNegative | fcPosZero: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OLE; + break; + case fcNegative | fcPosZero | fcPosSubnormal: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OLE; + break; + case fcNegSubnormal | fcNegNormal | fcNegInf: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_OLT; + break; + case fcPosNormal | fcPosInf: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OGT; + break; + case fcNegNormal | fcNegInf: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_OLT; + break; + case ~fcZero & ~fcNan: + if (inputDenormalIsIEEE(F, Ty)) + return FCmpInst::FCMP_ONE; + break; + case ~(fcZero | fcSubnormal) & ~fcNan: + if (inputDenormalIsDAZ(F, Ty)) + return FCmpInst::FCMP_ONE; + break; + default: + break; + } + + return FCmpInst::BAD_FCMP_PREDICATE; +} + +Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { + Value *Src0 = II.getArgOperand(0); + Value *Src1 = II.getArgOperand(1); + const ConstantInt *CMask = cast<ConstantInt>(Src1); + FPClassTest Mask = static_cast<FPClassTest>(CMask->getZExtValue()); + const bool IsUnordered = (Mask & fcNan) == fcNan; + const bool IsOrdered = (Mask & fcNan) == fcNone; + const FPClassTest OrderedMask = Mask & ~fcNan; + const FPClassTest OrderedInvertedMask = ~OrderedMask & ~fcNan; + + const bool IsStrict = II.isStrictFP(); + + Value *FNegSrc; + if (match(Src0, m_FNeg(m_Value(FNegSrc)))) { + // is.fpclass (fneg x), mask -> is.fpclass x, (fneg mask) + + II.setArgOperand(1, ConstantInt::get(Src1->getType(), fneg(Mask))); + return replaceOperand(II, 0, FNegSrc); + } + + Value *FAbsSrc; + if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) { + II.setArgOperand(1, ConstantInt::get(Src1->getType(), fabs(Mask))); + return replaceOperand(II, 0, FAbsSrc); + } + + // TODO: is.fpclass(x, fcInf) -> fabs(x) == inf + + if ((OrderedMask == fcPosInf || OrderedMask == fcNegInf) && + (IsOrdered || IsUnordered) && !IsStrict) { + // is.fpclass(x, fcPosInf) -> fcmp oeq x, +inf + // is.fpclass(x, fcNegInf) -> fcmp oeq x, -inf + // is.fpclass(x, fcPosInf|fcNan) -> fcmp ueq x, +inf + // is.fpclass(x, fcNegInf|fcNan) -> fcmp ueq x, -inf + Constant *Inf = + ConstantFP::getInfinity(Src0->getType(), OrderedMask == fcNegInf); + Value *EqInf = IsUnordered ? Builder.CreateFCmpUEQ(Src0, Inf) + : Builder.CreateFCmpOEQ(Src0, Inf); + + EqInf->takeName(&II); + return replaceInstUsesWith(II, EqInf); + } + + if ((OrderedInvertedMask == fcPosInf || OrderedInvertedMask == fcNegInf) && + (IsOrdered || IsUnordered) && !IsStrict) { + // is.fpclass(x, ~fcPosInf) -> fcmp one x, +inf + // is.fpclass(x, ~fcNegInf) -> fcmp one x, -inf + // is.fpclass(x, ~fcPosInf|fcNan) -> fcmp une x, +inf + // is.fpclass(x, ~fcNegInf|fcNan) -> fcmp une x, -inf + Constant *Inf = ConstantFP::getInfinity(Src0->getType(), + OrderedInvertedMask == fcNegInf); + Value *NeInf = IsUnordered ? Builder.CreateFCmpUNE(Src0, Inf) + : Builder.CreateFCmpONE(Src0, Inf); + NeInf->takeName(&II); + return replaceInstUsesWith(II, NeInf); + } + + if (Mask == fcNan && !IsStrict) { + // Equivalent of isnan. Replace with standard fcmp if we don't care about FP + // exceptions. + Value *IsNan = + Builder.CreateFCmpUNO(Src0, ConstantFP::getZero(Src0->getType())); + IsNan->takeName(&II); + return replaceInstUsesWith(II, IsNan); + } + + if (Mask == (~fcNan & fcAllFlags) && !IsStrict) { + // Equivalent of !isnan. Replace with standard fcmp. + Value *FCmp = + Builder.CreateFCmpORD(Src0, ConstantFP::getZero(Src0->getType())); + FCmp->takeName(&II); + return replaceInstUsesWith(II, FCmp); + } + + FCmpInst::Predicate PredType = FCmpInst::BAD_FCMP_PREDICATE; + + // Try to replace with an fcmp with 0 + // + // is.fpclass(x, fcZero) -> fcmp oeq x, 0.0 + // is.fpclass(x, fcZero | fcNan) -> fcmp ueq x, 0.0 + // is.fpclass(x, ~fcZero & ~fcNan) -> fcmp one x, 0.0 + // is.fpclass(x, ~fcZero) -> fcmp une x, 0.0 + // + // is.fpclass(x, fcPosSubnormal | fcPosNormal | fcPosInf) -> fcmp ogt x, 0.0 + // is.fpclass(x, fcPositive | fcNegZero) -> fcmp oge x, 0.0 + // + // is.fpclass(x, fcNegSubnormal | fcNegNormal | fcNegInf) -> fcmp olt x, 0.0 + // is.fpclass(x, fcNegative | fcPosZero) -> fcmp ole x, 0.0 + // + if (!IsStrict && (IsOrdered || IsUnordered) && + (PredType = fpclassTestIsFCmp0(OrderedMask, *II.getFunction(), + Src0->getType())) != + FCmpInst::BAD_FCMP_PREDICATE) { + Constant *Zero = ConstantFP::getZero(Src0->getType()); + // Equivalent of == 0. + Value *FCmp = Builder.CreateFCmp( + IsUnordered ? FCmpInst::getUnorderedPredicate(PredType) : PredType, + Src0, Zero); + + FCmp->takeName(&II); + return replaceInstUsesWith(II, FCmp); + } + + KnownFPClass Known = computeKnownFPClass( + Src0, DL, Mask, 0, &getTargetLibraryInfo(), &AC, &II, &DT); + + // Clear test bits we know must be false from the source value. + // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other + // fp_class (ninf x), ninf|pinf|other -> fp_class (ninf x), other + if ((Mask & Known.KnownFPClasses) != Mask) { + II.setArgOperand( + 1, ConstantInt::get(Src1->getType(), Mask & Known.KnownFPClasses)); + return &II; + } + + // If none of the tests which can return false are possible, fold to true. + // fp_class (nnan x), ~(qnan|snan) -> true + // fp_class (ninf x), ~(ninf|pinf) -> true + if (Mask == Known.KnownFPClasses) + return replaceInstUsesWith(II, ConstantInt::get(II.getType(), true)); + + return nullptr; +} + static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, - const DataLayout &DL, - AssumptionCache *AC, - DominatorTree *DT) { + const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); if (Known.isNonNegative()) return false; @@ -848,6 +1030,19 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } +/// Return true if two values \p Op0 and \p Op1 are known to have the same sign. +static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI, + const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { + std::optional<bool> Known1 = getKnownSign(Op1, CxtI, DL, AC, DT); + if (!Known1) + return false; + std::optional<bool> Known0 = getKnownSign(Op0, CxtI, DL, AC, DT); + if (!Known0) + return false; + return *Known0 == *Known1; +} + /// Try to canonicalize min/max(X + C0, C1) as min/max(X, C1 - C0) + C0. This /// can trigger other combines. static Instruction *moveAddAfterMinMax(IntrinsicInst *II, @@ -991,7 +1186,8 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, /// If this min/max has a constant operand and an operand that is a matching /// min/max with a constant operand, constant-fold the 2 constant operands. -static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) { +static Value *reassociateMinMaxWithConstants(IntrinsicInst *II, + IRBuilderBase &Builder) { Intrinsic::ID MinMaxID = II->getIntrinsicID(); auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); if (!LHS || LHS->getIntrinsicID() != MinMaxID) @@ -1004,12 +1200,10 @@ static Instruction *reassociateMinMaxWithConstants(IntrinsicInst *II) { // max (max X, C0), C1 --> max X, (max C0, C1) --> max X, NewC ICmpInst::Predicate Pred = MinMaxIntrinsic::getPredicate(MinMaxID); - Constant *CondC = ConstantExpr::getICmp(Pred, C0, C1); - Constant *NewC = ConstantExpr::getSelect(CondC, C0, C1); - - Module *Mod = II->getModule(); - Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); - return CallInst::Create(MinMax, {LHS->getArgOperand(0), NewC}); + Value *CondC = Builder.CreateICmp(Pred, C0, C1); + Value *NewC = Builder.CreateSelect(CondC, C0, C1); + return Builder.CreateIntrinsic(MinMaxID, II->getType(), + {LHS->getArgOperand(0), NewC}); } /// If this min/max has a matching min/max operand with a constant, try to push @@ -1149,15 +1343,60 @@ foldShuffledIntrinsicOperands(IntrinsicInst *II, return new ShuffleVectorInst(NewIntrinsic, Mask); } +/// Fold the following cases and accepts bswap and bitreverse intrinsics: +/// bswap(logic_op(bswap(x), y)) --> logic_op(x, bswap(y)) +/// bswap(logic_op(bswap(x), bswap(y))) --> logic_op(x, y) (ignores multiuse) +template <Intrinsic::ID IntrID> +static Instruction *foldBitOrderCrossLogicOp(Value *V, + InstCombiner::BuilderTy &Builder) { + static_assert(IntrID == Intrinsic::bswap || IntrID == Intrinsic::bitreverse, + "This helper only supports BSWAP and BITREVERSE intrinsics"); + + Value *X, *Y; + // Find bitwise logic op. Check that it is a BinaryOperator explicitly so we + // don't match ConstantExpr that aren't meaningful for this transform. + if (match(V, m_OneUse(m_BitwiseLogic(m_Value(X), m_Value(Y)))) && + isa<BinaryOperator>(V)) { + Value *OldReorderX, *OldReorderY; + BinaryOperator::BinaryOps Op = cast<BinaryOperator>(V)->getOpcode(); + + // If both X and Y are bswap/bitreverse, the transform reduces the number + // of instructions even if there's multiuse. + // If only one operand is bswap/bitreverse, we need to ensure the operand + // have only one use. + if (match(X, m_Intrinsic<IntrID>(m_Value(OldReorderX))) && + match(Y, m_Intrinsic<IntrID>(m_Value(OldReorderY)))) { + return BinaryOperator::Create(Op, OldReorderX, OldReorderY); + } + + if (match(X, m_OneUse(m_Intrinsic<IntrID>(m_Value(OldReorderX))))) { + Value *NewReorder = Builder.CreateUnaryIntrinsic(IntrID, Y); + return BinaryOperator::Create(Op, OldReorderX, NewReorder); + } + + if (match(Y, m_OneUse(m_Intrinsic<IntrID>(m_Value(OldReorderY))))) { + Value *NewReorder = Builder.CreateUnaryIntrinsic(IntrID, X); + return BinaryOperator::Create(Op, NewReorder, OldReorderY); + } + } + return nullptr; +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Don't try to simplify calls without uses. It will not do anything useful, // but will result in the following folds being skipped. - if (!CI.use_empty()) - if (Value *V = simplifyCall(&CI, SQ.getWithInstruction(&CI))) + if (!CI.use_empty()) { + SmallVector<Value *, 4> Args; + Args.reserve(CI.arg_size()); + for (Value *Op : CI.args()) + Args.push_back(Op); + if (Value *V = simplifyCall(&CI, CI.getCalledOperand(), Args, + SQ.getWithInstruction(&CI))) return replaceInstUsesWith(CI, V); + } if (Value *FreedOp = getFreedOperand(&CI, &TLI)) return visitFree(CI, FreedOp); @@ -1176,7 +1415,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // not a multiple of element size then behavior is undefined. if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II)) if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(AMI->getLength())) - if (NumBytes->getSExtValue() < 0 || + if (NumBytes->isNegative() || (NumBytes->getZExtValue() % AMI->getElementSizeInBytes() != 0)) { CreateNonTerminatorUnreachable(AMI); assert(AMI->getType()->isVoidTy() && @@ -1267,10 +1506,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Intrinsic::ID IID = II->getIntrinsicID(); switch (IID) { - case Intrinsic::objectsize: - if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false)) + case Intrinsic::objectsize: { + SmallVector<Instruction *> InsertedInstructions; + if (Value *V = lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/false, + &InsertedInstructions)) { + for (Instruction *Inserted : InsertedInstructions) + Worklist.add(Inserted); return replaceInstUsesWith(CI, V); + } return nullptr; + } case Intrinsic::abs: { Value *IIOperand = II->getArgOperand(0); bool IntMinIsPoison = cast<Constant>(II->getArgOperand(1))->isOneValue(); @@ -1377,6 +1622,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } + // (umax X, (xor X, Pow2)) + // -> (or X, Pow2) + // (umin X, (xor X, Pow2)) + // -> (and X, ~Pow2) + // (smax X, (xor X, Pos_Pow2)) + // -> (or X, Pos_Pow2) + // (smin X, (xor X, Pos_Pow2)) + // -> (and X, ~Pos_Pow2) + // (smax X, (xor X, Neg_Pow2)) + // -> (and X, ~Neg_Pow2) + // (smin X, (xor X, Neg_Pow2)) + // -> (or X, Neg_Pow2) + if ((match(I0, m_c_Xor(m_Specific(I1), m_Value(X))) || + match(I1, m_c_Xor(m_Specific(I0), m_Value(X)))) && + isKnownToBeAPowerOfTwo(X, /* OrZero */ true)) { + bool UseOr = IID == Intrinsic::smax || IID == Intrinsic::umax; + bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin; + + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { + auto KnownSign = getKnownSign(X, II, DL, &AC, &DT); + if (KnownSign == std::nullopt) { + UseOr = false; + UseAndN = false; + } else if (*KnownSign /* true is Signed. */) { + UseOr ^= true; + UseAndN ^= true; + Type *Ty = I0->getType(); + // Negative power of 2 must be IntMin. It's possible to be able to + // prove negative / power of 2 without actually having known bits, so + // just get the value by hand. + X = Constant::getIntegerValue( + Ty, APInt::getSignedMinValue(Ty->getScalarSizeInBits())); + } + } + if (UseOr) + return BinaryOperator::CreateOr(I0, X); + else if (UseAndN) + return BinaryOperator::CreateAnd(I0, Builder.CreateNot(X)); + } + // If we can eliminate ~A and Y is free to invert: // max ~A, Y --> ~(min A, ~Y) // @@ -1436,13 +1721,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *SAdd = matchSAddSubSat(*II)) return SAdd; - if (match(I1, m_ImmConstant())) - if (auto *Sel = dyn_cast<SelectInst>(I0)) - if (Instruction *R = FoldOpIntoSelect(*II, Sel)) - return R; - - if (Instruction *NewMinMax = reassociateMinMaxWithConstants(II)) - return NewMinMax; + if (Value *NewMinMax = reassociateMinMaxWithConstants(II, Builder)) + return replaceInstUsesWith(*II, NewMinMax); if (Instruction *R = reassociateMinMaxWithConstantInOperand(II, Builder)) return R; @@ -1453,15 +1733,21 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } case Intrinsic::bitreverse: { + Value *IIOperand = II->getArgOperand(0); // bitrev (zext i1 X to ?) --> X ? SignBitC : 0 Value *X; - if (match(II->getArgOperand(0), m_ZExt(m_Value(X))) && + if (match(IIOperand, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { Type *Ty = II->getType(); APInt SignBit = APInt::getSignMask(Ty->getScalarSizeInBits()); return SelectInst::Create(X, ConstantInt::get(Ty, SignBit), ConstantInt::getNullValue(Ty)); } + + if (Instruction *crossLogicOpFold = + foldBitOrderCrossLogicOp<Intrinsic::bitreverse>(IIOperand, Builder)) + return crossLogicOpFold; + break; } case Intrinsic::bswap: { @@ -1511,6 +1797,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *V = Builder.CreateLShr(X, CV); return new TruncInst(V, IIOperand->getType()); } + + if (Instruction *crossLogicOpFold = + foldBitOrderCrossLogicOp<Intrinsic::bswap>(IIOperand, Builder)) { + return crossLogicOpFold; + } + break; } case Intrinsic::masked_load: @@ -1616,6 +1908,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Function *Bswap = Intrinsic::getDeclaration(Mod, Intrinsic::bswap, Ty); return CallInst::Create(Bswap, { Op0 }); } + if (Instruction *BitOp = + matchBSwapOrBitReverse(*II, /*MatchBSwaps*/ true, + /*MatchBitReversals*/ true)) + return BitOp; } // Left or right might be masked. @@ -1983,7 +2279,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::copysign: { Value *Mag = II->getArgOperand(0), *Sign = II->getArgOperand(1); - if (SignBitMustBeZero(Sign, &TLI)) { + if (SignBitMustBeZero(Sign, DL, &TLI)) { // If we know that the sign argument is positive, reduce to FABS: // copysign Mag, +Sign --> fabs Mag Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Mag, II); @@ -2079,6 +2375,42 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::ldexp: { + // ldexp(ldexp(x, a), b) -> ldexp(x, a + b) + // + // The danger is if the first ldexp would overflow to infinity or underflow + // to zero, but the combined exponent avoids it. We ignore this with + // reassoc. + // + // It's also safe to fold if we know both exponents are >= 0 or <= 0 since + // it would just double down on the overflow/underflow which would occur + // anyway. + // + // TODO: Could do better if we had range tracking for the input value + // exponent. Also could broaden sign check to cover == 0 case. + Value *Src = II->getArgOperand(0); + Value *Exp = II->getArgOperand(1); + Value *InnerSrc; + Value *InnerExp; + if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ldexp>( + m_Value(InnerSrc), m_Value(InnerExp)))) && + Exp->getType() == InnerExp->getType()) { + FastMathFlags FMF = II->getFastMathFlags(); + FastMathFlags InnerFlags = cast<FPMathOperator>(Src)->getFastMathFlags(); + + if ((FMF.allowReassoc() && InnerFlags.allowReassoc()) || + signBitMustBeTheSame(Exp, InnerExp, II, DL, &AC, &DT)) { + // TODO: Add nsw/nuw probably safe if integer type exceeds exponent + // width. + Value *NewExp = Builder.CreateAdd(InnerExp, Exp); + II->setArgOperand(1, NewExp); + II->setFastMathFlags(InnerFlags); // Or the inner flags. + return replaceOperand(*II, 0, InnerSrc); + } + } + + break; + } case Intrinsic::ptrauth_auth: case Intrinsic::ptrauth_resign: { // (sign|resign) + (auth|resign) can be folded by omitting the middle @@ -2380,12 +2712,34 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { isValidAssumeForContext(II, LHS, &DT)) { MDNode *MD = MDNode::get(II->getContext(), std::nullopt); LHS->setMetadata(LLVMContext::MD_nonnull, MD); + LHS->setMetadata(LLVMContext::MD_noundef, MD); return RemoveConditionFromAssume(II); // TODO: apply nonnull return attributes to calls and invokes // TODO: apply range metadata for range check patterns? } + // Separate storage assumptions apply to the underlying allocations, not any + // particular pointer within them. When evaluating the hints for AA purposes + // we getUnderlyingObject them; by precomputing the answers here we can + // avoid having to do so repeatedly there. + for (unsigned Idx = 0; Idx < II->getNumOperandBundles(); Idx++) { + OperandBundleUse OBU = II->getOperandBundleAt(Idx); + if (OBU.getTagName() == "separate_storage") { + assert(OBU.Inputs.size() == 2); + auto MaybeSimplifyHint = [&](const Use &U) { + Value *Hint = U.get(); + // Not having a limit is safe because InstCombine removes unreachable + // code. + Value *UnderlyingObject = getUnderlyingObject(Hint, /*MaxLookup*/ 0); + if (Hint != UnderlyingObject) + replaceUse(const_cast<Use &>(U), UnderlyingObject); + }; + MaybeSimplifyHint(OBU.Inputs[0]); + MaybeSimplifyHint(OBU.Inputs[1]); + } + } + // Convert nonnull assume like: // %A = icmp ne i32* %PTR, null // call void @llvm.assume(i1 %A) @@ -2479,6 +2833,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Known.isAllOnes() && isAssumeWithEmptyBundle(cast<AssumeInst>(*II))) return eraseInstFromFunction(*II); + // assume(false) is unreachable. + if (match(IIOperand, m_CombineOr(m_Zero(), m_Undef()))) { + CreateNonTerminatorUnreachable(II); + return eraseInstFromFunction(*II); + } + // Update the cache of affected values for this assumption (we might be // here because we just simplified the condition). AC.updateAffectedValues(cast<AssumeInst>(II)); @@ -2545,7 +2905,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { for (i = 0; i != SubVecNumElts; ++i) WidenMask.push_back(i); for (; i != VecNumElts; ++i) - WidenMask.push_back(UndefMaskElem); + WidenMask.push_back(PoisonMaskElem); Value *WidenShuffle = Builder.CreateShuffleVector(SubVec, WidenMask); @@ -2840,7 +3200,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { int Sz = Mask.size(); SmallBitVector UsedIndices(Sz); for (int Idx : Mask) { - if (Idx == UndefMaskElem || UsedIndices.test(Idx)) + if (Idx == PoisonMaskElem || UsedIndices.test(Idx)) break; UsedIndices.set(Idx); } @@ -2852,6 +3212,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::is_fpclass: { + if (Instruction *I = foldIntrinsicIsFPClass(*II)) + return I; + break; + } default: { // Handle target specific intrinsics std::optional<Instruction *> V = targetInstCombineIntrinsic(*II); @@ -2861,6 +3226,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } + // Try to fold intrinsic into select operands. This is legal if: + // * The intrinsic is speculatable. + // * The select condition is not a vector, or the intrinsic does not + // perform cross-lane operations. + switch (IID) { + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::ctpop: + case Intrinsic::umin: + case Intrinsic::umax: + case Intrinsic::smin: + case Intrinsic::smax: + case Intrinsic::usub_sat: + case Intrinsic::uadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::sadd_sat: + for (Value *Op : II->args()) + if (auto *Sel = dyn_cast<SelectInst>(Op)) + if (Instruction *R = FoldOpIntoSelect(*II, Sel)) + return R; + [[fallthrough]]; + default: + break; + } + if (Instruction *Shuf = foldShuffledIntrinsicOperands(II, Builder)) return Shuf; @@ -2907,49 +3297,6 @@ Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) { return visitCallBase(CBI); } -/// If this cast does not affect the value passed through the varargs area, we -/// can eliminate the use of the cast. -static bool isSafeToEliminateVarargsCast(const CallBase &Call, - const DataLayout &DL, - const CastInst *const CI, - const int ix) { - if (!CI->isLosslessCast()) - return false; - - // If this is a GC intrinsic, avoid munging types. We need types for - // statepoint reconstruction in SelectionDAG. - // TODO: This is probably something which should be expanded to all - // intrinsics since the entire point of intrinsics is that - // they are understandable by the optimizer. - if (isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) || - isa<GCResultInst>(Call)) - return false; - - // Opaque pointers are compatible with any byval types. - PointerType *SrcTy = cast<PointerType>(CI->getOperand(0)->getType()); - if (SrcTy->isOpaque()) - return true; - - // The size of ByVal or InAlloca arguments is derived from the type, so we - // can't change to a type with a different size. If the size were - // passed explicitly we could avoid this check. - if (!Call.isPassPointeeByValueArgument(ix)) - return true; - - // The transform currently only handles type replacement for byval, not other - // type-carrying attributes. - if (!Call.isByValArgument(ix)) - return false; - - Type *SrcElemTy = SrcTy->getNonOpaquePointerElementType(); - Type *DstElemTy = Call.getParamByValType(ix); - if (!SrcElemTy->isSized() || !DstElemTy->isSized()) - return false; - if (DL.getTypeAllocSize(SrcElemTy) != DL.getTypeAllocSize(DstElemTy)) - return false; - return true; -} - Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { if (!CI->getCalledFunction()) return nullptr; @@ -2965,7 +3312,7 @@ Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { auto InstCombineErase = [this](Instruction *I) { eraseInstFromFunction(*I); }; - LibCallSimplifier Simplifier(DL, &TLI, ORE, BFI, PSI, InstCombineRAUW, + LibCallSimplifier Simplifier(DL, &TLI, &AC, ORE, BFI, PSI, InstCombineRAUW, InstCombineErase); if (Value *With = Simplifier.optimizeCall(CI, Builder)) { ++NumSimplified; @@ -3198,32 +3545,6 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { if (IntrinsicInst *II = findInitTrampoline(Callee)) return transformCallThroughTrampoline(Call, *II); - // TODO: Drop this transform once opaque pointer transition is done. - FunctionType *FTy = Call.getFunctionType(); - if (FTy->isVarArg()) { - int ix = FTy->getNumParams(); - // See if we can optimize any arguments passed through the varargs area of - // the call. - for (auto I = Call.arg_begin() + FTy->getNumParams(), E = Call.arg_end(); - I != E; ++I, ++ix) { - CastInst *CI = dyn_cast<CastInst>(*I); - if (CI && isSafeToEliminateVarargsCast(Call, DL, CI, ix)) { - replaceUse(*I, CI->getOperand(0)); - - // Update the byval type to match the pointer type. - // Not necessary for opaque pointers. - PointerType *NewTy = cast<PointerType>(CI->getOperand(0)->getType()); - if (!NewTy->isOpaque() && Call.isByValArgument(ix)) { - Call.removeParamAttr(ix, Attribute::ByVal); - Call.addParamAttr(ix, Attribute::getWithByValType( - Call.getContext(), - NewTy->getNonOpaquePointerElementType())); - } - Changed = true; - } - } - } - if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) { InlineAsm *IA = cast<InlineAsm>(Callee); if (!IA->canThrow()) { @@ -3381,13 +3702,17 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { } /// If the callee is a constexpr cast of a function, attempt to move the cast to -/// the arguments of the call/callbr/invoke. +/// the arguments of the call/invoke. +/// CallBrInst is not supported. bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { auto *Callee = dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts()); if (!Callee) return false; + assert(!isa<CallBrInst>(Call) && + "CallBr's don't have a single point after a def to insert at"); + // If this is a call to a thunk function, don't remove the cast. Thunks are // used to transparently forward all incoming parameters and outgoing return // values, so it's important to leave the cast in place. @@ -3433,7 +3758,7 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { return false; // Attribute not compatible with transformed value. } - // If the callbase is an invoke/callbr instruction, and the return value is + // If the callbase is an invoke instruction, and the return value is // used by a PHI node in a successor, we cannot change the return type of // the call because there is no place to put the cast instruction (without // breaking the critical edge). Bail out in this case. @@ -3441,8 +3766,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { BasicBlock *PhisNotSupportedBlock = nullptr; if (auto *II = dyn_cast<InvokeInst>(Caller)) PhisNotSupportedBlock = II->getNormalDest(); - if (auto *CB = dyn_cast<CallBrInst>(Caller)) - PhisNotSupportedBlock = CB->getDefaultDest(); if (PhisNotSupportedBlock) for (User *U : Caller->users()) if (PHINode *PN = dyn_cast<PHINode>(U)) @@ -3490,24 +3813,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (CallerPAL.hasParamAttr(i, Attribute::ByVal) != Callee->getAttributes().hasParamAttr(i, Attribute::ByVal)) return false; // Cannot transform to or from byval. - - // If the parameter is passed as a byval argument, then we have to have a - // sized type and the sized type has to have the same size as the old type. - if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { - PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); - if (!ParamPTy) - return false; - - if (!ParamPTy->isOpaque()) { - Type *ParamElTy = ParamPTy->getNonOpaquePointerElementType(); - if (!ParamElTy->isSized()) - return false; - - Type *CurElTy = Call.getParamByValType(i); - if (DL.getTypeAllocSize(CurElTy) != DL.getTypeAllocSize(ParamElTy)) - return false; - } - } } if (Callee->isDeclaration()) { @@ -3568,16 +3873,8 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // type. Note that we made sure all incompatible ones are safe to drop. AttributeMask IncompatibleAttrs = AttributeFuncs::typeIncompatible( ParamTy, AttributeFuncs::ASK_SAFE_TO_DROP); - if (CallerPAL.hasParamAttr(i, Attribute::ByVal) && - !ParamTy->isOpaquePointerTy()) { - AttrBuilder AB(Ctx, CallerPAL.getParamAttrs(i).removeAttributes( - Ctx, IncompatibleAttrs)); - AB.addByValAttr(ParamTy->getNonOpaquePointerElementType()); - ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); - } else { - ArgAttrs.push_back( - CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs)); - } + ArgAttrs.push_back( + CallerPAL.getParamAttrs(i).removeAttributes(Ctx, IncompatibleAttrs)); } // If the function takes more arguments than the call was taking, add them @@ -3626,9 +3923,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { NewCall = Builder.CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(), Args, OpBundles); - } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { - NewCall = Builder.CreateCallBr(Callee, CBI->getDefaultDest(), - CBI->getIndirectDests(), Args, OpBundles); } else { NewCall = Builder.CreateCall(Callee, Args, OpBundles); cast<CallInst>(NewCall)->setTailCallKind( diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 3f851a2b2182..5c84f666616d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -25,166 +25,6 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// Analyze 'Val', seeing if it is a simple linear expression. -/// If so, decompose it, returning some value X, such that Val is -/// X*Scale+Offset. -/// -static Value *decomposeSimpleLinearExpr(Value *Val, unsigned &Scale, - uint64_t &Offset) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) { - Offset = CI->getZExtValue(); - Scale = 0; - return ConstantInt::get(Val->getType(), 0); - } - - if (BinaryOperator *I = dyn_cast<BinaryOperator>(Val)) { - // Cannot look past anything that might overflow. - // We specifically require nuw because we store the Scale in an unsigned - // and perform an unsigned divide on it. - OverflowingBinaryOperator *OBI = dyn_cast<OverflowingBinaryOperator>(Val); - if (OBI && !OBI->hasNoUnsignedWrap()) { - Scale = 1; - Offset = 0; - return Val; - } - - if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) { - if (I->getOpcode() == Instruction::Shl) { - // This is a value scaled by '1 << the shift amt'. - Scale = UINT64_C(1) << RHS->getZExtValue(); - Offset = 0; - return I->getOperand(0); - } - - if (I->getOpcode() == Instruction::Mul) { - // This value is scaled by 'RHS'. - Scale = RHS->getZExtValue(); - Offset = 0; - return I->getOperand(0); - } - - if (I->getOpcode() == Instruction::Add) { - // We have X+C. Check to see if we really have (X*C2)+C1, - // where C1 is divisible by C2. - unsigned SubScale; - Value *SubVal = - decomposeSimpleLinearExpr(I->getOperand(0), SubScale, Offset); - Offset += RHS->getZExtValue(); - Scale = SubScale; - return SubVal; - } - } - } - - // Otherwise, we can't look past this. - Scale = 1; - Offset = 0; - return Val; -} - -/// If we find a cast of an allocation instruction, try to eliminate the cast by -/// moving the type information into the alloc. -Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, - AllocaInst &AI) { - PointerType *PTy = cast<PointerType>(CI.getType()); - // Opaque pointers don't have an element type we could replace with. - if (PTy->isOpaque()) - return nullptr; - - IRBuilderBase::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(&AI); - - // Get the type really allocated and the type casted to. - Type *AllocElTy = AI.getAllocatedType(); - Type *CastElTy = PTy->getNonOpaquePointerElementType(); - if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; - - // This optimisation does not work for cases where the cast type - // is scalable and the allocated type is not. This because we need to - // know how many times the casted type fits into the allocated type. - // For the opposite case where the allocated type is scalable and the - // cast type is not this leads to poor code quality due to the - // introduction of 'vscale' into the calculations. It seems better to - // bail out for this case too until we've done a proper cost-benefit - // analysis. - bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy); - bool CastIsScalable = isa<ScalableVectorType>(CastElTy); - if (AllocIsScalable != CastIsScalable) return nullptr; - - Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy); - Align CastElTyAlign = DL.getABITypeAlign(CastElTy); - if (CastElTyAlign < AllocElTyAlign) return nullptr; - - // If the allocation has multiple uses, only promote it if we are strictly - // increasing the alignment of the resultant allocation. If we keep it the - // same, we open the door to infinite loops of various kinds. - if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr; - - // The alloc and cast types should be either both fixed or both scalable. - uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinValue(); - uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinValue(); - if (CastElTySize == 0 || AllocElTySize == 0) return nullptr; - - // If the allocation has multiple uses, only promote it if we're not - // shrinking the amount of memory being allocated. - uint64_t AllocElTyStoreSize = - DL.getTypeStoreSize(AllocElTy).getKnownMinValue(); - uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinValue(); - if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr; - - // See if we can satisfy the modulus by pulling a scale out of the array - // size argument. - unsigned ArraySizeScale; - uint64_t ArrayOffset; - Value *NumElements = // See if the array size is a decomposable linear expr. - decomposeSimpleLinearExpr(AI.getOperand(0), ArraySizeScale, ArrayOffset); - - // If we can now satisfy the modulus, by using a non-1 scale, we really can - // do the xform. - if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 || - (AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr; - - // We don't currently support arrays of scalable types. - assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0)); - - unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize; - Value *Amt = nullptr; - if (Scale == 1) { - Amt = NumElements; - } else { - Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale); - // Insert before the alloca, not before the cast. - Amt = Builder.CreateMul(Amt, NumElements); - } - - if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { - Value *Off = ConstantInt::get(AI.getArraySize()->getType(), - Offset, true); - Amt = Builder.CreateAdd(Amt, Off); - } - - AllocaInst *New = Builder.CreateAlloca(CastElTy, AI.getAddressSpace(), Amt); - New->setAlignment(AI.getAlign()); - New->takeName(&AI); - New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); - New->setMetadata(LLVMContext::MD_DIAssignID, - AI.getMetadata(LLVMContext::MD_DIAssignID)); - - replaceAllDbgUsesWith(AI, *New, *New, DT); - - // If the allocation has multiple real uses, insert a cast and change all - // things that used it to use the new cast. This will also hack on CI, but it - // will die soon. - if (!AI.hasOneUse()) { - // New is the allocation instruction, pointer typed. AI is the original - // allocation instruction, also pointer typed. Thus, cast to use is BitCast. - Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast"); - replaceInstUsesWith(AI, NewCast); - eraseInstFromFunction(AI); - } - return replaceInstUsesWith(CI, New); -} - /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns /// true for, actually insert the code to evaluate the expression. Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, @@ -252,6 +92,20 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, Res = CastInst::Create( static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty); break; + case Instruction::Call: + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Unsupported call!"); + case Intrinsic::vscale: { + Function *Fn = + Intrinsic::getDeclaration(I->getModule(), Intrinsic::vscale, {Ty}); + Res = CallInst::Create(Fn->getFunctionType(), Fn); + break; + } + } + } + break; default: // TODO: Can handle more cases here. llvm_unreachable("Unreachable!"); @@ -294,6 +148,10 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); Type *Ty = CI.getType(); + if (auto *SrcC = dyn_cast<Constant>(Src)) + if (Constant *Res = ConstantFoldCastOperand(CI.getOpcode(), SrcC, Ty, DL)) + return replaceInstUsesWith(CI, Res); + // Try to eliminate a cast of a cast. if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { @@ -501,16 +359,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, // If the integer type can hold the max FP value, it is safe to cast // directly to that type. Otherwise, we may create poison via overflow // that did not exist in the original code. - // - // The max FP value is pow(2, MaxExponent) * (1 + MaxFraction), so we need - // at least one more bit than the MaxExponent to hold the max FP value. Type *InputTy = I->getOperand(0)->getType()->getScalarType(); const fltSemantics &Semantics = InputTy->getFltSemantics(); - uint32_t MinBitWidth = APFloatBase::semanticsMaxExponent(Semantics); - // Extra sign bit needed. - if (I->getOpcode() == Instruction::FPToSI) - ++MinBitWidth; - return Ty->getScalarSizeInBits() > MinBitWidth; + uint32_t MinBitWidth = + APFloatBase::semanticsIntSizeInBits(Semantics, + I->getOpcode() == Instruction::FPToSI); + return Ty->getScalarSizeInBits() >= MinBitWidth; } default: // TODO: Can handle more cases here. @@ -881,13 +735,12 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } - if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)), + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_ImmConstant(C)), m_Deferred(X))))) { // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); Constant *MaskC = ConstantExpr::getShl(One, C); - MaskC = ConstantExpr::getOr(MaskC, One); - Value *And = Builder.CreateAnd(X, MaskC); + Value *And = Builder.CreateAnd(X, Builder.CreateOr(MaskC, One)); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } } @@ -904,11 +757,18 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // removed by the trunc. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { + auto GetNewShAmt = [&](unsigned Width) { + Constant *MaxAmt = ConstantInt::get(SrcTy, Width - 1, false); + Constant *Cmp = + ConstantFoldCompareInstOperands(ICmpInst::ICMP_ULT, C, MaxAmt, DL); + Constant *ShAmt = ConstantFoldSelectInstruction(Cmp, C, MaxAmt); + return ConstantFoldCastOperand(Instruction::Trunc, ShAmt, A->getType(), + DL); + }; + // trunc (lshr (sext A), C) --> ashr A, C if (A->getType() == DestTy) { - Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); - Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); - ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + Constant *ShAmt = GetNewShAmt(DestWidth); ShAmt = Constant::mergeUndefsWith(ShAmt, C); return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) : BinaryOperator::CreateAShr(A, ShAmt); @@ -916,9 +776,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // The types are mismatched, so create a cast after shifting: // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) if (Src->hasOneUse()) { - Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false); - Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); - ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + Constant *ShAmt = GetNewShAmt(AWidth); Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact); return CastInst::CreateIntegerCast(Shift, DestTy, true); } @@ -998,7 +856,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { } } - if (match(Src, m_VScale(DL))) { + if (match(Src, m_VScale())) { if (Trunc.getFunction() && Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = @@ -1217,6 +1075,13 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, return false; return true; } + case Instruction::Call: + // llvm.vscale() can always be executed in larger type, because the + // value is automatically zero-extended. + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::vscale) + return true; + return false; default: // TODO: Can handle more cases here. return false; @@ -1226,7 +1091,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // If this zero extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this zext. - if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back())) + if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back()) && + !isa<Constant>(Zext.getOperand(0))) return nullptr; // If one of the common conversion will work, do it. @@ -1340,7 +1206,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { return BinaryOperator::CreateAnd(X, ZextC); } - if (match(Src, m_VScale(DL))) { + if (match(Src, m_VScale())) { if (Zext.getFunction() && Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = @@ -1402,7 +1268,7 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { // sext ((x & 2^n) == 0) -> (x >> n) - 1 // sext ((x & 2^n) != 2^n) -> (x >> n) - 1 - unsigned ShiftAmt = KnownZeroMask.countTrailingZeros(); + unsigned ShiftAmt = KnownZeroMask.countr_zero(); // Perform a right shift to place the desired bit in the LSB. if (ShiftAmt) In = Builder.CreateLShr(In, @@ -1416,7 +1282,7 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, } else { // sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1 // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1 - unsigned ShiftAmt = KnownZeroMask.countLeadingZeros(); + unsigned ShiftAmt = KnownZeroMask.countl_zero(); // Perform a left shift to place the desired bit in the MSB. if (ShiftAmt) In = Builder.CreateShl(In, @@ -1611,7 +1477,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { } } - if (match(Src, m_VScale(DL))) { + if (match(Src, m_VScale())) { if (Sext.getFunction() && Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = @@ -2687,57 +2553,6 @@ Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, return RetVal; } -static Instruction *convertBitCastToGEP(BitCastInst &CI, IRBuilderBase &Builder, - const DataLayout &DL) { - Value *Src = CI.getOperand(0); - PointerType *SrcPTy = cast<PointerType>(Src->getType()); - PointerType *DstPTy = cast<PointerType>(CI.getType()); - - // Bitcasts involving opaque pointers cannot be converted into a GEP. - if (SrcPTy->isOpaque() || DstPTy->isOpaque()) - return nullptr; - - Type *DstElTy = DstPTy->getNonOpaquePointerElementType(); - Type *SrcElTy = SrcPTy->getNonOpaquePointerElementType(); - - // When the type pointed to is not sized the cast cannot be - // turned into a gep. - if (!SrcElTy->isSized()) - return nullptr; - - // If the source and destination are pointers, and this cast is equivalent - // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. - // This can enhance SROA and other transforms that want type-safe pointers. - unsigned NumZeros = 0; - while (SrcElTy && SrcElTy != DstElTy) { - SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0); - ++NumZeros; - } - - // If we found a path from the src to dest, create the getelementptr now. - if (SrcElTy == DstElTy) { - SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - SrcPTy->getNonOpaquePointerElementType(), Src, Idxs); - - // If the source pointer is dereferenceable, then assume it points to an - // allocated object and apply "inbounds" to the GEP. - bool CanBeNull, CanBeFreed; - if (Src->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed)) { - // In a non-default address space (not 0), a null pointer can not be - // assumed inbounds, so ignore that case (dereferenceable_or_null). - // The reason is that 'null' is not treated differently in these address - // spaces, and we consequently ignore the 'gep inbounds' special case - // for 'null' which allows 'inbounds' on 'null' if the indices are - // zeros. - if (SrcPTy->getAddressSpace() == 0 || !CanBeNull) - GEP->setIsInBounds(); - } - return GEP; - } - return nullptr; -} - Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { // If the operands are integer typed then apply the integer transforms, // otherwise just apply the common ones. @@ -2750,19 +2565,6 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (DestTy == Src->getType()) return replaceInstUsesWith(CI, Src); - if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) { - // If we are casting a alloca to a pointer to a type of the same - // size, rewrite the allocation instruction to allocate the "right" type. - // There is no need to modify malloc calls because it is their bitcast that - // needs to be cleaned up. - if (AllocaInst *AI = dyn_cast<AllocaInst>(Src)) - if (Instruction *V = PromoteCastOfAllocation(CI, *AI)) - return V; - - if (Instruction *I = convertBitCastToGEP(CI, Builder, DL)) - return I; - } - if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) { // Beware: messing with this target-specific oddity may cause trouble. if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { @@ -2905,23 +2707,5 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { } Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { - // If the destination pointer element type is not the same as the source's - // first do a bitcast to the destination type, and then the addrspacecast. - // This allows the cast to be exposed to other transforms. - Value *Src = CI.getOperand(0); - PointerType *SrcTy = cast<PointerType>(Src->getType()->getScalarType()); - PointerType *DestTy = cast<PointerType>(CI.getType()->getScalarType()); - - if (!SrcTy->hasSameElementTypeAs(DestTy)) { - Type *MidTy = - PointerType::getWithSamePointeeType(DestTy, SrcTy->getAddressSpace()); - // Handle vectors of pointers. - if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) - MidTy = VectorType::get(MidTy, VT->getElementCount()); - - Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); - return new AddrSpaceCastInst(NewBitCast, CI.getType()); - } - return commonPointerCastTransforms(CI); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 1480a0ff9e2f..656f04370e17 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -198,7 +199,11 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( } // If the element is masked, handle it. - if (AndCst) Elt = ConstantExpr::getAnd(Elt, AndCst); + if (AndCst) { + Elt = ConstantFoldBinaryOpOperands(Instruction::And, Elt, AndCst, DL); + if (!Elt) + return nullptr; + } // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, @@ -276,14 +281,14 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // order the state machines in complexity of the generated code. Value *Idx = GEP->getOperand(2); - // If the index is larger than the pointer size of the target, truncate the - // index down like the GEP would do implicitly. We don't have to do this for - // an inbounds GEP because the index can't be out of range. + // If the index is larger than the pointer offset size of the target, truncate + // the index down like the GEP would do implicitly. We don't have to do this + // for an inbounds GEP because the index can't be out of range. if (!GEP->isInBounds()) { - Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); - unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); - if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > PtrSize) - Idx = Builder.CreateTrunc(Idx, IntPtrTy); + Type *PtrIdxTy = DL.getIndexType(GEP->getType()); + unsigned OffsetSize = PtrIdxTy->getIntegerBitWidth(); + if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > OffsetSize) + Idx = Builder.CreateTrunc(Idx, PtrIdxTy); } // If inbounds keyword is not present, Idx * ElementSize can overflow. @@ -295,10 +300,10 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx. unsigned ElementSize = DL.getTypeAllocSize(Init->getType()->getArrayElementType()); - auto MaskIdx = [&](Value* Idx){ - if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) { + auto MaskIdx = [&](Value *Idx) { + if (!GEP->isInBounds() && llvm::countr_zero(ElementSize) != 0) { Value *Mask = ConstantInt::get(Idx->getType(), -1); - Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize)); + Mask = Builder.CreateLShr(Mask, llvm::countr_zero(ElementSize)); Idx = Builder.CreateAnd(Idx, Mask); } return Idx; @@ -533,7 +538,8 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V, /// pointer. static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, const DataLayout &DL, - SetVector<Value *> &Explored) { + SetVector<Value *> &Explored, + InstCombiner &IC) { // Perform all the substitutions. This is a bit tricky because we can // have cycles in our use-def chains. // 1. Create the PHI nodes without any incoming values. @@ -562,7 +568,7 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, // Create all the other instructions. for (Value *Val : Explored) { - if (NewInsts.find(Val) != NewInsts.end()) + if (NewInsts.contains(Val)) continue; if (auto *CI = dyn_cast<CastInst>(Val)) { @@ -610,7 +616,7 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, for (unsigned I = 0, E = PHI->getNumIncomingValues(); I < E; ++I) { Value *NewIncoming = PHI->getIncomingValue(I); - if (NewInsts.find(NewIncoming) != NewInsts.end()) + if (NewInsts.contains(NewIncoming)) NewIncoming = NewInsts[NewIncoming]; NewPhi->addIncoming(NewIncoming, PHI->getIncomingBlock(I)); @@ -635,7 +641,10 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, Val->getName() + ".ptr"); NewVal = Builder.CreateBitOrPointerCast( NewVal, Val->getType(), Val->getName() + ".conv"); - Val->replaceAllUsesWith(NewVal); + IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal); + // Add old instruction to worklist for DCE. We don't directly remove it + // here because the original compare is one of the users. + IC.addToWorklist(cast<Instruction>(Val)); } return NewInsts[Start]; @@ -688,7 +697,8 @@ getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) { /// between GEPLHS and RHS. static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, - const DataLayout &DL) { + const DataLayout &DL, + InstCombiner &IC) { // FIXME: Support vector of pointers. if (GEPLHS->getType()->isVectorTy()) return nullptr; @@ -712,7 +722,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, // can't have overflow on either side. We can therefore re-write // this as: // OFFSET1 cmp OFFSET2 - Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes); + Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes, IC); // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written // GEP having PtrBase as the pointer base, and has returned in NewRHS the @@ -740,7 +750,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (PtrBase == RHS && GEPLHS->isInBounds()) { + if (PtrBase == RHS && (GEPLHS->isInBounds() || ICmpInst::isEquality(Cond))) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). Value *Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, @@ -831,7 +841,7 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Otherwise, the base pointers are different and the indices are // different. Try convert this to an indexed compare by looking through // PHIs/casts. - return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } // If one of the GEPs has all zero indices, recurse. @@ -883,7 +893,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! - if (GEPsInBounds && (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && + if ((GEPsInBounds || CmpInst::isEquality(Cond)) && + (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); @@ -894,13 +905,10 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Try convert this to an indexed compare by looking through PHIs/casts as a // last resort. - return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); + return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } -Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, - const AllocaInst *Alloca) { - assert(ICI.isEquality() && "Cannot fold non-equality comparison."); - +bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) { // It would be tempting to fold away comparisons between allocas and any // pointer not based on that alloca (e.g. an argument). However, even // though such pointers cannot alias, they can still compare equal. @@ -909,67 +917,72 @@ Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, // doesn't escape we can argue that it's impossible to guess its value, and we // can therefore act as if any such guesses are wrong. // - // The code below checks that the alloca doesn't escape, and that it's only - // used in a comparison once (the current instruction). The - // single-comparison-use condition ensures that we're trivially folding all - // comparisons against the alloca consistently, and avoids the risk of - // erroneously folding a comparison of the pointer with itself. - - unsigned MaxIter = 32; // Break cycles and bound to constant-time. + // However, we need to ensure that this folding is consistent: We can't fold + // one comparison to false, and then leave a different comparison against the + // same value alone (as it might evaluate to true at runtime, leading to a + // contradiction). As such, this code ensures that all comparisons are folded + // at the same time, and there are no other escapes. + + struct CmpCaptureTracker : public CaptureTracker { + AllocaInst *Alloca; + bool Captured = false; + /// The value of the map is a bit mask of which icmp operands the alloca is + /// used in. + SmallMapVector<ICmpInst *, unsigned, 4> ICmps; + + CmpCaptureTracker(AllocaInst *Alloca) : Alloca(Alloca) {} + + void tooManyUses() override { Captured = true; } + + bool captured(const Use *U) override { + auto *ICmp = dyn_cast<ICmpInst>(U->getUser()); + // We need to check that U is based *only* on the alloca, and doesn't + // have other contributions from a select/phi operand. + // TODO: We could check whether getUnderlyingObjects() reduces to one + // object, which would allow looking through phi nodes. + if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) { + // Collect equality icmps of the alloca, and don't treat them as + // captures. + auto Res = ICmps.insert({ICmp, 0}); + Res.first->second |= 1u << U->getOperandNo(); + return false; + } - SmallVector<const Use *, 32> Worklist; - for (const Use &U : Alloca->uses()) { - if (Worklist.size() >= MaxIter) - return nullptr; - Worklist.push_back(&U); - } + Captured = true; + return true; + } + }; - unsigned NumCmps = 0; - while (!Worklist.empty()) { - assert(Worklist.size() <= MaxIter); - const Use *U = Worklist.pop_back_val(); - const Value *V = U->getUser(); - --MaxIter; + CmpCaptureTracker Tracker(Alloca); + PointerMayBeCaptured(Alloca, &Tracker); + if (Tracker.Captured) + return false; - if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) || - isa<SelectInst>(V)) { - // Track the uses. - } else if (isa<LoadInst>(V)) { - // Loading from the pointer doesn't escape it. - continue; - } else if (const auto *SI = dyn_cast<StoreInst>(V)) { - // Storing *to* the pointer is fine, but storing the pointer escapes it. - if (SI->getValueOperand() == U->get()) - return nullptr; - continue; - } else if (isa<ICmpInst>(V)) { - if (NumCmps++) - return nullptr; // Found more than one cmp. - continue; - } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) { - switch (Intrin->getIntrinsicID()) { - // These intrinsics don't escape or compare the pointer. Memset is safe - // because we don't allow ptrtoint. Memcpy and memmove are safe because - // we don't allow stores, so src cannot point to V. - case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: - case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset: - continue; - default: - return nullptr; - } - } else { - return nullptr; + bool Changed = false; + for (auto [ICmp, Operands] : Tracker.ICmps) { + switch (Operands) { + case 1: + case 2: { + // The alloca is only used in one icmp operand. Assume that the + // equality is false. + auto *Res = ConstantInt::get( + ICmp->getType(), ICmp->getPredicate() == ICmpInst::ICMP_NE); + replaceInstUsesWith(*ICmp, Res); + eraseInstFromFunction(*ICmp); + Changed = true; + break; } - for (const Use &U : V->uses()) { - if (Worklist.size() >= MaxIter) - return nullptr; - Worklist.push_back(&U); + case 3: + // Both icmp operands are based on the alloca, so this is comparing + // pointer offsets, without leaking any information about the address + // of the alloca. Ignore such comparisons. + break; + default: + llvm_unreachable("Cannot happen"); } } - auto *Res = ConstantInt::get(ICI.getType(), - !CmpInst::isTrueWhenEqual(ICI.getPredicate())); - return replaceInstUsesWith(ICI, Res); + return Changed; } /// Fold "icmp pred (X+C), X". @@ -1058,9 +1071,9 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, int Shift; if (IsAShr && AP1.isNegative()) - Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); + Shift = AP1.countl_one() - AP2.countl_one(); else - Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + Shift = AP1.countl_zero() - AP2.countl_zero(); if (Shift > 0) { if (IsAShr && AP1 == AP2.ashr(Shift)) { @@ -1097,7 +1110,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, if (AP2.isZero()) return nullptr; - unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + unsigned AP2TrailingZeros = AP2.countr_zero(); if (!AP1 && AP2TrailingZeros != 0) return getICmp( @@ -1108,7 +1121,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); // Get the distance between the lowest bits that are set. - int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + int Shift = AP1.countr_zero() - AP2TrailingZeros; if (Shift > 0 && AP2.shl(Shift) == AP1) return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); @@ -1143,7 +1156,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. if (!CI2->getValue().isPowerOf2()) return nullptr; - unsigned NewWidth = CI2->getValue().countTrailingZeros(); + unsigned NewWidth = CI2->getValue().countr_zero(); if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; @@ -1295,6 +1308,48 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { return new ICmpInst(Pred, X, Cmp.getOperand(1)); } + // (icmp eq/ne (mul X Y)) -> (icmp eq/ne X/Y) if we know about whether X/Y are + // odd/non-zero/there is no overflow. + if (match(Cmp.getOperand(0), m_Mul(m_Value(X), m_Value(Y))) && + ICmpInst::isEquality(Pred)) { + + KnownBits XKnown = computeKnownBits(X, 0, &Cmp); + // if X % 2 != 0 + // (icmp eq/ne Y) + if (XKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, Y, Cmp.getOperand(1)); + + KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); + // if Y % 2 != 0 + // (icmp eq/ne X) + if (YKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + + auto *BO0 = cast<OverflowingBinaryOperator>(Cmp.getOperand(0)); + if (BO0->hasNoUnsignedWrap() || BO0->hasNoSignedWrap()) { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + // `isKnownNonZero` does more analysis than just `!KnownBits.One.isZero()` + // but to avoid unnecessary work, first just if this is an obvious case. + + // if X non-zero and NoOverflow(X * Y) + // (icmp eq/ne Y) + if (!XKnown.One.isZero() || isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(Pred, Y, Cmp.getOperand(1)); + + // if Y non-zero and NoOverflow(X * Y) + // (icmp eq/ne X) + if (!YKnown.One.isZero() || isKnownNonZero(Y, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + } + // Note, we are skipping cases: + // if Y % 2 != 0 AND X % 2 != 0 + // (false/true) + // if X non-zero and Y non-zero and NoOverflow(X * Y) + // (false/true) + // Those can be simplified later as we would have already replaced the (icmp + // eq/ne (mul X, Y)) with (icmp eq/ne X/Y) and if X/Y is known non-zero that + // will fold to a constant elsewhere. + } return nullptr; } @@ -1331,17 +1386,18 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { if (auto *Phi = dyn_cast<PHINode>(Op0)) if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) { - Type *Ty = Cmp.getType(); - Builder.SetInsertPoint(Phi); - PHINode *NewPhi = - Builder.CreatePHI(Ty, Phi->getNumOperands()); - for (BasicBlock *Predecessor : predecessors(Phi->getParent())) { - auto *Input = - cast<Constant>(Phi->getIncomingValueForBlock(Predecessor)); - auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C); - NewPhi->addIncoming(BoolInput, Predecessor); + SmallVector<Constant *> Ops; + for (Value *V : Phi->incoming_values()) { + Constant *Res = + ConstantFoldCompareInstOperands(Pred, cast<Constant>(V), C, DL); + if (!Res) + return nullptr; + Ops.push_back(Res); } - NewPhi->takeName(&Cmp); + Builder.SetInsertPoint(Phi); + PHINode *NewPhi = Builder.CreatePHI(Cmp.getType(), Phi->getNumOperands()); + for (auto [V, Pred] : zip(Ops, Phi->blocks())) + NewPhi->addIncoming(V, Pred); return replaceInstUsesWith(Cmp, NewPhi); } @@ -1369,11 +1425,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { if (TrueBB == FalseBB) return nullptr; - // Try to simplify this compare to T/F based on the dominating condition. - std::optional<bool> Imp = - isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); - if (Imp) - return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp)); + // We already checked simple implication in InstSimplify, only handle complex + // cases here. CmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1); @@ -1475,7 +1528,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, KnownBits Known = computeKnownBits(X, 0, &Cmp); // If all the high bits are known, we can do this xform. - if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { + if ((Known.Zero | Known.One).countl_one() >= SrcBits - DstBits) { // Pull in the high bits from known-ones set. APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); @@ -1781,17 +1834,12 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, ++UsesRemoved; // Compute A & ((1 << B) | 1) - Value *NewOr = nullptr; - if (auto *C = dyn_cast<Constant>(B)) { - if (UsesRemoved >= 1) - NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); - } - if (NewOr) { + unsigned RequireUsesRemoved = match(B, m_ImmConstant()) ? 1 : 3; + if (UsesRemoved >= RequireUsesRemoved) { + Value *NewOr = + Builder.CreateOr(Builder.CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); return replaceOperand(Cmp, 0, NewAnd); } @@ -1819,6 +1867,15 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType())); } + // (X & X) < 0 --> X == MinSignedC + // (X & X) > -1 --> X != MinSignedC + if (match(And, m_c_And(m_Neg(m_Value(X)), m_Deferred(X)))) { + Constant *MinSignedC = ConstantInt::get( + X->getType(), + APInt::getSignedMinValue(X->getType()->getScalarSizeInBits())); + auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; + return new ICmpInst(NewPred, X, MinSignedC); + } } // TODO: These all require that Y is constant too, so refactor with the above. @@ -1846,6 +1903,30 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); } + // If we are testing the intersection of 2 select-of-nonzero-constants with no + // common bits set, it's the same as checking if exactly one select condition + // is set: + // ((A ? TC : FC) & (B ? TC : FC)) == 0 --> xor A, B + // ((A ? TC : FC) & (B ? TC : FC)) != 0 --> not(xor A, B) + // TODO: Generalize for non-constant values. + // TODO: Handle signed/unsigned predicates. + // TODO: Handle other bitwise logic connectors. + // TODO: Extend to handle a non-zero compare constant. + if (C.isZero() && (Pred == CmpInst::ICMP_EQ || And->hasOneUse())) { + assert(Cmp.isEquality() && "Not expecting non-equality predicates"); + Value *A, *B; + const APInt *TC, *FC; + if (match(X, m_Select(m_Value(A), m_APInt(TC), m_APInt(FC))) && + match(Y, + m_Select(m_Value(B), m_SpecificInt(*TC), m_SpecificInt(*FC))) && + !TC->isZero() && !FC->isZero() && !TC->intersects(*FC)) { + Value *R = Builder.CreateXor(A, B); + if (Pred == CmpInst::ICMP_NE) + R = Builder.CreateNot(R); + return replaceInstUsesWith(Cmp, R); + } + } + // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X) // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X) // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X) @@ -1863,6 +1944,59 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return nullptr; } +/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0. +static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, + InstCombiner::BuilderTy &Builder) { + // Are we using xors to bitwise check for a pair or pairs of (in)equalities? + // Convert to a shorter form that has more potential to be folded even + // further. + // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) + // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 --> + // (X1 == X2) && (X3 == X4) && (X5 == X6) + // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 --> + // (X1 != X2) || (X3 != X4) || (X5 != X6) + // TODO: Implement for sub + SmallVector<std::pair<Value *, Value *>, 2> CmpValues; + SmallVector<Value *, 16> WorkList(1, Or); + + while (!WorkList.empty()) { + auto MatchOrOperatorArgument = [&](Value *OrOperatorArgument) { + Value *Lhs, *Rhs; + + if (match(OrOperatorArgument, + m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) { + CmpValues.emplace_back(Lhs, Rhs); + } else { + WorkList.push_back(OrOperatorArgument); + } + }; + + Value *CurrentValue = WorkList.pop_back_val(); + Value *OrOperatorLhs, *OrOperatorRhs; + + if (!match(CurrentValue, + m_Or(m_Value(OrOperatorLhs), m_Value(OrOperatorRhs)))) { + return nullptr; + } + + MatchOrOperatorArgument(OrOperatorRhs); + MatchOrOperatorArgument(OrOperatorLhs); + } + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; + Value *LhsCmp = Builder.CreateICmp(Pred, CmpValues.rbegin()->first, + CmpValues.rbegin()->second); + + for (auto It = CmpValues.rbegin() + 1; It != CmpValues.rend(); ++It) { + Value *RhsCmp = Builder.CreateICmp(Pred, It->first, It->second); + LhsCmp = Builder.CreateBinOp(BOpc, LhsCmp, RhsCmp); + } + + return LhsCmp; +} + /// Fold icmp (or X, Y), C. Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, @@ -1909,6 +2043,30 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, NewC); } + const APInt *OrC; + // icmp(X | OrC, C) --> icmp(X, 0) + if (C.isNonNegative() && match(Or, m_Or(m_Value(X), m_APInt(OrC)))) { + switch (Pred) { + // X | OrC s< C --> X s< 0 iff OrC s>= C s>= 0 + case ICmpInst::ICMP_SLT: + // X | OrC s>= C --> X s>= 0 iff OrC s>= C s>= 0 + case ICmpInst::ICMP_SGE: + if (OrC->sge(C)) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + break; + // X | OrC s<= C --> X s< 0 iff OrC s> C s>= 0 + case ICmpInst::ICMP_SLE: + // X | OrC s> C --> X s>= 0 iff OrC s> C s>= 0 + case ICmpInst::ICMP_SGT: + if (OrC->sgt(C)) + return new ICmpInst(ICmpInst::getFlippedStrictnessPredicate(Pred), X, + ConstantInt::getNullValue(X->getType())); + break; + default: + break; + } + } + if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) return nullptr; @@ -1924,18 +2082,8 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return BinaryOperator::Create(BOpc, CmpP, CmpQ); } - // Are we using xors to bitwise check for a pair of (in)equalities? Convert to - // a shorter form that has more potential to be folded even further. - Value *X1, *X2, *X3, *X4; - if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && - match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { - // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) - // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) - Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); - Value *Cmp34 = Builder.CreateICmp(Pred, X3, X4); - auto BOpc = Pred == CmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; - return BinaryOperator::Create(BOpc, Cmp12, Cmp34); - } + if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder)) + return replaceInstUsesWith(Cmp, V); return nullptr; } @@ -1969,21 +2117,29 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); } - if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap())) + if (MulC->isZero()) return nullptr; - // If the multiply does not wrap, try to divide the compare constant by the - // multiplication factor. + // If the multiply does not wrap or the constant is odd, try to divide the + // compare constant by the multiplication factor. if (Cmp.isEquality()) { - // (mul nsw X, MulC) == C --> X == C /s MulC + // (mul nsw X, MulC) eq/ne C --> X eq/ne C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); return new ICmpInst(Pred, X, NewC); } - // (mul nuw X, MulC) == C --> X == C /u MulC - if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); - return new ICmpInst(Pred, X, NewC); + + // C % MulC == 0 is weaker than we could use if MulC is odd because it + // correct to transform if MulC * N == C including overflow. I.e with i8 + // (icmp eq (mul X, 5), 101) -> (icmp eq X, 225) but since 101 % 5 != 0, we + // miss that case. + if (C.urem(*MulC).isZero()) { + // (mul nuw X, MulC) eq/ne C --> X eq/ne C /u MulC + // (mul X, OddC) eq/ne N * C --> X eq/ne N + if ((*MulC & 1).isOne() || Mul->hasNoUnsignedWrap()) { + Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); + return new ICmpInst(Pred, X, NewC); + } } } @@ -1992,27 +2148,32 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, // (X * MulC) > C --> X > (C / MulC) // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE? Constant *NewC = nullptr; - if (Mul->hasNoSignedWrap()) { + if (Mul->hasNoSignedWrap() && ICmpInst::isSigned(Pred)) { // MININT / -1 --> overflow. if (C.isMinSignedValue() && MulC->isAllOnes()) return nullptr; if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { NewC = ConstantInt::get( MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); - if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) + } else { + assert((Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) && + "Unexpected predicate"); NewC = ConstantInt::get( MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); - } else { - assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw"); - if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) + } + } else if (Mul->hasNoUnsignedWrap() && ICmpInst::isUnsigned(Pred)) { + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) { NewC = ConstantInt::get( MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); - if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) + } else { + assert((Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && + "Unexpected predicate"); NewC = ConstantInt::get( MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); + } } return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; @@ -2070,6 +2231,32 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + // (icmp pred (shl nuw&nsw X, Y), Csle0) + // -> (icmp pred X, Csle0) + // + // The idea is the nuw/nsw essentially freeze the sign bit for the shift op + // so X's must be what is used. + if (C.sle(0) && Shl->hasNoUnsignedWrap() && Shl->hasNoSignedWrap()) + return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); + + // (icmp eq/ne (shl nuw|nsw X, Y), 0) + // -> (icmp eq/ne X, 0) + if (ICmpInst::isEquality(Pred) && C.isZero() && + (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap())) + return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); + + // (icmp slt (shl nsw X, Y), 0/1) + // -> (icmp slt X, 0/1) + // (icmp sgt (shl nsw X, Y), 0/-1) + // -> (icmp sgt X, 0/-1) + // + // NB: sge/sle with a constant will canonicalize to sgt/slt. + if (Shl->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) + if (C.isZero() || (Pred == ICmpInst::ICMP_SGT ? C.isAllOnes() : C.isOne())) + return new ICmpInst(Pred, Shl->getOperand(0), Cmp.getOperand(1)); + const APInt *ShiftAmt; if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) return foldICmpShlOne(Cmp, Shl, C); @@ -2080,7 +2267,6 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, if (ShiftAmt->uge(TypeBits)) return nullptr; - ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Shl->getOperand(0); Type *ShType = Shl->getType(); @@ -2107,11 +2293,6 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead; isSignTest may change 'Pred', so only - // do that if we're sure to not continue on in this function. - if (isSignTest(Pred, C)) - return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); } // NUW guarantees that we are only shifting out zero bits from the high bits, @@ -2189,7 +2370,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp, // free on the target. It has the additional benefit of comparing to a // smaller constant that may be more target-friendly. unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); - if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && + if (Shl->hasOneUse() && Amt != 0 && C.countr_zero() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); if (auto *ShVTy = dyn_cast<VectorType>(ShType)) @@ -2237,9 +2418,8 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, assert(ShiftValC->uge(C) && "Expected simplify of compare"); assert((IsUGT || !C.isZero()) && "Expected X u< 0 to simplify"); - unsigned CmpLZ = - IsUGT ? C.countLeadingZeros() : (C - 1).countLeadingZeros(); - unsigned ShiftLZ = ShiftValC->countLeadingZeros(); + unsigned CmpLZ = IsUGT ? C.countl_zero() : (C - 1).countl_zero(); + unsigned ShiftLZ = ShiftValC->countl_zero(); Constant *NewC = ConstantInt::get(Shr->getType(), CmpLZ - ShiftLZ); auto NewPred = IsUGT ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE; return new ICmpInst(NewPred, Shr->getOperand(1), NewC); @@ -3184,18 +3364,30 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( } break; } - case Instruction::And: { - const APInt *BOC; - if (match(BOp1, m_APInt(BOC))) { - // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (C == *BOC && C.isPowerOf2()) - return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, - BO, Constant::getNullValue(RHS->getType())); - } - break; - } case Instruction::UDiv: - if (C.isZero()) { + case Instruction::SDiv: + if (BO->isExact()) { + // div exact X, Y eq/ne 0 -> X eq/ne 0 + // div exact X, Y eq/ne 1 -> X eq/ne Y + // div exact X, Y eq/ne C -> + // if Y * C never-overflow && OneUse: + // -> Y * C eq/ne X + if (C.isZero()) + return new ICmpInst(Pred, BOp0, Constant::getNullValue(BO->getType())); + else if (C.isOne()) + return new ICmpInst(Pred, BOp0, BOp1); + else if (BO->hasOneUse()) { + OverflowResult OR = computeOverflow( + Instruction::Mul, BO->getOpcode() == Instruction::SDiv, BOp1, + Cmp.getOperand(1), BO); + if (OR == OverflowResult::NeverOverflows) { + Value *YC = + Builder.CreateMul(BOp1, ConstantInt::get(BO->getType(), C)); + return new ICmpInst(Pred, YC, BOp0); + } + } + } + if (BO->getOpcode() == Instruction::UDiv && C.isZero()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -3207,6 +3399,44 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( return nullptr; } +static Instruction *foldCtpopPow2Test(ICmpInst &I, IntrinsicInst *CtpopLhs, + const APInt &CRhs, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &Q) { + assert(CtpopLhs->getIntrinsicID() == Intrinsic::ctpop && + "Non-ctpop intrin in ctpop fold"); + if (!CtpopLhs->hasOneUse()) + return nullptr; + + // Power of 2 test: + // isPow2OrZero : ctpop(X) u< 2 + // isPow2 : ctpop(X) == 1 + // NotPow2OrZero: ctpop(X) u> 1 + // NotPow2 : ctpop(X) != 1 + // If we know any bit of X can be folded to: + // IsPow2 : X & (~Bit) == 0 + // NotPow2 : X & (~Bit) != 0 + const ICmpInst::Predicate Pred = I.getPredicate(); + if (((I.isEquality() || Pred == ICmpInst::ICMP_UGT) && CRhs == 1) || + (Pred == ICmpInst::ICMP_ULT && CRhs == 2)) { + Value *Op = CtpopLhs->getArgOperand(0); + KnownBits OpKnown = computeKnownBits(Op, Q.DL, + /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT); + // No need to check for count > 1, that should be already constant folded. + if (OpKnown.countMinPopulation() == 1) { + Value *And = Builder.CreateAnd( + Op, Constant::getIntegerValue(Op->getType(), ~(OpKnown.One))); + return new ICmpInst( + (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_ULT) + ? ICmpInst::ICMP_EQ + : ICmpInst::ICMP_NE, + And, Constant::getNullValue(Op->getType())); + } + } + + return nullptr; +} + /// Fold an equality icmp with LLVM intrinsic and constant operand. Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { @@ -3227,6 +3457,11 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C.byteSwap())); + case Intrinsic::bitreverse: + // bitreverse(A) == C -> A == bitreverse(C) + return new ICmpInst(Pred, II->getArgOperand(0), + ConstantInt::get(Ty, C.reverseBits())); + case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != @@ -3277,15 +3512,22 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( } break; + case Intrinsic::umax: case Intrinsic::uadd_sat: { // uadd.sat(a, b) == 0 -> (a | b) == 0 - if (C.isZero()) { + // umax(a, b) == 0 -> (a | b) == 0 + if (C.isZero() && II->hasOneUse()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); } break; } + case Intrinsic::ssub_sat: + // ssub.sat(a, b) == 0 -> a == b + if (C.isZero()) + return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); + break; case Intrinsic::usub_sat: { // usub.sat(a, b) == 0 -> a <= b if (C.isZero()) { @@ -3303,7 +3545,9 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( } /// Fold an icmp with LLVM intrinsics -static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { +static Instruction * +foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { assert(Cmp.isEquality()); ICmpInst::Predicate Pred = Cmp.getPredicate(); @@ -3321,16 +3565,32 @@ static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { // original values. return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); case Intrinsic::fshl: - case Intrinsic::fshr: + case Intrinsic::fshr: { // If both operands are rotated by same amount, just compare the // original values. if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) break; if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) break; - if (IIOp0->getOperand(2) != IIOp1->getOperand(2)) - break; - return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + if (IIOp0->getOperand(2) == IIOp1->getOperand(2)) + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + + // rotate(X, AmtX) == rotate(Y, AmtY) + // -> rotate(X, AmtX - AmtY) == Y + // Do this if either both rotates have one use or if only one has one use + // and AmtX/AmtY are constants. + unsigned OneUses = IIOp0->hasOneUse() + IIOp1->hasOneUse(); + if (OneUses == 2 || + (OneUses == 1 && match(IIOp0->getOperand(2), m_ImmConstant()) && + match(IIOp1->getOperand(2), m_ImmConstant()))) { + Value *SubAmt = + Builder.CreateSub(IIOp0->getOperand(2), IIOp1->getOperand(2)); + Value *CombinedRotate = Builder.CreateIntrinsic( + Op0->getType(), IIOp0->getIntrinsicID(), + {IIOp0->getOperand(0), IIOp0->getOperand(0), SubAmt}); + return new ICmpInst(Pred, IIOp1->getOperand(0), CombinedRotate); + } + } break; default: break; } @@ -3421,16 +3681,119 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, return foldICmpBinOpEqualityWithConstant(Cmp, BO, C); } +static Instruction * +foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred, + SaturatingInst *II, const APInt &C, + InstCombiner::BuilderTy &Builder) { + // This transform may end up producing more than one instruction for the + // intrinsic, so limit it to one user of the intrinsic. + if (!II->hasOneUse()) + return nullptr; + + // Let Y = [add/sub]_sat(X, C) pred C2 + // SatVal = The saturating value for the operation + // WillWrap = Whether or not the operation will underflow / overflow + // => Y = (WillWrap ? SatVal : (X binop C)) pred C2 + // => Y = WillWrap ? (SatVal pred C2) : ((X binop C) pred C2) + // + // When (SatVal pred C2) is true, then + // Y = WillWrap ? true : ((X binop C) pred C2) + // => Y = WillWrap || ((X binop C) pred C2) + // else + // Y = WillWrap ? false : ((X binop C) pred C2) + // => Y = !WillWrap ? ((X binop C) pred C2) : false + // => Y = !WillWrap && ((X binop C) pred C2) + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + + const APInt *COp1; + // This transform only works when the intrinsic has an integral constant or + // splat vector as the second operand. + if (!match(Op1, m_APInt(COp1))) + return nullptr; + + APInt SatVal; + switch (II->getIntrinsicID()) { + default: + llvm_unreachable( + "This function only works with usub_sat and uadd_sat for now!"); + case Intrinsic::uadd_sat: + SatVal = APInt::getAllOnes(C.getBitWidth()); + break; + case Intrinsic::usub_sat: + SatVal = APInt::getZero(C.getBitWidth()); + break; + } + + // Check (SatVal pred C2) + bool SatValCheck = ICmpInst::compare(SatVal, C, Pred); + + // !WillWrap. + ConstantRange C1 = ConstantRange::makeExactNoWrapRegion( + II->getBinaryOp(), *COp1, II->getNoWrapKind()); + + // WillWrap. + if (SatValCheck) + C1 = C1.inverse(); + + ConstantRange C2 = ConstantRange::makeExactICmpRegion(Pred, C); + if (II->getBinaryOp() == Instruction::Add) + C2 = C2.sub(*COp1); + else + C2 = C2.add(*COp1); + + Instruction::BinaryOps CombiningOp = + SatValCheck ? Instruction::BinaryOps::Or : Instruction::BinaryOps::And; + + std::optional<ConstantRange> Combination; + if (CombiningOp == Instruction::BinaryOps::Or) + Combination = C1.exactUnionWith(C2); + else /* CombiningOp == Instruction::BinaryOps::And */ + Combination = C1.exactIntersectWith(C2); + + if (!Combination) + return nullptr; + + CmpInst::Predicate EquivPred; + APInt EquivInt; + APInt EquivOffset; + + Combination->getEquivalentICmp(EquivPred, EquivInt, EquivOffset); + + return new ICmpInst( + EquivPred, + Builder.CreateAdd(Op0, ConstantInt::get(Op1->getType(), EquivOffset)), + ConstantInt::get(Op1->getType(), EquivInt)); +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + + // Handle folds that apply for any kind of icmp. + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: + if (auto *Folded = foldICmpUSubSatOrUAddSatWithConstant( + Pred, cast<SaturatingInst>(II), C, Builder)) + return Folded; + break; + case Intrinsic::ctpop: { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + if (Instruction *R = foldCtpopPow2Test(Cmp, II, C, Builder, Q)) + return R; + } break; + } + if (Cmp.isEquality()) return foldICmpEqIntrinsicWithConstant(Cmp, II, C); Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); - ICmpInst::Predicate Pred = Cmp.getPredicate(); switch (II->getIntrinsicID()) { case Intrinsic::ctpop: { // (ctpop X > BitWidth - 1) --> X == -1 @@ -3484,6 +3847,21 @@ Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, } break; } + case Intrinsic::ssub_sat: + // ssub.sat(a, b) spred 0 -> a spred b + if (ICmpInst::isSigned(Pred)) { + if (C.isZero()) + return new ICmpInst(Pred, II->getArgOperand(0), II->getArgOperand(1)); + // X s<= 0 is cannonicalized to X s< 1 + if (Pred == ICmpInst::ICMP_SLT && C.isOne()) + return new ICmpInst(ICmpInst::ICMP_SLE, II->getArgOperand(0), + II->getArgOperand(1)); + // X s>= 0 is cannonicalized to X s> -1 + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) + return new ICmpInst(ICmpInst::ICMP_SGE, II->getArgOperand(0), + II->getArgOperand(1)); + } + break; default: break; } @@ -4014,20 +4392,60 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { return Res; } -static Instruction *foldICmpXNegX(ICmpInst &I) { +static Instruction *foldICmpXNegX(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { CmpInst::Predicate Pred; Value *X; - if (!match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) - return nullptr; + if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) { + + if (ICmpInst::isSigned(Pred)) + Pred = ICmpInst::getSwappedPredicate(Pred); + else if (ICmpInst::isUnsigned(Pred)) + Pred = ICmpInst::getSignedPredicate(Pred); + // else for equality-comparisons just keep the predicate. + + return ICmpInst::Create(Instruction::ICmp, Pred, X, + Constant::getNullValue(X->getType()), I.getName()); + } + + // A value is not equal to its negation unless that value is 0 or + // MinSignedValue, ie: a != -a --> (a & MaxSignedVal) != 0 + if (match(&I, m_c_ICmp(Pred, m_OneUse(m_Neg(m_Value(X))), m_Deferred(X))) && + ICmpInst::isEquality(Pred)) { + Type *Ty = X->getType(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + Constant *MaxSignedVal = + ConstantInt::get(Ty, APInt::getSignedMaxValue(BitWidth)); + Value *And = Builder.CreateAnd(X, MaxSignedVal); + Constant *Zero = Constant::getNullValue(Ty); + return CmpInst::Create(Instruction::ICmp, Pred, And, Zero); + } + + return nullptr; +} - if (ICmpInst::isSigned(Pred)) +static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + // Normalize xor operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_Xor(m_Specific(Op0), m_Value()))) { + std::swap(Op0, Op1); Pred = ICmpInst::getSwappedPredicate(Pred); - else if (ICmpInst::isUnsigned(Pred)) - Pred = ICmpInst::getSignedPredicate(Pred); - // else for equality-comparisons just keep the predicate. + } + if (!match(Op0, m_c_Xor(m_Specific(Op1), m_Value(A)))) + return nullptr; - return ICmpInst::Create(Instruction::ICmp, Pred, X, - Constant::getNullValue(X->getType()), I.getName()); + // icmp (X ^ Y_NonZero) u>= X --> icmp (X ^ Y_NonZero) u> X + // icmp (X ^ Y_NonZero) u<= X --> icmp (X ^ Y_NonZero) u< X + // icmp (X ^ Y_NonZero) s>= X --> icmp (X ^ Y_NonZero) s> X + // icmp (X ^ Y_NonZero) s<= X --> icmp (X ^ Y_NonZero) s< X + CmpInst::Predicate PredOut = CmpInst::getStrictPredicate(Pred); + if (PredOut != Pred && + isKnownNonZero(A, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(PredOut, Op0, Op1); + + return nullptr; } /// Try to fold icmp (binop), X or icmp X, (binop). @@ -4045,7 +4463,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (!BO0 && !BO1) return nullptr; - if (Instruction *NewICmp = foldICmpXNegX(I)) + if (Instruction *NewICmp = foldICmpXNegX(I, Builder)) return NewICmp; const CmpInst::Predicate Pred = I.getPredicate(); @@ -4326,17 +4744,41 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, ConstantExpr::getNeg(RHSC)); } + if (Instruction * R = foldICmpXorXX(I, Q, *this)) + return R; + { - // Try to remove shared constant multiplier from equality comparison: - // X * C == Y * C (with no overflowing/aliasing) --> X == Y - Value *X, *Y; - const APInt *C; - if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && - match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) - if (!C->countTrailingZeros() || - (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) - return new ICmpInst(Pred, X, Y); + // Try to remove shared multiplier from comparison: + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z + Value *X, *Y, *Z; + if (Pred == ICmpInst::getUnsignedPredicate(Pred) && + ((match(Op0, m_Mul(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))) || + (match(Op0, m_Mul(m_Value(Z), m_Value(X))) && + match(Op1, m_c_Mul(m_Specific(Z), m_Value(Y)))))) { + bool NonZero; + if (ICmpInst::isEquality(Pred)) { + KnownBits ZKnown = computeKnownBits(Z, 0, &I); + // if Z % 2 != 0 + // X * Z eq/ne Y * Z -> X eq/ne Y + if (ZKnown.countMaxTrailingZeros() == 0) + return new ICmpInst(Pred, X, Y); + NonZero = !ZKnown.One.isZero() || + isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + // if Z != 0 and nsw(X * Z) and nsw(Y * Z) + // X * Z eq/ne Y * Z -> X eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoSignedWrap() && + BO1->hasNoSignedWrap()) + return new ICmpInst(Pred, X, Y); + } else + NonZero = isKnownNonZero(Z, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + + // If Z != 0 and nuw(X * Z) and nuw(Y * Z) + // X * Z u{lt/le/gt/ge}/eq/ne Y * Z -> X u{lt/le/gt/ge}/eq/ne Y + if (NonZero && BO0 && BO1 && BO0->hasNoUnsignedWrap() && + BO1->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, Y); + } } BinaryOperator *SRem = nullptr; @@ -4405,7 +4847,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, !C->isOne()) { // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) // Mask = -1 >> count-trailing-zeros(C). - if (unsigned TZs = C->countTrailingZeros()) { + if (unsigned TZs = C->countr_zero()) { Constant *Mask = ConstantInt::get( BO0->getType(), APInt::getLowBitsSet(C->getBitWidth(), C->getBitWidth() - TZs)); @@ -4569,6 +5011,59 @@ static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { return nullptr; } +// Canonicalize checking for a power-of-2-or-zero value: +static Instruction *foldICmpPow2Test(ICmpInst &I, + InstCombiner::BuilderTy &Builder) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const CmpInst::Predicate Pred = I.getPredicate(); + Value *A = nullptr; + bool CheckIs; + if (I.isEquality()) { + // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) + // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) + if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), + m_Deferred(A)))) || + !match(Op1, m_ZeroInt())) + A = nullptr; + + // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) + // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) + if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) + A = Op1; + else if (match(Op1, + m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) + A = Op0; + + CheckIs = Pred == ICmpInst::ICMP_EQ; + } else if (ICmpInst::isUnsigned(Pred)) { + // (A ^ (A-1)) u>= A --> ctpop(A) < 2 (two commuted variants) + // ((A-1) ^ A) u< A --> ctpop(A) > 1 (two commuted variants) + + if ((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && + match(Op0, m_OneUse(m_c_Xor(m_Add(m_Specific(Op1), m_AllOnes()), + m_Specific(Op1))))) { + A = Op1; + CheckIs = Pred == ICmpInst::ICMP_UGE; + } else if ((Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE) && + match(Op1, m_OneUse(m_c_Xor(m_Add(m_Specific(Op0), m_AllOnes()), + m_Specific(Op0))))) { + A = Op0; + CheckIs = Pred == ICmpInst::ICMP_ULE; + } + } + + if (A) { + Type *Ty = A->getType(); + CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); + return CheckIs ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, + ConstantInt::get(Ty, 2)) + : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, + ConstantInt::get(Ty, 1)); + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { if (!I.isEquality()) return nullptr; @@ -4604,6 +5099,21 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } + // canoncalize: + // (icmp eq/ne (and X, C), X) + // -> (icmp eq/ne (and X, ~C), 0) + { + Constant *CMask; + A = nullptr; + if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_ImmConstant(CMask))))) + A = Op1; + else if (match(Op1, m_OneUse(m_And(m_Specific(Op0), m_ImmConstant(CMask))))) + A = Op0; + if (A) + return new ICmpInst(Pred, Builder.CreateAnd(A, Builder.CreateNot(CMask)), + Constant::getNullValue(A->getType())); + } + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { // A == (A^B) -> B == 0 Value *OtherVal = A == Op0 ? B : A; @@ -4659,22 +5169,36 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { // (B & (Pow2C-1)) != zext A --> A != trunc B const APInt *MaskC; if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) && - MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits()) + MaskC->countr_one() == A->getType()->getScalarSizeInBits()) return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); + } - // Test if 2 values have different or same signbits: - // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 - // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 + // Test if 2 values have different or same signbits: + // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 + // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 + // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0 + // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1 + Instruction *ExtI; + if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + Instruction *ShiftI; Value *X, *Y; ICmpInst::Predicate Pred2; - if (match(Op0, m_LShr(m_Value(X), m_SpecificIntAllowUndef(OpWidth - 1))) && + if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), + m_Shr(m_Value(X), + m_SpecificIntAllowUndef(OpWidth - 1)))) && match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) && Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) { - Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); - Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) : - Builder.CreateIsNotNeg(Xor); - return replaceInstUsesWith(I, R); + unsigned ExtOpc = ExtI->getOpcode(); + unsigned ShiftOpc = ShiftI->getOpcode(); + if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || + (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { + Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); + Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) + : Builder.CreateIsNotNeg(Xor); + return replaceInstUsesWith(I, R); + } } } @@ -4737,33 +5261,9 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I)) + if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I, Builder)) return ICmp; - // Canonicalize checking for a power-of-2-or-zero value: - // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) - // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) - if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), - m_Deferred(A)))) || - !match(Op1, m_ZeroInt())) - A = nullptr; - - // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) - // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) - if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) - A = Op1; - else if (match(Op1, - m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) - A = Op0; - - if (A) { - Type *Ty = A->getType(); - CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); - return Pred == ICmpInst::ICMP_EQ - ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) - : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); - } - // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps @@ -4794,11 +5294,23 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1, ConstantInt::getNullValue(Op1->getType())); + // Canonicalize: + // icmp eq/ne X, OneUse(rotate-right(X)) + // -> icmp eq/ne X, rotate-left(X) + // We generally try to convert rotate-right -> rotate-left, this just + // canonicalizes another case. + CmpInst::Predicate PredUnused = Pred; + if (match(&I, m_c_ICmp(PredUnused, m_Value(A), + m_OneUse(m_Intrinsic<Intrinsic::fshr>( + m_Deferred(A), m_Deferred(A), m_Value(B)))))) + return new ICmpInst( + Pred, A, + Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B})); + return nullptr; } -static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { ICmpInst::Predicate Pred = ICmp.getPredicate(); Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); @@ -4836,6 +5348,25 @@ static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); } + if (auto *II = dyn_cast<IntrinsicInst>(X)) { + if (II->getIntrinsicID() == Intrinsic::cttz || + II->getIntrinsicID() == Intrinsic::ctlz) { + unsigned MaxRet = SrcBits; + // If the "is_zero_poison" argument is set, then we know at least + // one bit is set in the input, so the result is always at least one + // less than the full bitwidth of that input. + if (match(II->getArgOperand(1), m_One())) + MaxRet--; + + // Make sure the destination is wide enough to hold the largest output of + // the intrinsic. + if (llvm::Log2_32(MaxRet) + 1 <= Op0->getType()->getScalarSizeInBits()) + if (Instruction *I = + foldICmpIntrinsicWithConstant(ICmp, II, C->zext(SrcBits))) + return I; + } + } + return nullptr; } @@ -4855,10 +5386,19 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0)); bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1)); - // If we have mismatched casts, treat the zext of a non-negative source as - // a sext to simulate matching casts. Otherwise, we are done. - // TODO: Can we handle some predicates (equality) without non-negative? if (IsZext0 != IsZext1) { + // If X and Y and both i1 + // (icmp eq/ne (zext X) (sext Y)) + // eq -> (icmp eq (or X, Y), 0) + // ne -> (icmp ne (or X, Y), 0) + if (ICmp.isEquality() && X->getType()->isIntOrIntVectorTy(1) && + Y->getType()->isIntOrIntVectorTy(1)) + return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y), + Constant::getNullValue(X->getType())); + + // If we have mismatched casts, treat the zext of a non-negative source as + // a sext to simulate matching casts. Otherwise, we are done. + // TODO: Can we handle some predicates (equality) without non-negative? if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) IsSignedExt = true; @@ -4993,7 +5533,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); } - if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) + if (Instruction *R = foldICmpWithTrunc(ICmp)) return R; return foldICmpWithZextOrSext(ICmp); @@ -5153,7 +5693,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, return nullptr; if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { const APInt &CVal = CI->getValue(); - if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + if (CVal.getBitWidth() - CVal.countl_zero() > MulWidth) return nullptr; } else { // In this case we could have the operand of the binary operation @@ -5334,44 +5874,18 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { // bits doesn't impact the outcome of the comparison, because any value // greater than the RHS must differ in a bit higher than these due to carry. case ICmpInst::ICMP_UGT: - return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes()); + return APInt::getBitsSetFrom(BitWidth, RHS->countr_one()); // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. case ICmpInst::ICMP_ULT: - return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); + return APInt::getBitsSetFrom(BitWidth, RHS->countr_zero()); default: return APInt::getAllOnes(BitWidth); } } -/// Check if the order of \p Op0 and \p Op1 as operands in an ICmpInst -/// should be swapped. -/// The decision is based on how many times these two operands are reused -/// as subtract operands and their positions in those instructions. -/// The rationale is that several architectures use the same instruction for -/// both subtract and cmp. Thus, it is better if the order of those operands -/// match. -/// \return true if Op0 and Op1 should be swapped. -static bool swapMayExposeCSEOpportunities(const Value *Op0, const Value *Op1) { - // Filter out pointer values as those cannot appear directly in subtract. - // FIXME: we may want to go through inttoptrs or bitcasts. - if (Op0->getType()->isPointerTy()) - return false; - // If a subtract already has the same operands as a compare, swapping would be - // bad. If a subtract has the same operands as a compare but in reverse order, - // then swapping is good. - int GoodToSwap = 0; - for (const User *U : Op0->users()) { - if (match(U, m_Sub(m_Specific(Op1), m_Specific(Op0)))) - GoodToSwap++; - else if (match(U, m_Sub(m_Specific(Op0), m_Specific(Op1)))) - GoodToSwap--; - } - return GoodToSwap > 0; -} - /// Check that one use is in the same block as the definition and all /// other uses are in blocks dominated by a given block. /// @@ -5638,14 +6152,14 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { const APInt *C1; if (match(LHS, m_Shl(m_Power2(C1), m_Value(X)))) { Type *XTy = X->getType(); - unsigned Log2C1 = C1->countTrailingZeros(); + unsigned Log2C1 = C1->countr_zero(); APInt C2 = Op0KnownZeroInverted; APInt C2Pow2 = (C2 & ~(*C1 - 1)) + *C1; if (C2Pow2.isPowerOf2()) { // iff (C1 is pow2) & ((C2 & ~(C1-1)) + C1) is pow2): // ((C1 << X) & C2) == 0 -> X >= (Log2(C2+C1) - Log2(C1)) // ((C1 << X) & C2) != 0 -> X < (Log2(C2+C1) - Log2(C1)) - unsigned Log2C2 = C2Pow2.countTrailingZeros(); + unsigned Log2C2 = C2Pow2.countr_zero(); auto *CmpC = ConstantInt::get(XTy, Log2C2 - Log2C1); auto NewPred = Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; @@ -5653,6 +6167,12 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { } } } + + // Op0 eq C_Pow2 -> Op0 ne 0 if Op0 is known to be C_Pow2 or zero. + if (Op1Known.isConstant() && Op1Known.getConstant().isPowerOf2() && + (Op0Known & Op1Known) == Op0Known) + return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0, + ConstantInt::getNullValue(Op1->getType())); break; } case ICmpInst::ICMP_ULT: { @@ -5733,8 +6253,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { /// If one operand of an icmp is effectively a bool (value range of {0,1}), /// then try to reduce patterns based on that limit. -static Instruction *foldICmpUsingBoolRange(ICmpInst &I, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { Value *X, *Y; ICmpInst::Predicate Pred; @@ -5750,6 +6269,60 @@ static Instruction *foldICmpUsingBoolRange(ICmpInst &I, Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + const APInt *C; + if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) && + match(I.getOperand(1), m_APInt(C)) && + X->getType()->isIntOrIntVectorTy(1) && + Y->getType()->isIntOrIntVectorTy(1)) { + unsigned BitWidth = C->getBitWidth(); + Pred = I.getPredicate(); + APInt Zero = APInt::getZero(BitWidth); + APInt MinusOne = APInt::getAllOnes(BitWidth); + APInt One(BitWidth, 1); + if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) || + (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT)) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) || + (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT)) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + if (I.getOperand(0)->hasOneUse()) { + APInt NewC = *C; + // canonicalize predicate to eq/ne + if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) || + (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) { + // x s< 0 in [-1, 1] --> x == -1 + // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1 + NewC = MinusOne; + Pred = ICmpInst::ICMP_EQ; + } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) || + (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) { + // x s> -1 in [-1, 1] --> x != -1 + // x u< -1 in [-1, 1] --> x != -1 + Pred = ICmpInst::ICMP_NE; + } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) { + // x s> 0 in [-1, 1] --> x == 1 + NewC = One; + Pred = ICmpInst::ICMP_EQ; + } else if (*C == One && Pred == ICmpInst::ICMP_SLT) { + // x s< 1 in [-1, 1] --> x != 1 + Pred = ICmpInst::ICMP_NE; + } + + if (NewC == MinusOne) { + if (Pred == ICmpInst::ICMP_EQ) + return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); + if (Pred == ICmpInst::ICMP_NE) + return BinaryOperator::CreateOr(X, Builder.CreateNot(Y)); + } else if (NewC == One) { + if (Pred == ICmpInst::ICMP_EQ) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y)); + if (Pred == ICmpInst::ICMP_NE) + return BinaryOperator::CreateOr(Builder.CreateNot(X), Y); + } + } + } + return nullptr; } @@ -6162,8 +6735,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { /// Orders the operands of the compare so that they are listed from most /// complex to least complex. This puts constants before unary operators, /// before binary operators. - if (Op0Cplxity < Op1Cplxity || - (Op0Cplxity == Op1Cplxity && swapMayExposeCSEOpportunities(Op0, Op1))) { + if (Op0Cplxity < Op1Cplxity) { I.swapOperands(); std::swap(Op0, Op1); Changed = true; @@ -6205,7 +6777,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; - if (Instruction *Res = foldICmpUsingBoolRange(I, Builder)) + if (Instruction *Res = foldICmpUsingBoolRange(I)) return Res; if (Instruction *Res = foldICmpUsingKnownBits(I)) @@ -6288,15 +6860,46 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I)) return NI; + // In case of a comparison with two select instructions having the same + // condition, check whether one of the resulting branches can be simplified. + // If so, just compare the other branch and select the appropriate result. + // For example: + // %tmp1 = select i1 %cmp, i32 %y, i32 %x + // %tmp2 = select i1 %cmp, i32 %z, i32 %x + // %cmp2 = icmp slt i32 %tmp2, %tmp1 + // The icmp will result false for the false value of selects and the result + // will depend upon the comparison of true values of selects if %cmp is + // true. Thus, transform this into: + // %cmp = icmp slt i32 %y, %z + // %sel = select i1 %cond, i1 %cmp, i1 false + // This handles similar cases to transform. + { + Value *Cond, *A, *B, *C, *D; + if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) && + match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // Check whether comparison of TrueValues can be simplified + if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) { + Value *NewICMP = Builder.CreateICmp(Pred, B, D); + return SelectInst::Create(Cond, Res, NewICMP); + } + // Check whether comparison of FalseValues can be simplified + if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) { + Value *NewICMP = Builder.CreateICmp(Pred, A, C); + return SelectInst::Create(Cond, NewICMP, Res); + } + } + } + // Try to optimize equality comparisons against alloca-based pointers. if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op0))) - if (Instruction *New = foldAllocaCmp(I, Alloca)) - return New; + if (foldAllocaCmp(Alloca)) + return nullptr; if (auto *Alloca = dyn_cast<AllocaInst>(getUnderlyingObject(Op1))) - if (Instruction *New = foldAllocaCmp(I, Alloca)) - return New; + if (foldAllocaCmp(Alloca)) + return nullptr; } if (Instruction *Res = foldICmpBitCast(I)) @@ -6363,6 +6966,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpEquality(I)) return Res; + if (Instruction *Res = foldICmpPow2Test(I, Builder)) + return Res; + if (Instruction *Res = foldICmpOfUAddOv(I)) return Res; @@ -6717,7 +7323,7 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { Mode.Input == DenormalMode::PositiveZero) { auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { - Constant *Zero = ConstantFP::getNullValue(X->getType()); + Constant *Zero = ConstantFP::getZero(X->getType()); return new FCmpInst(P, X, Zero, "", I); }; @@ -6813,7 +7419,7 @@ static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { // Replace the negated operand with 0.0: // fcmp Pred Op0, -Op0 --> fcmp Pred Op0, 0.0 - Constant *Zero = ConstantFP::getNullValue(Op0->getType()); + Constant *Zero = ConstantFP::getZero(Op0->getType()); return new FCmpInst(Pred, Op0, Zero, "", &I); } @@ -6863,11 +7469,13 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) - return replaceOperand(I, 0, ConstantFP::getNullValue(OpType)); + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, DL, &TLI, 0, + &AC, &I, &DT)) + return replaceOperand(I, 0, ConstantFP::getZero(OpType)); - if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) - return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + if (!match(Op1, m_PosZeroFP()) && + isKnownNeverNaN(Op1, DL, &TLI, 0, &AC, &I, &DT)) + return replaceOperand(I, 1, ConstantFP::getZero(OpType)); } // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y @@ -6896,7 +7504,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) - return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + return replaceOperand(I, 1, ConstantFP::getZero(OpType)); // Ignore signbit of bitcasted int when comparing equality to FP 0.0: // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 @@ -6985,11 +7593,11 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_ONE: // X is ordered and not equal to an impossible constant --> ordered return new FCmpInst(FCmpInst::FCMP_ORD, X, - ConstantFP::getNullValue(X->getType())); + ConstantFP::getZero(X->getType())); case FCmpInst::FCMP_UEQ: // X is unordered or equal to an impossible constant --> unordered return new FCmpInst(FCmpInst::FCMP_UNO, X, - ConstantFP::getNullValue(X->getType())); + ConstantFP::getZero(X->getType())); case FCmpInst::FCMP_UNE: // X is unordered or not equal to an impossible constant --> true return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index f4e88b122383..701579e1de48 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -150,7 +150,6 @@ public: Instruction *visitPHINode(PHINode &PN); Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); Instruction *visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src); - Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP); Instruction *visitAllocaInst(AllocaInst &AI); Instruction *visitAllocSite(Instruction &FI); Instruction *visitFree(CallInst &FI, Value *FreedOp); @@ -330,8 +329,7 @@ private: Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); Instruction *matchSAddSubSat(IntrinsicInst &MinMax1); Instruction *foldNot(BinaryOperator &I); - - void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); + Instruction *foldBinOpOfDisplacedShifts(BinaryOperator &I); /// Determine if a pair of casts can be replaced by a single cast. /// @@ -378,6 +376,7 @@ private: Instruction *foldLShrOverflowBit(BinaryOperator &I); Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); + Instruction *foldIntrinsicIsFPClass(IntrinsicInst &II); Instruction *foldFPSignBitOps(BinaryOperator &I); Instruction *foldFDivConstantDivisor(BinaryOperator &I); @@ -393,12 +392,12 @@ public: /// without having to rewrite the CFG from within InstCombine. void CreateNonTerminatorUnreachable(Instruction *InsertAt) { auto &Ctx = InsertAt->getContext(); - new StoreInst(ConstantInt::getTrue(Ctx), - PoisonValue::get(Type::getInt1PtrTy(Ctx)), - InsertAt); + auto *SI = new StoreInst(ConstantInt::getTrue(Ctx), + PoisonValue::get(Type::getInt1PtrTy(Ctx)), + /*isVolatile*/ false, Align(1)); + InsertNewInstBefore(SI, *InsertAt); } - /// Combiner aware instruction erasure. /// /// When dealing with an instruction that has side effects or produces a void @@ -411,12 +410,11 @@ public: // Make sure that we reprocess all operands now that we reduced their // use counts. - for (Use &Operand : I.operands()) - if (auto *Inst = dyn_cast<Instruction>(Operand)) - Worklist.add(Inst); - + SmallVector<Value *> Ops(I.operands()); Worklist.remove(&I); I.eraseFromParent(); + for (Value *Op : Ops) + Worklist.handleUseCountDecrement(Op); MadeIRChange = true; return nullptr; // Don't do anything with FI } @@ -450,6 +448,18 @@ public: Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, Value *RHS); + // (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C)) + // -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C) + // (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt)) + // -> (BinOp (logic_shift (BinOp X, Y)), Mask) + Instruction *foldBinOpShiftWithShift(BinaryOperator &I); + + /// Tries to simplify binops of select and cast of the select condition. + /// + /// (Binop (cast C), (select C, T, F)) + /// -> (select C, C0, C1) + Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorizationFolds(BinaryOperator &I); @@ -549,7 +559,7 @@ public: ICmpInst::Predicate Cond, Instruction &I); Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI, Value *RHS, const ICmpInst &I); - Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca); + bool foldAllocaCmp(AllocaInst *Alloca); Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, @@ -564,6 +574,7 @@ public: Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpUsingBoolRange(ICmpInst &I); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, @@ -623,6 +634,7 @@ public: Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, const APInt &C); Instruction *foldICmpBitCast(ICmpInst &Cmp); + Instruction *foldICmpWithTrunc(ICmpInst &Cmp); // Helpers of visitSelectInst(). Instruction *foldSelectOfBools(SelectInst &SI); @@ -634,10 +646,11 @@ public: SelectPatternFlavor SPF2, Value *C); Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI); + bool replaceInInstruction(Value *V, Value *Old, Value *New, + unsigned Depth = 0); Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside); - Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); bool mergeStoreIntoSuccessor(StoreInst &SI); /// Given an initial instruction, check to see if it is the root of a @@ -651,10 +664,12 @@ public: Value *EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned); - /// Returns a value X such that Val = X * Scale, or null if none. - /// - /// If the multiplication is known not to overflow then NoSignedWrap is set. - Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); + bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock); + + bool removeInstructionsBeforeUnreachable(Instruction &I); + bool handleUnreachableFrom(Instruction *I); + bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); + void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); }; class Negator final { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 41bc65620ff6..6aa20ee26b9a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -32,7 +32,7 @@ STATISTIC(NumDeadStore, "Number of dead stores eliminated"); STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global"); static cl::opt<unsigned> MaxCopiedFromConstantUsers( - "instcombine-max-copied-from-constant-users", cl::init(128), + "instcombine-max-copied-from-constant-users", cl::init(300), cl::desc("Maximum users to visit in copy from constant transform"), cl::Hidden); @@ -219,7 +219,7 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, // Now that I is pointing to the first non-allocation-inst in the block, // insert our getelementptr instruction... // - Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType()); + Type *IdxTy = IC.getDataLayout().getIndexType(AI.getType()); Value *NullIdx = Constant::getNullValue(IdxTy); Value *Idx[2] = {NullIdx, NullIdx}; Instruction *GEP = GetElementPtrInst::CreateInBounds( @@ -235,11 +235,12 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, if (isa<UndefValue>(AI.getArraySize())) return IC.replaceInstUsesWith(AI, Constant::getNullValue(AI.getType())); - // Ensure that the alloca array size argument has type intptr_t, so that - // any casting is exposed early. - Type *IntPtrTy = IC.getDataLayout().getIntPtrType(AI.getType()); - if (AI.getArraySize()->getType() != IntPtrTy) { - Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), IntPtrTy, false); + // Ensure that the alloca array size argument has type equal to the offset + // size of the alloca() pointer, which, in the tyical case, is intptr_t, + // so that any casting is exposed early. + Type *PtrIdxTy = IC.getDataLayout().getIndexType(AI.getType()); + if (AI.getArraySize()->getType() != PtrIdxTy) { + Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), PtrIdxTy, false); return IC.replaceOperand(AI, 0, V); } @@ -259,8 +260,8 @@ namespace { // instruction. class PointerReplacer { public: - PointerReplacer(InstCombinerImpl &IC, Instruction &Root) - : IC(IC), Root(Root) {} + PointerReplacer(InstCombinerImpl &IC, Instruction &Root, unsigned SrcAS) + : IC(IC), Root(Root), FromAS(SrcAS) {} bool collectUsers(); void replacePointer(Value *V); @@ -273,11 +274,21 @@ private: return I == &Root || Worklist.contains(I); } + bool isEqualOrValidAddrSpaceCast(const Instruction *I, + unsigned FromAS) const { + const auto *ASC = dyn_cast<AddrSpaceCastInst>(I); + if (!ASC) + return false; + unsigned ToAS = ASC->getDestAddressSpace(); + return (FromAS == ToAS) || IC.isValidAddrSpaceCast(FromAS, ToAS); + } + SmallPtrSet<Instruction *, 32> ValuesToRevisit; SmallSetVector<Instruction *, 4> Worklist; MapVector<Value *, Value *> WorkMap; InstCombinerImpl &IC; Instruction &Root; + unsigned FromAS; }; } // end anonymous namespace @@ -341,6 +352,8 @@ bool PointerReplacer::collectUsersRecursive(Instruction &I) { if (MI->isVolatile()) return false; Worklist.insert(Inst); + } else if (isEqualOrValidAddrSpaceCast(Inst, FromAS)) { + Worklist.insert(Inst); } else if (Inst->isLifetimeStartOrEnd()) { continue; } else { @@ -391,9 +404,8 @@ void PointerReplacer::replace(Instruction *I) { } else if (auto *BC = dyn_cast<BitCastInst>(I)) { auto *V = getReplacement(BC->getOperand(0)); assert(V && "Operand not replaced"); - auto *NewT = PointerType::getWithSamePointeeType( - cast<PointerType>(BC->getType()), - V->getType()->getPointerAddressSpace()); + auto *NewT = PointerType::get(BC->getType()->getContext(), + V->getType()->getPointerAddressSpace()); auto *NewI = new BitCastInst(V, NewT); IC.InsertNewInstWith(NewI, *BC); NewI->takeName(BC); @@ -426,6 +438,22 @@ void PointerReplacer::replace(Instruction *I) { IC.eraseInstFromFunction(*MemCpy); WorkMap[MemCpy] = NewI; + } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I)) { + auto *V = getReplacement(ASC->getPointerOperand()); + assert(V && "Operand not replaced"); + assert(isEqualOrValidAddrSpaceCast( + ASC, V->getType()->getPointerAddressSpace()) && + "Invalid address space cast!"); + auto *NewV = V; + if (V->getType()->getPointerAddressSpace() != + ASC->getType()->getPointerAddressSpace()) { + auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); + NewI->takeName(ASC); + IC.InsertNewInstWith(NewI, *ASC); + NewV = NewI; + } + IC.replaceInstUsesWith(*ASC, NewV); + IC.eraseInstFromFunction(*ASC); } else { llvm_unreachable("should never reach here"); } @@ -435,7 +463,7 @@ void PointerReplacer::replacePointer(Value *V) { #ifndef NDEBUG auto *PT = cast<PointerType>(Root.getType()); auto *NT = cast<PointerType>(V->getType()); - assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage"); + assert(PT != NT && "Invalid usage"); #endif WorkMap[&Root] = V; @@ -518,7 +546,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { return NewI; } - PointerReplacer PtrReplacer(*this, AI); + PointerReplacer PtrReplacer(*this, AI, SrcAddrSpace); if (PtrReplacer.collectUsers()) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); @@ -739,6 +767,11 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { // the knowledge that padding exists for the rest of the pipeline. const DataLayout &DL = IC.getDataLayout(); auto *SL = DL.getStructLayout(ST); + + // Don't unpack for structure with scalable vector. + if (SL->getSizeInBits().isScalable()) + return nullptr; + if (SL->hasPadding()) return nullptr; @@ -979,17 +1012,15 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, // If we're indexing into an object with a variable index for the memory // access, but the object has only one element, we can assume that the index // will always be zero. If we replace the GEP, return it. -template <typename T> static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr, - T &MemI) { + Instruction &MemI) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) { unsigned Idx; if (canReplaceGEPIdxWithZero(IC, GEPI, &MemI, Idx)) { Instruction *NewGEPI = GEPI->clone(); NewGEPI->setOperand(Idx, ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0)); - NewGEPI->insertBefore(GEPI); - MemI.setOperand(MemI.getPointerOperandIndex(), NewGEPI); + IC.InsertNewInstBefore(NewGEPI, *GEPI); return NewGEPI; } } @@ -1024,6 +1055,8 @@ static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { Value *Op = LI.getOperand(0); + if (Value *Res = simplifyLoadInst(&LI, Op, SQ.getWithInstruction(&LI))) + return replaceInstUsesWith(LI, Res); // Try to canonicalize the loaded type. if (Instruction *Res = combineLoadToOperationType(*this, LI)) @@ -1036,10 +1069,8 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { LI.setAlignment(KnownAlign); // Replace GEP indices if possible. - if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) { - Worklist.push(NewGEPI); - return &LI; - } + if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) + return replaceOperand(LI, 0, NewGEPI); if (Instruction *Res = unpackLoadToAggregate(*this, LI)) return Res; @@ -1065,13 +1096,7 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { // load null/undef -> unreachable // TODO: Consider a target hook for valid address spaces for this xforms. if (canSimplifyNullLoadOrGEP(LI, Op)) { - // Insert a new store to null instruction before the load to indicate - // that this code is not reachable. We do this instead of inserting - // an unreachable instruction directly because we cannot modify the - // CFG. - StoreInst *SI = new StoreInst(PoisonValue::get(LI.getType()), - Constant::getNullValue(Op->getType()), &LI); - SI->setDebugLoc(LI.getDebugLoc()); + CreateNonTerminatorUnreachable(&LI); return replaceInstUsesWith(LI, PoisonValue::get(LI.getType())); } @@ -1261,6 +1286,11 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { // the knowledge that padding exists for the rest of the pipeline. const DataLayout &DL = IC.getDataLayout(); auto *SL = DL.getStructLayout(ST); + + // Don't unpack for structure with scalable vector. + if (SL->getSizeInBits().isScalable()) + return false; + if (SL->hasPadding()) return false; @@ -1443,10 +1473,8 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { return eraseInstFromFunction(SI); // Replace GEP indices if possible. - if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { - Worklist.push(NewGEPI); - return &SI; - } + if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) + return replaceOperand(SI, 1, NewGEPI); // Don't hack volatile/ordered stores. // FIXME: Some bits are legal for ordered atomic stores; needs refactoring. @@ -1530,6 +1558,16 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { return nullptr; // Do not modify these! } + // This is a non-terminator unreachable marker. Don't remove it. + if (isa<UndefValue>(Ptr)) { + // Remove all instructions after the marker and guaranteed-to-transfer + // instructions before the marker. + if (handleUnreachableFrom(SI.getNextNode()) || + removeInstructionsBeforeUnreachable(SI)) + return &SI; + return nullptr; + } + // store undef, Ptr -> noop // FIXME: This is technically incorrect because it might overwrite a poison // value. Change to PoisonValue once #52930 is resolved. @@ -1571,6 +1609,17 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { if (!OtherBr || BBI == OtherBB->begin()) return false; + auto OtherStoreIsMergeable = [&](StoreInst *OtherStore) -> bool { + if (!OtherStore || + OtherStore->getPointerOperand() != SI.getPointerOperand()) + return false; + + auto *SIVTy = SI.getValueOperand()->getType(); + auto *OSVTy = OtherStore->getValueOperand()->getType(); + return CastInst::isBitOrNoopPointerCastable(OSVTy, SIVTy, DL) && + SI.hasSameSpecialState(OtherStore); + }; + // If the other block ends in an unconditional branch, check for the 'if then // else' case. There is an instruction before the branch. StoreInst *OtherStore = nullptr; @@ -1586,8 +1635,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { // If this isn't a store, isn't a store to the same location, or is not the // right kind of store, bail out. OtherStore = dyn_cast<StoreInst>(BBI); - if (!OtherStore || OtherStore->getOperand(1) != SI.getOperand(1) || - !SI.isSameOperationAs(OtherStore)) + if (!OtherStoreIsMergeable(OtherStore)) return false; } else { // Otherwise, the other block ended with a conditional branch. If one of the @@ -1601,12 +1649,10 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { // lives in OtherBB. for (;; --BBI) { // Check to see if we find the matching store. - if ((OtherStore = dyn_cast<StoreInst>(BBI))) { - if (OtherStore->getOperand(1) != SI.getOperand(1) || - !SI.isSameOperationAs(OtherStore)) - return false; + OtherStore = dyn_cast<StoreInst>(BBI); + if (OtherStoreIsMergeable(OtherStore)) break; - } + // If we find something that may be using or overwriting the stored // value, or if we run out of instructions, we can't do the transform. if (BBI->mayReadFromMemory() || BBI->mayThrow() || @@ -1624,14 +1670,17 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { } // Insert a PHI node now if we need it. - Value *MergedVal = OtherStore->getOperand(0); + Value *MergedVal = OtherStore->getValueOperand(); // The debug locations of the original instructions might differ. Merge them. DebugLoc MergedLoc = DILocation::getMergedLocation(SI.getDebugLoc(), OtherStore->getDebugLoc()); - if (MergedVal != SI.getOperand(0)) { - PHINode *PN = PHINode::Create(MergedVal->getType(), 2, "storemerge"); - PN->addIncoming(SI.getOperand(0), SI.getParent()); - PN->addIncoming(OtherStore->getOperand(0), OtherBB); + if (MergedVal != SI.getValueOperand()) { + PHINode *PN = + PHINode::Create(SI.getValueOperand()->getType(), 2, "storemerge"); + PN->addIncoming(SI.getValueOperand(), SI.getParent()); + Builder.SetInsertPoint(OtherStore); + PN->addIncoming(Builder.CreateBitOrPointerCast(MergedVal, PN->getType()), + OtherBB); MergedVal = InsertNewInstBefore(PN, DestBB->front()); PN->setDebugLoc(MergedLoc); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 97f129e200de..50458e2773e6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -185,6 +185,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, return nullptr; } +static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, + bool AssumeNonZero, bool DoFold); + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = @@ -270,7 +273,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (match(Op0, m_ZExtOrSExt(m_Value(X))) && match(Op1, m_APIntAllowUndef(NegPow2C))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); - unsigned ShiftAmt = NegPow2C->countTrailingZeros(); + unsigned ShiftAmt = NegPow2C->countr_zero(); if (ShiftAmt >= BitWidth - SrcWidth) { Value *N = Builder.CreateNeg(X, X->getName() + ".neg"); Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z"); @@ -471,6 +474,40 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + + // min(X, Y) * max(X, Y) => X * Y. + if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)), + m_c_SMin(m_Deferred(X), m_Deferred(Y))), + m_c_Mul(m_UMax(m_Value(X), m_Value(Y)), + m_c_UMin(m_Deferred(X), m_Deferred(Y)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I); + + // (mul Op0 Op1): + // if Log2(Op0) folds away -> + // (shl Op1, Log2(Op0)) + // if Log2(Op1) folds away -> + // (shl Op0, Log2(Op1)) + if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res); + // We can only propegate nuw flag. + Shl->setHasNoUnsignedWrap(HasNUW); + return Shl; + } + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res); + // We can only propegate nuw flag. + Shl->setHasNoUnsignedWrap(HasNUW); + return Shl; + } + bool Changed = false; if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; @@ -765,6 +802,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { I.hasNoSignedZeros() && match(Start, m_Zero())) return replaceInstUsesWith(I, Start); + // minimun(X, Y) * maximum(X, Y) => X * Y. + if (match(&I, + m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), + m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), + m_Deferred(Y))))) { + BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I); + // We cannot preserve ninf if nnan flag is not set. + // If X is NaN and Y is Inf then in original program we had NaN * NaN, + // while in optimized version NaN * Inf and this is a poison with ninf flag. + if (!Result->hasNoNaNs()) + Result->setHasNoInfs(false); + return Result; + } + return nullptr; } @@ -976,9 +1027,9 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { ConstantInt::get(Ty, Product)); } + APInt Quotient(C2->getBitWidth(), /*val=*/0ULL, IsSigned); if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. if (isMultiple(*C2, *C1, Quotient, IsSigned)) { @@ -1003,7 +1054,6 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { C1->ult(C1->getBitWidth() - 1)) || (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) && C1->ult(C1->getBitWidth()))) { - APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); APInt C1Shifted = APInt::getOneBitSet( C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue())); @@ -1026,6 +1076,23 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } + // Distribute div over add to eliminate a matching div/mul pair: + // ((X * C2) + C1) / C2 --> X + C1/C2 + // We need a multiple of the divisor for a signed add constant, but + // unsigned is fine with any constant pair. + if (IsSigned && + match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1))) && + isMultiple(*C1, *C2, Quotient, IsSigned)) { + return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient)); + } + if (!IsSigned && + match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1)))) { + return BinaryOperator::CreateNUWAdd(X, + ConstantInt::get(Ty, C1->udiv(*C2))); + } + if (!C2->isZero()) // avoid X udiv 0 if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) return FoldedDiv; @@ -1121,7 +1188,7 @@ static const unsigned MaxDepth = 6; // actual instructions, otherwise return a non-null dummy value. Return nullptr // on failure. static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold) { + bool AssumeNonZero, bool DoFold) { auto IfFold = [DoFold](function_ref<Value *()> Fn) { if (!DoFold) return reinterpret_cast<Value *>(-1); @@ -1147,14 +1214,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(X << Y) -> log2(X) + Y // FIXME: Require one use unless X is 1? - if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) - return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) { + auto *BO = cast<OverflowingBinaryOperator>(Op); + // nuw will be set if the `shl` is trivially non-zero. + if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: missed optimization: if one of the hands of select is/contains @@ -1162,8 +1233,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast<SelectInst>(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1171,13 +1244,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op); - if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) { + // Use AssumeNonZero as false here. Otherwise we can hit case where + // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) return IfFold([&]() { - return Builder.CreateBinaryIntrinsic( - MinMax->getIntrinsicID(), LogX, LogY); + return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, + LogY); }); + } return nullptr; } @@ -1297,8 +1375,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, + /*AssumeNonZero*/ true, /*DoFold*/ true); return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); } @@ -1359,7 +1439,8 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // (sext X) sdiv C --> sext (X sdiv C) Value *Op0Src; if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && - Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + Op0Src->getType()->getScalarSizeInBits() >= + Op1C->getSignificantBits()) { // In the general case, we need to make sure that the dividend is not the // minimum signed value because dividing that by -1 is UB. But here, we @@ -1402,7 +1483,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { KnownBits KnownDividend = computeKnownBits(Op0, 0, &I); if (!I.isExact() && (match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) && - KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) { + KnownDividend.countMinTrailingZeros() >= Op1C->countr_zero()) { I.setIsExact(); return &I; } @@ -1681,6 +1762,111 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { return nullptr; } +// Variety of transform for: +// (urem/srem (mul X, Y), (mul X, Z)) +// (urem/srem (shl X, Y), (shl X, Z)) +// (urem/srem (shl Y, X), (shl Z, X)) +// NB: The shift cases are really just extensions of the mul case. We treat +// shift as Val * (1 << Amt). +static Instruction *simplifyIRemMulShl(BinaryOperator &I, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr; + APInt Y, Z; + bool ShiftByX = false; + + // If V is not nullptr, it will be matched using m_Specific. + auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp))))) + C = *Tmp; + else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) + C = APInt(Tmp->getBitWidth(), 1) << *Tmp; + if (Tmp != nullptr) + return true; + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) || + (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) { + C = *Tmp; + return true; + } + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) { + // pass + } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) { + ShiftByX = true; + } else { + return nullptr; + } + + bool IsSRem = I.getOpcode() == Instruction::SRem; + + OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0); + // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >= + // Z or Z >= Y. + bool BO0HasNSW = BO0->hasNoSignedWrap(); + bool BO0HasNUW = BO0->hasNoUnsignedWrap(); + bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW; + + APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z); + // (rem (mul nuw/nsw X, Y), (mul X, Z)) + // if (rem Y, Z) == 0 + // -> 0 + if (RemYZ.isZero() && BO0NoWrap) + return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + + // Helper function to emit either (RemSimplificationC << X) or + // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as + // (shl V, X) or (mul V, X) respectively. + auto CreateMulOrShift = + [&](const APInt &RemSimplificationC) -> BinaryOperator * { + Value *RemSimplification = + ConstantInt::get(I.getType(), RemSimplificationC); + return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X) + : BinaryOperator::CreateMul(X, RemSimplification); + }; + + OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1); + bool BO1HasNSW = BO1->hasNoSignedWrap(); + bool BO1HasNUW = BO1->hasNoUnsignedWrap(); + bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW; + // (rem (mul X, Y), (mul nuw/nsw X, Z)) + // if (rem Y, Z) == Y + // -> (mul nuw/nsw X, Y) + if (RemYZ == Y && BO1NoWrap) { + BinaryOperator *BO = CreateMulOrShift(Y); + // Copy any overflow flags from Op0. + BO->setHasNoSignedWrap(IsSRem || BO0HasNSW); + BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW); + return BO; + } + + // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) + // if Y >= Z + // -> (mul {nuw} nsw X, (rem Y, Z)) + if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { + BinaryOperator *BO = CreateMulOrShift(RemYZ); + BO->setHasNoSignedWrap(); + BO->setHasNoUnsignedWrap(BO0HasNUW); + return BO; + } + + return nullptr; +} + /// This function implements the transforms common to both integer remainder /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. @@ -1733,6 +1919,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { } } + if (Instruction *R = simplifyIRemMulShl(I, *this)) + return R; + return nullptr; } @@ -1782,8 +1971,21 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 Value *X; if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); - return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); + Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = + Builder.CreateICmpEQ(FrozenOp0, ConstantInt::getAllOnesValue(Ty)); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); + } + + // For "(X + 1) % Op1" and if (X u< Op1) => (X + 1) == Op1 ? 0 : X + 1 . + if (match(Op0, m_Add(m_Value(X), m_One()))) { + Value *Val = + simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ.getWithInstruction(&I)); + if (Val && match(Val, m_One())) { + Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, Op1); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); + } } return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 7f59729f0085..2f6aa85062a5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -316,7 +316,7 @@ Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) { for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) { if (auto *NewOp = simplifyIntToPtrRoundTripCast(PN.getIncomingValue(OpNum))) { - PN.setIncomingValue(OpNum, NewOp); + replaceOperand(PN, OpNum, NewOp); OperandWithRoundTripCast = true; } } @@ -745,6 +745,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_access_group, + LLVMContext::MD_noundef, }; for (unsigned ID : KnownIDs) @@ -1388,11 +1389,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // If all PHI operands are the same operation, pull them through the PHI, // reducing code size. - if (isa<Instruction>(PN.getIncomingValue(0)) && - isa<Instruction>(PN.getIncomingValue(1)) && - cast<Instruction>(PN.getIncomingValue(0))->getOpcode() == - cast<Instruction>(PN.getIncomingValue(1))->getOpcode() && - PN.getIncomingValue(0)->hasOneUser()) + auto *Inst0 = dyn_cast<Instruction>(PN.getIncomingValue(0)); + auto *Inst1 = dyn_cast<Instruction>(PN.getIncomingValue(1)); + if (Inst0 && Inst1 && Inst0->getOpcode() == Inst1->getOpcode() && + Inst0->hasOneUser()) if (Instruction *Result = foldPHIArgOpIntoPHI(PN)) return Result; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index e7d8208f94fd..661c50062223 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -98,7 +98,8 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, // +0.0 compares equal to -0.0, and so it does not behave as required for this // transform. Bail out if we can not exclude that possibility. if (isa<FPMathOperator>(BO)) - if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) + if (!BO->hasNoSignedZeros() && + !cannotBeNegativeZero(Y, IC.getDataLayout(), &TLI)) return nullptr; // BO = binop Y, X @@ -386,6 +387,32 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); } } + + // select c, (ldexp v, e0), (ldexp v, e1) -> ldexp v, (select c, e0, e1) + // select c, (ldexp v0, e), (ldexp v1, e) -> ldexp (select c, v0, v1), e + // + // select c, (ldexp v0, e0), (ldexp v1, e1) -> + // ldexp (select c, v0, v1), (select c, e0, e1) + if (TII->getIntrinsicID() == Intrinsic::ldexp) { + Value *LdexpVal0 = TII->getArgOperand(0); + Value *LdexpExp0 = TII->getArgOperand(1); + Value *LdexpVal1 = FII->getArgOperand(0); + Value *LdexpExp1 = FII->getArgOperand(1); + if (LdexpExp0->getType() == LdexpExp1->getType()) { + FPMathOperator *SelectFPOp = cast<FPMathOperator>(&SI); + FastMathFlags FMF = cast<FPMathOperator>(TII)->getFastMathFlags(); + FMF &= cast<FPMathOperator>(FII)->getFastMathFlags(); + FMF |= SelectFPOp->getFastMathFlags(); + + Value *SelectVal = Builder.CreateSelect(Cond, LdexpVal0, LdexpVal1); + Value *SelectExp = Builder.CreateSelect(Cond, LdexpExp0, LdexpExp1); + + CallInst *NewLdexp = Builder.CreateIntrinsic( + TII->getType(), Intrinsic::ldexp, {SelectVal, SelectExp}); + NewLdexp->setFastMathFlags(FMF); + return replaceInstUsesWith(SI, NewLdexp); + } + } } // icmp with a common operand also can have the common operand @@ -429,6 +456,21 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, !OtherOpF->getType()->isVectorTy())) return nullptr; + // If we are sinking div/rem after a select, we may need to freeze the + // condition because div/rem may induce immediate UB with a poison operand. + // For example, the following transform is not safe if Cond can ever be poison + // because we can replace poison with zero and then we have div-by-zero that + // didn't exist in the original code: + // Cond ? x/y : x/z --> x / (Cond ? y : z) + auto *BO = dyn_cast<BinaryOperator>(TI); + if (BO && BO->isIntDivRem() && !isGuaranteedNotToBePoison(Cond)) { + // A udiv/urem with a common divisor is safe because UB can only occur with + // div-by-zero, and that would be present in the original code. + if (BO->getOpcode() == Instruction::SDiv || + BO->getOpcode() == Instruction::SRem || MatchIsOpZero) + Cond = Builder.CreateFreeze(Cond); + } + // If we reach here, they do have operations in common. Value *NewSI = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, SI.getName() + ".v", &SI); @@ -461,7 +503,7 @@ static bool isSelect01(const APInt &C1I, const APInt &C2I) { /// optimization. Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { - // See the comment above GetSelectFoldableOperands for a description of the + // See the comment above getSelectFoldableOperands for a description of the // transformation we are doing here. auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, Value *FalseVal, @@ -496,7 +538,7 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, - Swapped ? OOp : C); + Swapped ? OOp : C, "", &SI); if (isa<FPMathOperator>(&SI)) cast<Instruction>(NewSel)->setFastMathFlags(FMF); NewSel->takeName(TVI); @@ -569,6 +611,44 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp, } /// We want to turn: +/// (select (icmp eq (and X, C1), 0), 0, (shl [nsw/nuw] X, C2)); +/// iff C1 is a mask and the number of its leading zeros is equal to C2 +/// into: +/// shl X, C2 +static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal, + Value *FVal, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred; + Value *AndVal; + if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero()))) + return nullptr; + + if (Pred == ICmpInst::ICMP_NE) { + Pred = ICmpInst::ICMP_EQ; + std::swap(TVal, FVal); + } + + Value *X; + const APInt *C2, *C1; + if (Pred != ICmpInst::ICMP_EQ || + !match(AndVal, m_And(m_Value(X), m_APInt(C1))) || + !match(TVal, m_Zero()) || !match(FVal, m_Shl(m_Specific(X), m_APInt(C2)))) + return nullptr; + + if (!C1->isMask() || + C1->countLeadingZeros() != static_cast<unsigned>(C2->getZExtValue())) + return nullptr; + + auto *FI = dyn_cast<Instruction>(FVal); + if (!FI) + return nullptr; + + FI->setHasNoSignedWrap(false); + FI->setHasNoUnsignedWrap(false); + return FVal; +} + +/// We want to turn: /// (select (icmp sgt x, C), lshr (X, Y), ashr (X, Y)); iff C s>= -1 /// (select (icmp slt x, C), ashr (X, Y), lshr (X, Y)); iff C s>= 0 /// into: @@ -935,10 +1015,53 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Try to match patterns with select and subtract as absolute difference. +static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + auto *TI = dyn_cast<Instruction>(TVal); + auto *FI = dyn_cast<Instruction>(FVal); + if (!TI || !FI) + return nullptr; + + // Normalize predicate to gt/lt rather than ge/le. + ICmpInst::Predicate Pred = Cmp->getStrictPredicate(); + Value *A = Cmp->getOperand(0); + Value *B = Cmp->getOperand(1); + + // Normalize "A - B" as the true value of the select. + if (match(FI, m_Sub(m_Specific(A), m_Specific(B)))) { + std::swap(FI, TI); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + // With any pair of no-wrap subtracts: + // (A > B) ? (A - B) : (B - A) --> abs(A - B) + if (Pred == CmpInst::ICMP_SGT && + match(TI, m_Sub(m_Specific(A), m_Specific(B))) && + match(FI, m_Sub(m_Specific(B), m_Specific(A))) && + (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap()) && + (FI->hasNoSignedWrap() || FI->hasNoUnsignedWrap())) { + // The remaining subtract is not "nuw" any more. + // If there's one use of the subtract (no other use than the use we are + // about to replace), then we know that the sub is "nsw" in this context + // even if it was only "nuw" before. If there's another use, then we can't + // add "nsw" to the existing instruction because it may not be safe in the + // other user's context. + TI->setHasNoUnsignedWrap(false); + if (!TI->hasNoSignedWrap()) + TI->setHasNoSignedWrap(TI->hasOneUse()); + return Builder.CreateBinaryIntrinsic(Intrinsic::abs, TI, Builder.getTrue()); + } + + return nullptr; +} + /// Fold the following code sequence: /// \code /// int a = ctlz(x & -x); // x ? 31 - a : a; +// // or +// x ? 31 - a : 32; /// \code /// /// into: @@ -953,15 +1076,19 @@ static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, if (ICI->getPredicate() == ICmpInst::ICMP_NE) std::swap(TrueVal, FalseVal); + Value *Ctlz; if (!match(FalseVal, - m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) + m_Xor(m_Value(Ctlz), m_SpecificInt(BitWidth - 1)))) return nullptr; - if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) + if (!match(Ctlz, m_Intrinsic<Intrinsic::ctlz>())) + return nullptr; + + if (TrueVal != Ctlz && !match(TrueVal, m_SpecificInt(BitWidth))) return nullptr; Value *X = ICI->getOperand(0); - auto *II = cast<IntrinsicInst>(TrueVal); + auto *II = cast<IntrinsicInst>(Ctlz); if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) return nullptr; @@ -1038,99 +1165,6 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return nullptr; } -/// Return true if we find and adjust an icmp+select pattern where the compare -/// is with a constant that can be incremented or decremented to match the -/// minimum or maximum idiom. -static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { - ICmpInst::Predicate Pred = Cmp.getPredicate(); - Value *CmpLHS = Cmp.getOperand(0); - Value *CmpRHS = Cmp.getOperand(1); - Value *TrueVal = Sel.getTrueValue(); - Value *FalseVal = Sel.getFalseValue(); - - // We may move or edit the compare, so make sure the select is the only user. - const APInt *CmpC; - if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) - return false; - - // These transforms only work for selects of integers or vector selects of - // integer vectors. - Type *SelTy = Sel.getType(); - auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); - if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) - return false; - - Constant *AdjustedRHS; - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) - AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); - else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) - AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); - else - return false; - - // X > C ? X : C+1 --> X < C+1 ? C+1 : X - // X < C ? X : C-1 --> X > C-1 ? C-1 : X - if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { - ; // Nothing to do here. Values match without any sign/zero extension. - } - // Types do not match. Instead of calculating this with mixed types, promote - // all to the larger type. This enables scalar evolution to analyze this - // expression. - else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { - Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); - - // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X - // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X - // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X - // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X - if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = SextRHS; - } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && - SextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = SextRHS; - } else if (Cmp.isUnsigned()) { - Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); - // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X - // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X - // zext + signed compare cannot be changed: - // 0xff <s 0x00, but 0x00ff >s 0x0000 - if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = ZextRHS; - } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && - ZextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = ZextRHS; - } else { - return false; - } - } else { - return false; - } - } else { - return false; - } - - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - Cmp.setPredicate(Pred); - Cmp.setOperand(0, CmpLHS); - Cmp.setOperand(1, CmpRHS); - Sel.setOperand(1, TrueVal); - Sel.setOperand(2, FalseVal); - Sel.swapProfMetadata(); - - // Move the compare instruction right before the select instruction. Otherwise - // the sext/zext value may be defined after the compare instruction uses it. - Cmp.moveBefore(&Sel); - - return true; -} - static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, InstCombinerImpl &IC) { Value *LHS, *RHS; @@ -1182,8 +1216,8 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, return nullptr; } -static bool replaceInInstruction(Value *V, Value *Old, Value *New, - InstCombiner &IC, unsigned Depth = 0) { +bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New, + unsigned Depth) { // Conservatively limit replacement to two instructions upwards. if (Depth == 2) return false; @@ -1195,10 +1229,11 @@ static bool replaceInInstruction(Value *V, Value *Old, Value *New, bool Changed = false; for (Use &U : I->operands()) { if (U == Old) { - IC.replaceUse(U, New); + replaceUse(U, New); + Worklist.add(I); Changed = true; } else { - Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1); + Changed |= replaceInInstruction(U, Old, New, Depth + 1); } } return Changed; @@ -1254,7 +1289,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // FIXME: Support vectors. if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && !Cmp.getType()->isVectorTy()) - if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this)) + if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS)) return &Sel; } if (TrueVal != CmpRHS && @@ -1593,13 +1628,32 @@ static Instruction *foldSelectZeroOrOnes(ICmpInst *Cmp, Value *TVal, return nullptr; } -static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI) { +static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI, + InstCombiner::BuilderTy &Builder) { const APInt *CmpC; Value *V; CmpInst::Predicate Pred; if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC)))) return nullptr; + // Match clamp away from min/max value as a max/min operation. + Value *TVal = SI.getTrueValue(); + Value *FVal = SI.getFalseValue(); + if (Pred == ICmpInst::ICMP_EQ && V == FVal) { + // (V == UMIN) ? UMIN+1 : V --> umax(V, UMIN+1) + if (CmpC->isMinValue() && match(TVal, m_SpecificInt(*CmpC + 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::umax, V, TVal); + // (V == UMAX) ? UMAX-1 : V --> umin(V, UMAX-1) + if (CmpC->isMaxValue() && match(TVal, m_SpecificInt(*CmpC - 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::umin, V, TVal); + // (V == SMIN) ? SMIN+1 : V --> smax(V, SMIN+1) + if (CmpC->isMinSignedValue() && match(TVal, m_SpecificInt(*CmpC + 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::smax, V, TVal); + // (V == SMAX) ? SMAX-1 : V --> smin(V, SMAX-1) + if (CmpC->isMaxSignedValue() && match(TVal, m_SpecificInt(*CmpC - 1))) + return Builder.CreateBinaryIntrinsic(Intrinsic::smin, V, TVal); + } + BinaryOperator *BO; const APInt *C; CmpInst::Predicate CPred; @@ -1632,7 +1686,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewSPF = canonicalizeSPF(SI, *ICI, *this)) return NewSPF; - if (Value *V = foldSelectInstWithICmpConst(SI, ICI)) + if (Value *V = foldSelectInstWithICmpConst(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) @@ -1642,18 +1696,17 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) return NewSel; - bool Changed = adjustMinMax(SI, *ICI); - if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); // NOTE: if we wanted to, this is where to detect integer MIN/MAX + bool Changed = false; Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { + if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS) && !isa<Constant>(CmpLHS)) { if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) { // Transform (X == C) ? X : Y -> (X == C) ? C : Y SI.setOperand(1, CmpRHS); @@ -1683,7 +1736,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring // decomposeBitTestICmp() might help. - { + if (TrueVal->getType()->isIntOrIntVectorTy()) { unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType()->getScalarType()); APInt MinSignedValue = APInt::getSignedMinValue(BitWidth); @@ -1735,6 +1788,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) return V; + if (Value *V = foldSelectICmpAndZeroShl(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) return V; @@ -1756,6 +1812,9 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = canonicalizeSaturatedAdd(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); + if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder)) + return replaceInstUsesWith(SI, V); + return Changed ? &SI : nullptr; } @@ -2418,7 +2477,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { // in the case of a shuffle with no undefined mask elements. ArrayRef<int> Mask; if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && - !is_contained(Mask, UndefMaskElem) && + !is_contained(Mask, PoisonMaskElem) && cast<ShuffleVectorInst>(TVal)->isSelect()) { if (X == FVal) { // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X) @@ -2432,7 +2491,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { } } if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && - !is_contained(Mask, UndefMaskElem) && + !is_contained(Mask, PoisonMaskElem) && cast<ShuffleVectorInst>(FVal)->isSelect()) { if (X == TVal) { // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y) @@ -2965,6 +3024,14 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); + // select a, (select ~a, true, b), false -> select a, b, false + if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) && + match(FalseVal, m_Zero())) + return replaceOperand(SI, 1, B); + // select a, true, (select ~a, b, false) -> select a, true, b + if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) && + match(TrueVal, m_One())) + return replaceOperand(SI, 2, B); // ~(A & B) & (A | B) --> A ^ B if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), @@ -3077,6 +3144,134 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { return nullptr; } +// Return true if we can safely remove the select instruction for std::bit_ceil +// pattern. +static bool isSafeToRemoveBitCeilSelect(ICmpInst::Predicate Pred, Value *Cond0, + const APInt *Cond1, Value *CtlzOp, + unsigned BitWidth) { + // The challenge in recognizing std::bit_ceil(X) is that the operand is used + // for the CTLZ proper and select condition, each possibly with some + // operation like add and sub. + // + // Our aim is to make sure that -ctlz & (BitWidth - 1) == 0 even when the + // select instruction would select 1, which allows us to get rid of the select + // instruction. + // + // To see if we can do so, we do some symbolic execution with ConstantRange. + // Specifically, we compute the range of values that Cond0 could take when + // Cond == false. Then we successively transform the range until we obtain + // the range of values that CtlzOp could take. + // + // Conceptually, we follow the def-use chain backward from Cond0 while + // transforming the range for Cond0 until we meet the common ancestor of Cond0 + // and CtlzOp. Then we follow the def-use chain forward until we obtain the + // range for CtlzOp. That said, we only follow at most one ancestor from + // Cond0. Likewise, we only follow at most one ancestor from CtrlOp. + + ConstantRange CR = ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred), *Cond1); + + // Match the operation that's used to compute CtlzOp from CommonAncestor. If + // CtlzOp == CommonAncestor, return true as no operation is needed. If a + // match is found, execute the operation on CR, update CR, and return true. + // Otherwise, return false. + auto MatchForward = [&](Value *CommonAncestor) { + const APInt *C = nullptr; + if (CtlzOp == CommonAncestor) + return true; + if (match(CtlzOp, m_Add(m_Specific(CommonAncestor), m_APInt(C)))) { + CR = CR.add(*C); + return true; + } + if (match(CtlzOp, m_Sub(m_APInt(C), m_Specific(CommonAncestor)))) { + CR = ConstantRange(*C).sub(CR); + return true; + } + if (match(CtlzOp, m_Not(m_Specific(CommonAncestor)))) { + CR = CR.binaryNot(); + return true; + } + return false; + }; + + const APInt *C = nullptr; + Value *CommonAncestor; + if (MatchForward(Cond0)) { + // Cond0 is either CtlzOp or CtlzOp's parent. CR has been updated. + } else if (match(Cond0, m_Add(m_Value(CommonAncestor), m_APInt(C)))) { + CR = CR.sub(*C); + if (!MatchForward(CommonAncestor)) + return false; + // Cond0's parent is either CtlzOp or CtlzOp's parent. CR has been updated. + } else { + return false; + } + + // Return true if all the values in the range are either 0 or negative (if + // treated as signed). We do so by evaluating: + // + // CR - 1 u>= (1 << BitWidth) - 1. + APInt IntMax = APInt::getSignMask(BitWidth) - 1; + CR = CR.sub(APInt(BitWidth, 1)); + return CR.icmp(ICmpInst::ICMP_UGE, IntMax); +} + +// Transform the std::bit_ceil(X) pattern like: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %sub = sub i32 32, %ctlz +// %shl = shl i32 1, %sub +// %ugt = icmp ugt i32 %x, 1 +// %sel = select i1 %ugt, i32 %shl, i32 1 +// +// into: +// +// %dec = add i32 %x, -1 +// %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false) +// %neg = sub i32 0, %ctlz +// %masked = and i32 %ctlz, 31 +// %shl = shl i32 1, %sub +// +// Note that the select is optimized away while the shift count is masked with +// 31. We handle some variations of the input operand like std::bit_ceil(X + +// 1). +static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { + Type *SelType = SI.getType(); + unsigned BitWidth = SelType->getScalarSizeInBits(); + + Value *FalseVal = SI.getFalseValue(); + Value *TrueVal = SI.getTrueValue(); + ICmpInst::Predicate Pred; + const APInt *Cond1; + Value *Cond0, *Ctlz, *CtlzOp; + if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1)))) + return nullptr; + + if (match(TrueVal, m_One())) { + std::swap(FalseVal, TrueVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + if (!match(FalseVal, m_One()) || + !match(TrueVal, + m_OneUse(m_Shl(m_One(), m_OneUse(m_Sub(m_SpecificInt(BitWidth), + m_Value(Ctlz)))))) || + !match(Ctlz, m_Intrinsic<Intrinsic::ctlz>(m_Value(CtlzOp), m_Zero())) || + !isSafeToRemoveBitCeilSelect(Pred, Cond0, Cond1, CtlzOp, BitWidth)) + return nullptr; + + // Build 1 << (-CTLZ & (BitWidth-1)). The negation likely corresponds to a + // single hardware instruction as opposed to BitWidth - CTLZ, where BitWidth + // is an integer constant. Masking with BitWidth-1 comes free on some + // hardware as part of the shift instruction. + Value *Neg = Builder.CreateNeg(Ctlz); + Value *Masked = + Builder.CreateAnd(Neg, ConstantInt::get(SelType, BitWidth - 1)); + return BinaryOperator::Create(Instruction::Shl, ConstantInt::get(SelType, 1), + Masked); +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3253,6 +3448,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { std::swap(NewT, NewF); Value *NewSI = Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); + if (Gep->isInBounds()) + return GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI}); return GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); }; if (auto *TrueGep = dyn_cast<GetElementPtrInst>(TrueVal)) @@ -3364,25 +3561,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } - auto canMergeSelectThroughBinop = [](BinaryOperator *BO) { - // The select might be preventing a division by 0. - switch (BO->getOpcode()) { - default: - return true; - case Instruction::SRem: - case Instruction::URem: - case Instruction::SDiv: - case Instruction::UDiv: - return false; - } - }; - // Try to simplify a binop sandwiched between 2 selects with the same - // condition. + // condition. This is not valid for div/rem because the select might be + // preventing a division-by-zero. + // TODO: A div/rem restriction is conservative; use something like + // isSafeToSpeculativelyExecute(). // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) BinaryOperator *TrueBO; - if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && - canMergeSelectThroughBinop(TrueBO)) { + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO))) && !TrueBO->isIntDivRem()) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue()); @@ -3401,8 +3587,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) BinaryOperator *FalseBO; - if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && - canMergeSelectThroughBinop(FalseBO)) { + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO))) && !FalseBO->isIntDivRem()) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue()); @@ -3516,5 +3701,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (sinkNotIntoOtherHandOfLogicalOp(SI)) return &SI; + if (Instruction *I = foldBitCeil(SI, Builder)) + return I; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index ec505381cc86..89dad455f015 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -322,15 +322,20 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } -/// If we have a shift-by-constant of a bitwise logic op that itself has a -/// shift-by-constant operand with identical opcode, we may be able to convert -/// that into 2 independent shifts followed by the logic op. This eliminates a -/// a use of an intermediate value (reduces dependency chain). -static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, +/// If we have a shift-by-constant of a bin op (bitwise logic op or add/sub w/ +/// shl) that itself has a shift-by-constant operand with identical opcode, we +/// may be able to convert that into 2 independent shifts followed by the logic +/// op. This eliminates a use of an intermediate value (reduces dependency +/// chain). +static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { assert(I.isShift() && "Expected a shift as input"); - auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); - if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) + auto *BinInst = dyn_cast<BinaryOperator>(I.getOperand(0)); + if (!BinInst || + (!BinInst->isBitwiseLogicOp() && + BinInst->getOpcode() != Instruction::Add && + BinInst->getOpcode() != Instruction::Sub) || + !BinInst->hasOneUse()) return nullptr; Constant *C0, *C1; @@ -338,6 +343,12 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, return nullptr; Instruction::BinaryOps ShiftOpcode = I.getOpcode(); + // Transform for add/sub only works with shl. + if ((BinInst->getOpcode() == Instruction::Add || + BinInst->getOpcode() == Instruction::Sub) && + ShiftOpcode != Instruction::Shl) + return nullptr; + Type *Ty = I.getType(); // Find a matching one-use shift by constant. The fold is not valid if the sum @@ -352,19 +363,25 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; - // Logic ops are commutative, so check each operand for a match. - if (matchFirstShift(LogicInst->getOperand(0))) - Y = LogicInst->getOperand(1); - else if (matchFirstShift(LogicInst->getOperand(1))) - Y = LogicInst->getOperand(0); - else + // Logic ops and Add are commutative, so check each operand for a match. Sub + // is not so we cannot reoder if we match operand(1) and need to keep the + // operands in their original positions. + bool FirstShiftIsOp1 = false; + if (matchFirstShift(BinInst->getOperand(0))) + Y = BinInst->getOperand(1); + else if (matchFirstShift(BinInst->getOperand(1))) { + Y = BinInst->getOperand(0); + FirstShiftIsOp1 = BinInst->getOpcode() == Instruction::Sub; + } else return nullptr; - // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) + // shift (binop (shift X, C0), Y), C1 -> binop (shift X, C0+C1), (shift Y, C1) Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1); - return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); + Value *Op1 = FirstShiftIsOp1 ? NewShift2 : NewShift1; + Value *Op2 = FirstShiftIsOp1 ? NewShift1 : NewShift2; + return BinaryOperator::Create(BinInst->getOpcode(), Op1, Op2); } Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { @@ -463,9 +480,12 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) { return replaceOperand(I, 1, Rem); } - if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) + if (Instruction *Logic = foldShiftOfShiftedBinOp(I, Builder)) return Logic; + if (match(Op1, m_Or(m_Value(), m_SpecificInt(BitWidth - 1)))) + return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1)); + return nullptr; } @@ -570,8 +590,7 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, const APInt *MulConst; // We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`) return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) && - MulConst->isNegatedPowerOf2() && - MulConst->countTrailingZeros() == NumBits; + MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits; } } } @@ -900,8 +919,10 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { // Replace the uses of the original add with a zext of the // NarrowAdd's result. Note that all users at this stage are known to // be ShAmt-sized truncs, or the lshr itself. - if (!Add->hasOneUse()) + if (!Add->hasOneUse()) { replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty)); + eraseInstFromFunction(*AddInst); + } // Replace the LShr with a zext of the overflow check. return new ZExtInst(Overflow, Ty); @@ -1133,6 +1154,14 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { return BinaryOperator::CreateLShr( ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + // Canonicalize "extract lowest set bit" using cttz to and-with-negate: + // 1 << (cttz X) --> -X & X + if (match(Op1, + m_OneUse(m_Intrinsic<Intrinsic::cttz>(m_Value(X), m_Value())))) { + Value *NegX = Builder.CreateNeg(X, "neg"); + return BinaryOperator::CreateAnd(NegX, X); + } + // The only way to shift out the 1 is with an over-shift, so that would // be poison with or without "nuw". Undef is excluded because (undef << X) // is not undef (it is zero). diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 77d675422966..00eece9534b0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -168,7 +168,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care // about the high bits of the operands. auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); // Right fill the mask of bits for the operands to demand the most // significant bit and all those below it. DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); @@ -195,7 +195,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = LHSKnown & RHSKnown; + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, DL, &AC, CxtI, &DT); // If the client is only demanding bits that we know, return the known // constant. @@ -224,7 +225,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = LHSKnown | RHSKnown; + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, DL, &AC, CxtI, &DT); // If the client is only demanding bits that we know, return the known // constant. @@ -262,7 +264,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - Known = LHSKnown ^ RHSKnown; + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, DL, &AC, CxtI, &DT); // If the client is only demanding bits that we know, return the known // constant. @@ -381,7 +384,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; // Only known if known in both the LHS and RHS. - Known = KnownBits::commonBits(LHSKnown, RHSKnown); + Known = LHSKnown.intersectWith(RHSKnown); break; } case Instruction::Trunc: { @@ -393,7 +396,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // The shift amount must be valid (not poison) in the narrow type, and // it must not be greater than the high bits demanded of the result. if (C->ult(VTy->getScalarSizeInBits()) && - C->ule(DemandedMask.countLeadingZeros())) { + C->ule(DemandedMask.countl_zero())) { // trunc (lshr X, C) --> lshr (trunc X), C IRBuilderBase::InsertPointGuard Guard(Builder); Builder.SetInsertPoint(I); @@ -508,7 +511,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Right fill the mask of bits for the operands to demand the most // significant bit and all those below it. - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) @@ -517,7 +520,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If low order bits are not demanded and known to be zero in one operand, // then we don't need to demand them from the other operand, since they // can't cause overflow into any bits that are demanded in the result. - unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); APInt DemandedFromLHS = DemandedFromOps; DemandedFromLHS.clearLowBits(NTZ); if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || @@ -539,7 +542,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, case Instruction::Sub: { // Right fill the mask of bits for the operands to demand the most // significant bit and all those below it. - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) @@ -548,7 +551,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If low order bits are not demanded and are known to be zero in RHS, // then we don't need to demand them from LHS, since they can't cause a // borrow from any bits that are demanded in the result. - unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one(); APInt DemandedFromLHS = DemandedFromOps; DemandedFromLHS.clearLowBits(NTZ); if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || @@ -578,10 +581,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1. // If we demand exactly one bit N and we have "X * (C' << N)" where C' is // odd (has LSB set), then the left-shifted low bit of X is the answer. - unsigned CTZ = DemandedMask.countTrailingZeros(); + unsigned CTZ = DemandedMask.countr_zero(); const APInt *C; - if (match(I->getOperand(1), m_APInt(C)) && - C->countTrailingZeros() == CTZ) { + if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) { Constant *ShiftC = ConstantInt::get(VTy, CTZ); Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); return InsertNewInstWith(Shl, *I); @@ -619,7 +621,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); Value *X; Constant *C; - if (DemandedMask.countTrailingZeros() >= ShiftAmt && + if (DemandedMask.countr_zero() >= ShiftAmt && match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); @@ -642,29 +644,15 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - bool SignBitZero = Known.Zero.isSignBitSet(); - bool SignBitOne = Known.One.isSignBitSet(); - Known.Zero <<= ShiftAmt; - Known.One <<= ShiftAmt; - // low bits known zero. - if (ShiftAmt) - Known.Zero.setLowBits(ShiftAmt); - - // If this shift has "nsw" keyword, then the result is either a poison - // value or has the same sign bit as the first operand. - if (IOp->hasNoSignedWrap()) { - if (SignBitZero) - Known.Zero.setSignBit(); - else if (SignBitOne) - Known.One.setSignBit(); - if (Known.hasConflict()) - return UndefValue::get(VTy); - } + Known = KnownBits::shl(Known, + KnownBits::makeConstant(APInt(BitWidth, ShiftAmt)), + /* NUW */ IOp->hasNoUnsignedWrap(), + /* NSW */ IOp->hasNoSignedWrap()); } else { // This is a variable shift, so we can't shift the demand mask by a known // amount. But if we are not demanding high bits, then we are not // demanding those bits from the pre-shifted operand either. - if (unsigned CTLZ = DemandedMask.countLeadingZeros()) { + if (unsigned CTLZ = DemandedMask.countl_zero()) { APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ)); if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) { // We can't guarantee that nsw/nuw hold after simplifying the operand. @@ -683,11 +671,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If we are just demanding the shifted sign bit and below, then this can // be treated as an ASHR in disguise. - if (DemandedMask.countLeadingZeros() >= ShiftAmt) { + if (DemandedMask.countl_zero() >= ShiftAmt) { // If we only want bits that already match the signbit then we don't // need to shift. - unsigned NumHiDemandedBits = - BitWidth - DemandedMask.countTrailingZeros(); + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); if (SignBits >= NumHiDemandedBits) @@ -734,7 +721,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If we only want bits that already match the signbit then we don't need // to shift. - unsigned NumHiDemandedBits = BitWidth - DemandedMask.countTrailingZeros(); + unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); if (SignBits >= NumHiDemandedBits) return I->getOperand(0); @@ -757,7 +744,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); // If any of the high bits are demanded, we should set the sign bit as // demanded. - if (DemandedMask.countLeadingZeros() <= ShiftAmt) + if (DemandedMask.countl_zero() <= ShiftAmt) DemandedMaskIn.setSignBit(); // If the shift is exact, then it does demand the low bits (and knows that @@ -797,7 +784,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { // TODO: Take the demanded mask of the result into account. - unsigned RHSTrailingZeros = SA->countTrailingZeros(); + unsigned RHSTrailingZeros = SA->countr_zero(); APInt DemandedMaskIn = APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) { @@ -807,9 +794,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I; } - // Increase high zero bits from the input. - Known.Zero.setHighBits(std::min( - BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); + Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA), + cast<BinaryOperator>(I)->isExact()); } else { computeKnownBits(I, Known, Depth, CxtI); } @@ -851,25 +837,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } } - // The sign bit is the LHS's sign bit, except when the result of the - // remainder is zero. - if (DemandedMask.isSignBitSet()) { - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - // If it's known zero, our sign bit is also zero. - if (LHSKnown.isNonNegative()) - Known.makeNonNegative(); - } + computeKnownBits(I, Known, Depth, CxtI); break; } case Instruction::URem: { - KnownBits Known2(BitWidth); APInt AllOnes = APInt::getAllOnes(BitWidth); - if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || - SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) + if (SimplifyDemandedBits(I, 0, AllOnes, LHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 1, AllOnes, RHSKnown, Depth + 1)) return I; - unsigned Leaders = Known2.countMinLeadingZeros(); - Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; + Known = KnownBits::urem(LHSKnown, RHSKnown); break; } case Instruction::Call: { @@ -897,8 +874,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, case Intrinsic::bswap: { // If the only bits demanded come from one byte of the bswap result, // just shift the input byte into position to eliminate the bswap. - unsigned NLZ = DemandedMask.countLeadingZeros(); - unsigned NTZ = DemandedMask.countTrailingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); + unsigned NTZ = DemandedMask.countr_zero(); // Round NTZ down to the next byte. If we have 11 trailing zeros, then // we need all the bits down to bit 8. Likewise, round NLZ. If we @@ -935,9 +912,28 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt)); APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); - if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) || - SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) - return I; + if (I->getOperand(0) != I->getOperand(1)) { + if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, + Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) + return I; + } else { // fshl is a rotate + // Avoid converting rotate into funnel shift. + // Only simplify if one operand is constant. + LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I); + if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) && + !match(I->getOperand(0), m_SpecificInt(LHSKnown.One))) { + replaceOperand(*I, 0, Constant::getIntegerValue(VTy, LHSKnown.One)); + return I; + } + + RHSKnown = computeKnownBits(I->getOperand(1), Depth + 1, I); + if (DemandedMaskRHS.isSubsetOf(RHSKnown.Zero | RHSKnown.One) && + !match(I->getOperand(1), m_SpecificInt(RHSKnown.One))) { + replaceOperand(*I, 1, Constant::getIntegerValue(VTy, RHSKnown.One)); + return I; + } + } Known.Zero = LHSKnown.Zero.shl(ShiftAmt) | RHSKnown.Zero.lshr(BitWidth - ShiftAmt); @@ -951,7 +947,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // The lowest non-zero bit of DemandMask is higher than the highest // non-zero bit of C. const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); + unsigned CTZ = DemandedMask.countr_zero(); if (match(II->getArgOperand(1), m_APInt(C)) && CTZ >= C->getActiveBits()) return II->getArgOperand(0); @@ -963,9 +959,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // non-one bit of C. // This comes from using DeMorgans on the above umax example. const APInt *C; - unsigned CTZ = DemandedMask.countTrailingZeros(); + unsigned CTZ = DemandedMask.countr_zero(); if (match(II->getArgOperand(1), m_APInt(C)) && - CTZ >= C->getBitWidth() - C->countLeadingOnes()) + CTZ >= C->getBitWidth() - C->countl_one()) return II->getArgOperand(0); break; } @@ -1014,6 +1010,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown & RHSKnown; + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1033,6 +1030,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown | RHSKnown; + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1054,6 +1052,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown ^ RHSKnown; + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1071,7 +1070,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Add: { - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); // If an operand adds zeros to every bit below the highest demanded bit, @@ -1084,10 +1083,13 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown); + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::Sub: { - unsigned NLZ = DemandedMask.countLeadingZeros(); + unsigned NLZ = DemandedMask.countl_zero(); APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); // If an operand subtracts zeros from every bit below the highest demanded @@ -1096,6 +1098,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown); + computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::AShr: { @@ -1541,7 +1547,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Found constant vector with single element - convert to insertelement. if (Op && Value) { Instruction *New = InsertElementInst::Create( - Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx), + Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx), Shuffle->getName()); InsertNewInstWith(New, *Shuffle); return New; @@ -1552,7 +1558,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, SmallVector<int, 16> Elts; for (unsigned i = 0; i < VWidth; ++i) { if (UndefElts[i]) - Elts.push_back(UndefMaskElem); + Elts.push_back(PoisonMaskElem); else Elts.push_back(Shuffle->getMaskValue(i)); } @@ -1653,7 +1659,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // corresponding input elements are undef. for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) { APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio); - if (SubUndef.countPopulation() == Ratio) + if (SubUndef.popcount() == Ratio) UndefElts.setBit(OutIdx); } } else { @@ -1712,6 +1718,54 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // UB/poison potential, but that should be refined. BinaryOperator *BO; if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) { + Value *X = BO->getOperand(0); + Value *Y = BO->getOperand(1); + + // Look for an equivalent binop except that one operand has been shuffled. + // If the demand for this binop only includes elements that are the same as + // the other binop, then we may be able to replace this binop with a use of + // the earlier one. + // + // Example: + // %other_bo = bo (shuf X, {0}), Y + // %this_extracted_bo = extelt (bo X, Y), 0 + // --> + // %other_bo = bo (shuf X, {0}), Y + // %this_extracted_bo = extelt %other_bo, 0 + // + // TODO: Handle demand of an arbitrary single element or more than one + // element instead of just element 0. + // TODO: Unlike general demanded elements transforms, this should be safe + // for any (div/rem/shift) opcode too. + if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() && + BO->hasOneUse() ) { + + auto findShufBO = [&](bool MatchShufAsOp0) -> User * { + // Try to use shuffle-of-operand in place of an operand: + // bo X, Y --> bo (shuf X), Y + // bo X, Y --> bo X, (shuf Y) + BinaryOperator::BinaryOps Opcode = BO->getOpcode(); + Value *ShufOp = MatchShufAsOp0 ? X : Y; + Value *OtherOp = MatchShufAsOp0 ? Y : X; + for (User *U : OtherOp->users()) { + auto Shuf = m_Shuffle(m_Specific(ShufOp), m_Value(), m_ZeroMask()); + if (BO->isCommutative() + ? match(U, m_c_BinOp(Opcode, Shuf, m_Specific(OtherOp))) + : MatchShufAsOp0 + ? match(U, m_BinOp(Opcode, Shuf, m_Specific(OtherOp))) + : match(U, m_BinOp(Opcode, m_Specific(OtherOp), Shuf))) + if (DT.dominates(U, I)) + return U; + } + return nullptr; + }; + + if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true)) + return ShufBO; + if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false)) + return ShufBO; + } + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); simplifyAndSetOp(I, 1, DemandedElts, UndefElts2); @@ -1723,7 +1777,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // If we've proven all of the lanes undef, return an undef value. // TODO: Intersect w/demanded lanes if (UndefElts.isAllOnes()) - return UndefValue::get(I->getType());; + return UndefValue::get(I->getType()); return MadeChange ? I : nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 61e62adbe327..4a5ffef2b08e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -171,8 +171,11 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, } } - for (auto *E : Extracts) + for (auto *E : Extracts) { replaceInstUsesWith(*E, scalarPHI); + // Add old extract to worklist for DCE. + addToWorklist(E); + } return &EI; } @@ -384,7 +387,7 @@ static APInt findDemandedEltsByAllUsers(Value *V) { /// return it with the canonical type if it isn't already canonical. We /// arbitrarily pick 64 bit as our canonical type. The actual bitwidth doesn't /// matter, we just want a consistent type to simplify CSE. -ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) { +static ConstantInt *getPreferredVectorIndex(ConstantInt *IndexC) { const unsigned IndexBW = IndexC->getType()->getBitWidth(); if (IndexBW == 64 || IndexC->getValue().getActiveBits() > 64) return nullptr; @@ -543,16 +546,16 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { ->getNumElements(); if (SrcIdx < 0) - return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); + return replaceInstUsesWith(EI, PoisonValue::get(EI.getType())); if (SrcIdx < (int)LHSWidth) Src = SVI->getOperand(0); else { SrcIdx -= LHSWidth; Src = SVI->getOperand(1); } - Type *Int32Ty = Type::getInt32Ty(EI.getContext()); + Type *Int64Ty = Type::getInt64Ty(EI.getContext()); return ExtractElementInst::Create( - Src, ConstantInt::get(Int32Ty, SrcIdx, false)); + Src, ConstantInt::get(Int64Ty, SrcIdx, false)); } } else if (auto *CI = dyn_cast<CastInst>(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). @@ -594,6 +597,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { SrcVec, DemandedElts, UndefElts, 0 /* Depth */, true /* AllowMultipleUsers */)) { if (V != SrcVec) { + Worklist.addValue(SrcVec); SrcVec->replaceAllUsesWith(V); return &EI; } @@ -640,11 +644,11 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, return false; unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); - if (isa<UndefValue>(ScalarOp)) { // inserting undef into vector. + if (isa<PoisonValue>(ScalarOp)) { // inserting poison into vector. // We can handle this if the vector we are inserting into is // transitively ok. if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { - // If so, update the mask to reflect the inserted undef. + // If so, update the mask to reflect the inserted poison. Mask[InsertedIdx] = -1; return true; } @@ -680,7 +684,7 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, /// If we have insertion into a vector that is wider than the vector that we /// are extracting from, try to widen the source vector to allow a single /// shufflevector to replace one or more insert/extract pairs. -static void replaceExtractElements(InsertElementInst *InsElt, +static bool replaceExtractElements(InsertElementInst *InsElt, ExtractElementInst *ExtElt, InstCombinerImpl &IC) { auto *InsVecType = cast<FixedVectorType>(InsElt->getType()); @@ -691,7 +695,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, // The inserted-to vector must be wider than the extracted-from vector. if (InsVecType->getElementType() != ExtVecType->getElementType() || NumExtElts >= NumInsElts) - return; + return false; // Create a shuffle mask to widen the extended-from vector using poison // values. The mask selects all of the values of the original vector followed @@ -719,7 +723,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, // that will delete our widening shuffle. This would trigger another attempt // here to create that shuffle, and we spin forever. if (InsertionBlock != InsElt->getParent()) - return; + return false; // TODO: This restriction matches the check in visitInsertElementInst() and // prevents an infinite loop caused by not turning the extract/insert pair @@ -727,7 +731,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, // folds for shufflevectors because we're afraid to generate shuffle masks // that the backend can't handle. if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) - return; + return false; auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask); @@ -747,9 +751,14 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (!OldExt || OldExt->getParent() != WideVec->getParent()) continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); - NewExt->insertAfter(OldExt); + IC.InsertNewInstWith(NewExt, *OldExt); IC.replaceInstUsesWith(*OldExt, NewExt); + // Add the old extracts to the worklist for DCE. We can't remove the + // extracts directly, because they may still be used by the calling code. + IC.addToWorklist(OldExt); } + + return true; } /// We are building a shuffle to create V, which is a sequence of insertelement, @@ -764,7 +773,7 @@ using ShuffleOps = std::pair<Value *, Value *>; static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, Value *PermittedRHS, - InstCombinerImpl &IC) { + InstCombinerImpl &IC, bool &Rerun) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); @@ -795,13 +804,14 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, // otherwise we'd end up with a shuffle of three inputs. if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) { Value *RHS = EI->getOperand(0); - ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC); + ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC, Rerun); assert(LR.second == nullptr || LR.second == RHS); if (LR.first->getType() != RHS->getType()) { // Although we are giving up for now, see if we can create extracts // that match the inserts for another round of combining. - replaceExtractElements(IEI, EI, IC); + if (replaceExtractElements(IEI, EI, IC)) + Rerun = true; // We tried our best, but we can't find anything compatible with RHS // further up the chain. Return a trivial shuffle. @@ -1129,6 +1139,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( /// It should be transformed to: /// %0 = insertvalue { i8, i32 } undef, i8 %y, 0 Instruction *InstCombinerImpl::visitInsertValueInst(InsertValueInst &I) { + if (Value *V = simplifyInsertValueInst( + I.getAggregateOperand(), I.getInsertedValueOperand(), I.getIndices(), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + bool IsRedundant = false; ArrayRef<unsigned int> FirstIndices = I.getIndices(); @@ -1235,22 +1250,22 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { if (FirstIE == &InsElt) return nullptr; - // If we are not inserting into an undef vector, make sure we've seen an + // If we are not inserting into a poison vector, make sure we've seen an // insert into every element. // TODO: If the base vector is not undef, it might be better to create a splat // and then a select-shuffle (blend) with the base vector. - if (!match(FirstIE->getOperand(0), m_Undef())) + if (!match(FirstIE->getOperand(0), m_Poison())) if (!ElementPresent.all()) return nullptr; // Create the insert + shuffle. - Type *Int32Ty = Type::getInt32Ty(InsElt.getContext()); + Type *Int64Ty = Type::getInt64Ty(InsElt.getContext()); PoisonValue *PoisonVec = PoisonValue::get(VecTy); - Constant *Zero = ConstantInt::get(Int32Ty, 0); + Constant *Zero = ConstantInt::get(Int64Ty, 0); if (!cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) FirstIE = InsertElementInst::Create(PoisonVec, SplatVal, Zero, "", &InsElt); - // Splat from element 0, but replace absent elements with undef in the mask. + // Splat from element 0, but replace absent elements with poison in the mask. SmallVector<int, 16> Mask(NumElements, 0); for (unsigned i = 0; i != NumElements; ++i) if (!ElementPresent[i]) @@ -1339,7 +1354,7 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { // (demanded elements analysis may unset it later). return nullptr; } else { - assert(OldMask[i] == UndefMaskElem && + assert(OldMask[i] == PoisonMaskElem && "Unexpected shuffle mask element for identity shuffle"); NewMask[i] = IdxC; } @@ -1465,10 +1480,10 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { } ++ValI; } - // Remaining values are filled with 'undef' values. + // Remaining values are filled with 'poison' values. for (unsigned I = 0; I < NumElts; ++I) { if (!Values[I]) { - Values[I] = UndefValue::get(InsElt.getType()->getElementType()); + Values[I] = PoisonValue::get(InsElt.getType()->getElementType()); Mask[I] = I; } } @@ -1676,16 +1691,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { // Try to form a shuffle from a chain of extract-insert ops. if (isShuffleRootCandidate(IE)) { - SmallVector<int, 16> Mask; - ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); - - // The proposed shuffle may be trivial, in which case we shouldn't - // perform the combine. - if (LR.first != &IE && LR.second != &IE) { - // We now have a shuffle of LHS, RHS, Mask. - if (LR.second == nullptr) - LR.second = UndefValue::get(LR.first->getType()); - return new ShuffleVectorInst(LR.first, LR.second, Mask); + bool Rerun = true; + while (Rerun) { + Rerun = false; + + SmallVector<int, 16> Mask; + ShuffleOps LR = + collectShuffleElements(&IE, Mask, nullptr, *this, Rerun); + + // The proposed shuffle may be trivial, in which case we shouldn't + // perform the combine. + if (LR.first != &IE && LR.second != &IE) { + // We now have a shuffle of LHS, RHS, Mask. + if (LR.second == nullptr) + LR.second = PoisonValue::get(LR.first->getType()); + return new ShuffleVectorInst(LR.first, LR.second, Mask); + } } } } @@ -1815,9 +1836,9 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, /// Rebuild a new instruction just like 'I' but with the new operands given. /// In the event of type mismatch, the type of the operands is correct. -static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { - // We don't want to use the IRBuilder here because we want the replacement - // instructions to appear next to 'I', not the builder's insertion point. +static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps, + IRBuilderBase &Builder) { + Builder.SetInsertPoint(I); switch (I->getOpcode()) { case Instruction::Add: case Instruction::FAdd: @@ -1839,28 +1860,29 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { case Instruction::Xor: { BinaryOperator *BO = cast<BinaryOperator>(I); assert(NewOps.size() == 2 && "binary operator with #ops != 2"); - BinaryOperator *New = - BinaryOperator::Create(cast<BinaryOperator>(I)->getOpcode(), - NewOps[0], NewOps[1], "", BO); - if (isa<OverflowingBinaryOperator>(BO)) { - New->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap()); - New->setHasNoSignedWrap(BO->hasNoSignedWrap()); - } - if (isa<PossiblyExactOperator>(BO)) { - New->setIsExact(BO->isExact()); + Value *New = Builder.CreateBinOp(cast<BinaryOperator>(I)->getOpcode(), + NewOps[0], NewOps[1]); + if (auto *NewI = dyn_cast<Instruction>(New)) { + if (isa<OverflowingBinaryOperator>(BO)) { + NewI->setHasNoUnsignedWrap(BO->hasNoUnsignedWrap()); + NewI->setHasNoSignedWrap(BO->hasNoSignedWrap()); + } + if (isa<PossiblyExactOperator>(BO)) { + NewI->setIsExact(BO->isExact()); + } + if (isa<FPMathOperator>(BO)) + NewI->copyFastMathFlags(I); } - if (isa<FPMathOperator>(BO)) - New->copyFastMathFlags(I); return New; } case Instruction::ICmp: assert(NewOps.size() == 2 && "icmp with #ops != 2"); - return new ICmpInst(I, cast<ICmpInst>(I)->getPredicate(), - NewOps[0], NewOps[1]); + return Builder.CreateICmp(cast<ICmpInst>(I)->getPredicate(), NewOps[0], + NewOps[1]); case Instruction::FCmp: assert(NewOps.size() == 2 && "fcmp with #ops != 2"); - return new FCmpInst(I, cast<FCmpInst>(I)->getPredicate(), - NewOps[0], NewOps[1]); + return Builder.CreateFCmp(cast<FCmpInst>(I)->getPredicate(), NewOps[0], + NewOps[1]); case Instruction::Trunc: case Instruction::ZExt: case Instruction::SExt: @@ -1876,27 +1898,26 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { I->getType()->getScalarType(), cast<VectorType>(NewOps[0]->getType())->getElementCount()); assert(NewOps.size() == 1 && "cast with #ops != 1"); - return CastInst::Create(cast<CastInst>(I)->getOpcode(), NewOps[0], DestTy, - "", I); + return Builder.CreateCast(cast<CastInst>(I)->getOpcode(), NewOps[0], + DestTy); } case Instruction::GetElementPtr: { Value *Ptr = NewOps[0]; ArrayRef<Value*> Idx = NewOps.slice(1); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - cast<GetElementPtrInst>(I)->getSourceElementType(), Ptr, Idx, "", I); - GEP->setIsInBounds(cast<GetElementPtrInst>(I)->isInBounds()); - return GEP; + return Builder.CreateGEP(cast<GEPOperator>(I)->getSourceElementType(), + Ptr, Idx, "", + cast<GEPOperator>(I)->isInBounds()); } } llvm_unreachable("failed to rebuild vector instructions"); } -static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { +static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask, + IRBuilderBase &Builder) { // Mask.size() does not need to be equal to the number of vector elements. assert(V->getType()->isVectorTy() && "can't reorder non-vector elements"); Type *EltTy = V->getType()->getScalarType(); - Type *I32Ty = IntegerType::getInt32Ty(V->getContext()); if (match(V, m_Undef())) return UndefValue::get(FixedVectorType::get(EltTy, Mask.size())); @@ -1950,15 +1971,14 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // as well. E.g. GetElementPtr may have scalar operands even if the // return value is a vector, so we need to examine the operand type. if (I->getOperand(i)->getType()->isVectorTy()) - V = evaluateInDifferentElementOrder(I->getOperand(i), Mask); + V = evaluateInDifferentElementOrder(I->getOperand(i), Mask, Builder); else V = I->getOperand(i); NewOps.push_back(V); NeedsRebuild |= (V != I->getOperand(i)); } - if (NeedsRebuild) { - return buildNew(I, NewOps); - } + if (NeedsRebuild) + return buildNew(I, NewOps, Builder); return I; } case Instruction::InsertElement: { @@ -1979,11 +1999,12 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // If element is not in Mask, no need to handle the operand 1 (element to // be inserted). Just evaluate values in operand 0 according to Mask. if (!Found) - return evaluateInDifferentElementOrder(I->getOperand(0), Mask); + return evaluateInDifferentElementOrder(I->getOperand(0), Mask, Builder); - Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask); - return InsertElementInst::Create(V, I->getOperand(1), - ConstantInt::get(I32Ty, Index), "", I); + Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask, + Builder); + Builder.SetInsertPoint(I); + return Builder.CreateInsertElement(V, I->getOperand(1), Index); } } llvm_unreachable("failed to reorder elements of vector instruction!"); @@ -2140,7 +2161,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { ConstantExpr::getShuffleVector(IdC, C, Mask); bool MightCreatePoisonOrUB = - is_contained(Mask, UndefMaskElem) && + is_contained(Mask, PoisonMaskElem) && (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); if (MightCreatePoisonOrUB) NewC = InstCombiner::getSafeVectorConstantForBinop(BOpcode, NewC, true); @@ -2154,7 +2175,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { // An undef shuffle mask element may propagate as an undef constant element in // the new binop. That would produce poison where the original code might not. // If we already made a safe constant, then there's no danger. - if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) + if (is_contained(Mask, PoisonMaskElem) && !MightCreatePoisonOrUB) NewBO->dropPoisonGeneratingFlags(); return NewBO; } @@ -2178,8 +2199,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, // Insert into element 0 of an undef vector. UndefValue *UndefVec = UndefValue::get(Shuf.getType()); - Constant *Zero = Builder.getInt32(0); - Value *NewIns = Builder.CreateInsertElement(UndefVec, X, Zero); + Value *NewIns = Builder.CreateInsertElement(UndefVec, X, (uint64_t)0); // Splat from element 0. Any mask element that is undefined remains undefined. // For example: @@ -2189,7 +2209,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts, 0); for (unsigned i = 0; i != NumMaskElts; ++i) - if (Mask[i] == UndefMaskElem) + if (Mask[i] == PoisonMaskElem) NewMask[i] = Mask[i]; return new ShuffleVectorInst(NewIns, NewMask); @@ -2274,7 +2294,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { // mask element, the result is undefined, but it is not poison or undefined // behavior. That is not necessarily true for div/rem/shift. bool MightCreatePoisonOrUB = - is_contained(Mask, UndefMaskElem) && + is_contained(Mask, PoisonMaskElem) && (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); if (MightCreatePoisonOrUB) NewC = InstCombiner::getSafeVectorConstantForBinop(BOpc, NewC, @@ -2325,7 +2345,7 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { NewI->andIRFlags(B1); if (DropNSW) NewI->setHasNoSignedWrap(false); - if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) + if (is_contained(Mask, PoisonMaskElem) && !MightCreatePoisonOrUB) NewI->dropPoisonGeneratingFlags(); } return replaceInstUsesWith(Shuf, NewBO); @@ -2361,7 +2381,7 @@ static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, SrcType->getScalarSizeInBits() / DestType->getScalarSizeInBits(); ArrayRef<int> Mask = Shuf.getShuffleMask(); for (unsigned i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] == UndefMaskElem) + if (Mask[i] == PoisonMaskElem) continue; uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio; assert(LSBIndex <= INT32_MAX && "Overflowed 32-bits"); @@ -2407,37 +2427,51 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, return SelectInst::Create(NarrowCond, NarrowX, NarrowY); } -/// Canonicalize FP negate after shuffle. -static Instruction *foldFNegShuffle(ShuffleVectorInst &Shuf, - InstCombiner::BuilderTy &Builder) { - Instruction *FNeg0; +/// Canonicalize FP negate/abs after shuffle. +static Instruction *foldShuffleOfUnaryOps(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + auto *S0 = dyn_cast<Instruction>(Shuf.getOperand(0)); Value *X; - if (!match(Shuf.getOperand(0), m_CombineAnd(m_Instruction(FNeg0), - m_FNeg(m_Value(X))))) + if (!S0 || !match(S0, m_CombineOr(m_FNeg(m_Value(X)), m_FAbs(m_Value(X))))) return nullptr; - // shuffle (fneg X), Mask --> fneg (shuffle X, Mask) - if (FNeg0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { + bool IsFNeg = S0->getOpcode() == Instruction::FNeg; + + // Match 1-input (unary) shuffle. + // shuffle (fneg/fabs X), Mask --> fneg/fabs (shuffle X, Mask) + if (S0->hasOneUse() && match(Shuf.getOperand(1), m_Undef())) { Value *NewShuf = Builder.CreateShuffleVector(X, Shuf.getShuffleMask()); - return UnaryOperator::CreateFNegFMF(NewShuf, FNeg0); + if (IsFNeg) + return UnaryOperator::CreateFNegFMF(NewShuf, S0); + + Function *FAbs = Intrinsic::getDeclaration(Shuf.getModule(), + Intrinsic::fabs, Shuf.getType()); + CallInst *NewF = CallInst::Create(FAbs, {NewShuf}); + NewF->setFastMathFlags(S0->getFastMathFlags()); + return NewF; } - Instruction *FNeg1; + // Match 2-input (binary) shuffle. + auto *S1 = dyn_cast<Instruction>(Shuf.getOperand(1)); Value *Y; - if (!match(Shuf.getOperand(1), m_CombineAnd(m_Instruction(FNeg1), - m_FNeg(m_Value(Y))))) + if (!S1 || !match(S1, m_CombineOr(m_FNeg(m_Value(Y)), m_FAbs(m_Value(Y)))) || + S0->getOpcode() != S1->getOpcode() || + (!S0->hasOneUse() && !S1->hasOneUse())) return nullptr; - // shuffle (fneg X), (fneg Y), Mask --> fneg (shuffle X, Y, Mask) - if (FNeg0->hasOneUse() || FNeg1->hasOneUse()) { - Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); - Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewShuf); - NewFNeg->copyIRFlags(FNeg0); - NewFNeg->andIRFlags(FNeg1); - return NewFNeg; + // shuf (fneg/fabs X), (fneg/fabs Y), Mask --> fneg/fabs (shuf X, Y, Mask) + Value *NewShuf = Builder.CreateShuffleVector(X, Y, Shuf.getShuffleMask()); + Instruction *NewF; + if (IsFNeg) { + NewF = UnaryOperator::CreateFNeg(NewShuf); + } else { + Function *FAbs = Intrinsic::getDeclaration(Shuf.getModule(), + Intrinsic::fabs, Shuf.getType()); + NewF = CallInst::Create(FAbs, {NewShuf}); } - - return nullptr; + NewF->copyIRFlags(S0); + NewF->andIRFlags(S1); + return NewF; } /// Canonicalize casts after shuffle. @@ -2533,7 +2567,7 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { for (unsigned i = 0; i != NumElts; ++i) { int ExtractMaskElt = Shuf.getMaskValue(i); int MaskElt = Mask[i]; - NewMask[i] = ExtractMaskElt == UndefMaskElem ? ExtractMaskElt : MaskElt; + NewMask[i] = ExtractMaskElt == PoisonMaskElem ? ExtractMaskElt : MaskElt; } return new ShuffleVectorInst(X, Y, NewMask); } @@ -2699,7 +2733,8 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { // splatting the first element of the result of the BinOp Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) { if (!match(SVI.getOperand(1), m_Undef()) || - !match(SVI.getShuffleMask(), m_ZeroMask())) + !match(SVI.getShuffleMask(), m_ZeroMask()) || + !SVI.getOperand(0)->hasOneUse()) return nullptr; Value *Op0 = SVI.getOperand(0); @@ -2759,7 +2794,6 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { } ArrayRef<int> Mask = SVI.getShuffleMask(); - Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); // Peek through a bitcasted shuffle operand by scaling the mask. If the // simulated shuffle can simplify, then this shuffle is unnecessary: @@ -2815,7 +2849,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = narrowVectorSelect(SVI, Builder)) return I; - if (Instruction *I = foldFNegShuffle(SVI, Builder)) + if (Instruction *I = foldShuffleOfUnaryOps(SVI, Builder)) return I; if (Instruction *I = foldCastShuffle(SVI, Builder)) @@ -2840,7 +2874,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return I; if (match(RHS, m_Undef()) && canEvaluateShuffled(LHS, Mask)) { - Value *V = evaluateInDifferentElementOrder(LHS, Mask); + Value *V = evaluateInDifferentElementOrder(LHS, Mask, Builder); return replaceInstUsesWith(SVI, V); } @@ -2916,15 +2950,15 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { unsigned SrcElemsPerTgtElem = TgtElemBitWidth / SrcElemBitWidth; assert(SrcElemsPerTgtElem); BegIdx /= SrcElemsPerTgtElem; - bool BCAlreadyExists = NewBCs.find(CastSrcTy) != NewBCs.end(); + bool BCAlreadyExists = NewBCs.contains(CastSrcTy); auto *NewBC = BCAlreadyExists ? NewBCs[CastSrcTy] : Builder.CreateBitCast(V, CastSrcTy, SVI.getName() + ".bc"); if (!BCAlreadyExists) NewBCs[CastSrcTy] = NewBC; - auto *Ext = Builder.CreateExtractElement( - NewBC, ConstantInt::get(Int32Ty, BegIdx), SVI.getName() + ".extract"); + auto *Ext = Builder.CreateExtractElement(NewBC, BegIdx, + SVI.getName() + ".extract"); // The shufflevector isn't being replaced: the bitcast that used it // is. InstCombine will visit the newly-created instructions. replaceInstUsesWith(*BC, Ext); @@ -3042,7 +3076,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { for (unsigned i = 0; i < VWidth; ++i) { int eltMask; if (Mask[i] < 0) { - // This element is an undef value. + // This element is a poison value. eltMask = -1; } else if (Mask[i] < (int)LHSWidth) { // This element is from left hand side vector operand. @@ -3051,27 +3085,27 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // new mask value for the element. if (newLHS != LHS) { eltMask = LHSMask[Mask[i]]; - // If the value selected is an undef value, explicitly specify it + // If the value selected is an poison value, explicitly specify it // with a -1 mask value. - if (eltMask >= (int)LHSOp0Width && isa<UndefValue>(LHSOp1)) + if (eltMask >= (int)LHSOp0Width && isa<PoisonValue>(LHSOp1)) eltMask = -1; } else eltMask = Mask[i]; } else { // This element is from right hand side vector operand // - // If the value selected is an undef value, explicitly specify it + // If the value selected is a poison value, explicitly specify it // with a -1 mask value. (case 1) - if (match(RHS, m_Undef())) + if (match(RHS, m_Poison())) eltMask = -1; // If RHS is going to be replaced (case 3 or 4), calculate the // new mask value for the element. else if (newRHS != RHS) { eltMask = RHSMask[Mask[i]-LHSWidth]; - // If the value selected is an undef value, explicitly specify it + // If the value selected is an poison value, explicitly specify it // with a -1 mask value. if (eltMask >= (int)RHSOp0Width) { - assert(match(RHSShuffle->getOperand(1), m_Undef()) && + assert(match(RHSShuffle->getOperand(1), m_Poison()) && "should have been check above"); eltMask = -1; } @@ -3102,7 +3136,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // or is a splat, do the replacement. if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) { if (!newRHS) - newRHS = UndefValue::get(newLHS->getType()); + newRHS = PoisonValue::get(newLHS->getType()); return new ShuffleVectorInst(newLHS, newRHS, newMask); } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index fb6f4f96ea48..afd6e034f46d 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -33,8 +33,6 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/InstCombine.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -47,7 +45,6 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyBlockFrequencyInfo.h" @@ -70,6 +67,7 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/EHPersonalities.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" @@ -78,7 +76,6 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" @@ -117,6 +114,11 @@ using namespace llvm::PatternMatch; STATISTIC(NumWorklistIterations, "Number of instruction combining iterations performed"); +STATISTIC(NumOneIteration, "Number of functions with one iteration"); +STATISTIC(NumTwoIterations, "Number of functions with two iterations"); +STATISTIC(NumThreeIterations, "Number of functions with three iterations"); +STATISTIC(NumFourOrMoreIterations, + "Number of functions with four or more iterations"); STATISTIC(NumCombined , "Number of insts combined"); STATISTIC(NumConstProp, "Number of constant folds"); @@ -129,7 +131,6 @@ DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); // FIXME: these limits eventually should be as low as 2. -static constexpr unsigned InstCombineDefaultMaxIterations = 1000; #ifndef NDEBUG static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; #else @@ -144,11 +145,6 @@ static cl::opt<unsigned> MaxSinkNumUsers( "instcombine-max-sink-users", cl::init(32), cl::desc("Maximum number of undroppable users for instruction sinking")); -static cl::opt<unsigned> LimitMaxIterations( - "instcombine-max-iterations", - cl::desc("Limit the maximum number of instruction combining iterations"), - cl::init(InstCombineDefaultMaxIterations)); - static cl::opt<unsigned> InfiniteLoopDetectionThreshold( "instcombine-infinite-loop-threshold", cl::desc("Number of instruction combining iterations considered an " @@ -203,6 +199,10 @@ std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( return std::nullopt; } +bool InstCombiner::isValidAddrSpaceCast(unsigned FromAS, unsigned ToAS) const { + return TTI.isValidAddrSpaceCast(FromAS, ToAS); +} + Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::emitGEPOffset(&Builder, DL, GEP); } @@ -360,13 +360,17 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) Type *DestTy = C1->getType(); Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); - Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); + Constant *FoldedC = + ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout()); + if (!FoldedC) + return false; + IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0)); IC.replaceOperand(*BinOp1, 1, FoldedC); return true; } -// Simplifies IntToPtr/PtrToInt RoundTrip Cast To BitCast. +// Simplifies IntToPtr/PtrToInt RoundTrip Cast. // inttoptr ( ptrtoint (x) ) --> x Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { auto *IntToPtr = dyn_cast<IntToPtrInst>(Val); @@ -378,10 +382,8 @@ Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { CastTy->getPointerAddressSpace() == PtrToInt->getSrcTy()->getPointerAddressSpace() && DL.getTypeSizeInBits(PtrToInt->getSrcTy()) == - DL.getTypeSizeInBits(PtrToInt->getDestTy())) { - return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, - "", PtrToInt); - } + DL.getTypeSizeInBits(PtrToInt->getDestTy())) + return PtrToInt->getOperand(0); } return nullptr; } @@ -732,6 +734,207 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, return RetVal; } +// (Binop1 (Binop2 (logic_shift X, C), C1), (logic_shift Y, C)) +// IFF +// 1) the logic_shifts match +// 2) either both binops are binops and one is `and` or +// BinOp1 is `and` +// (logic_shift (inv_logic_shift C1, C), C) == C1 or +// +// -> (logic_shift (Binop1 (Binop2 X, inv_logic_shift(C1, C)), Y), C) +// +// (Binop1 (Binop2 (logic_shift X, Amt), Mask), (logic_shift Y, Amt)) +// IFF +// 1) the logic_shifts match +// 2) BinOp1 == BinOp2 (if BinOp == `add`, then also requires `shl`). +// +// -> (BinOp (logic_shift (BinOp X, Y)), Mask) +Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { + auto IsValidBinOpc = [](unsigned Opc) { + switch (Opc) { + default: + return false; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + // Skip Sub as we only match constant masks which will canonicalize to use + // add. + return true; + } + }; + + // Check if we can distribute binop arbitrarily. `add` + `lshr` has extra + // constraints. + auto IsCompletelyDistributable = [](unsigned BinOpc1, unsigned BinOpc2, + unsigned ShOpc) { + return (BinOpc1 != Instruction::Add && BinOpc2 != Instruction::Add) || + ShOpc == Instruction::Shl; + }; + + auto GetInvShift = [](unsigned ShOpc) { + return ShOpc == Instruction::LShr ? Instruction::Shl : Instruction::LShr; + }; + + auto CanDistributeBinops = [&](unsigned BinOpc1, unsigned BinOpc2, + unsigned ShOpc, Constant *CMask, + Constant *CShift) { + // If the BinOp1 is `and` we don't need to check the mask. + if (BinOpc1 == Instruction::And) + return true; + + // For all other possible transfers we need complete distributable + // binop/shift (anything but `add` + `lshr`). + if (!IsCompletelyDistributable(BinOpc1, BinOpc2, ShOpc)) + return false; + + // If BinOp2 is `and`, any mask works (this only really helps for non-splat + // vecs, otherwise the mask will be simplified and the following check will + // handle it). + if (BinOpc2 == Instruction::And) + return true; + + // Otherwise, need mask that meets the below requirement. + // (logic_shift (inv_logic_shift Mask, ShAmt), ShAmt) == Mask + return ConstantExpr::get( + ShOpc, ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift), + CShift) == CMask; + }; + + auto MatchBinOp = [&](unsigned ShOpnum) -> Instruction * { + Constant *CMask, *CShift; + Value *X, *Y, *ShiftedX, *Mask, *Shift; + if (!match(I.getOperand(ShOpnum), + m_OneUse(m_LogicalShift(m_Value(Y), m_Value(Shift))))) + return nullptr; + if (!match(I.getOperand(1 - ShOpnum), + m_BinOp(m_Value(ShiftedX), m_Value(Mask)))) + return nullptr; + + if (!match(ShiftedX, + m_OneUse(m_LogicalShift(m_Value(X), m_Specific(Shift))))) + return nullptr; + + // Make sure we are matching instruction shifts and not ConstantExpr + auto *IY = dyn_cast<Instruction>(I.getOperand(ShOpnum)); + auto *IX = dyn_cast<Instruction>(ShiftedX); + if (!IY || !IX) + return nullptr; + + // LHS and RHS need same shift opcode + unsigned ShOpc = IY->getOpcode(); + if (ShOpc != IX->getOpcode()) + return nullptr; + + // Make sure binop is real instruction and not ConstantExpr + auto *BO2 = dyn_cast<Instruction>(I.getOperand(1 - ShOpnum)); + if (!BO2) + return nullptr; + + unsigned BinOpc = BO2->getOpcode(); + // Make sure we have valid binops. + if (!IsValidBinOpc(I.getOpcode()) || !IsValidBinOpc(BinOpc)) + return nullptr; + + // If BinOp1 == BinOp2 and it's bitwise or shl with add, then just + // distribute to drop the shift irrelevant of constants. + if (BinOpc == I.getOpcode() && + IsCompletelyDistributable(I.getOpcode(), BinOpc, ShOpc)) { + Value *NewBinOp2 = Builder.CreateBinOp(I.getOpcode(), X, Y); + Value *NewBinOp1 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(ShOpc), NewBinOp2, Shift); + return BinaryOperator::Create(I.getOpcode(), NewBinOp1, Mask); + } + + // Otherwise we can only distribute by constant shifting the mask, so + // ensure we have constants. + if (!match(Shift, m_ImmConstant(CShift))) + return nullptr; + if (!match(Mask, m_ImmConstant(CMask))) + return nullptr; + + // Check if we can distribute the binops. + if (!CanDistributeBinops(I.getOpcode(), BinOpc, ShOpc, CMask, CShift)) + return nullptr; + + Constant *NewCMask = ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift); + Value *NewBinOp2 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(BinOpc), X, NewCMask); + Value *NewBinOp1 = Builder.CreateBinOp(I.getOpcode(), Y, NewBinOp2); + return BinaryOperator::Create(static_cast<Instruction::BinaryOps>(ShOpc), + NewBinOp1, CShift); + }; + + if (Instruction *R = MatchBinOp(0)) + return R; + return MatchBinOp(1); +} + +// (Binop (zext C), (select C, T, F)) +// -> (select C, (binop 1, T), (binop 0, F)) +// +// (Binop (sext C), (select C, T, F)) +// -> (select C, (binop -1, T), (binop 0, F)) +// +// Attempt to simplify binary operations into a select with folded args, when +// one operand of the binop is a select instruction and the other operand is a +// zext/sext extension, whose value is the select condition. +Instruction * +InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { + // TODO: this simplification may be extended to any speculatable instruction, + // not just binops, and would possibly be handled better in FoldOpIntoSelect. + Instruction::BinaryOps Opc = I.getOpcode(); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *A, *CondVal, *TrueVal, *FalseVal; + Value *CastOp; + + auto MatchSelectAndCast = [&](Value *CastOp, Value *SelectOp) { + return match(CastOp, m_ZExtOrSExt(m_Value(A))) && + A->getType()->getScalarSizeInBits() == 1 && + match(SelectOp, m_Select(m_Value(CondVal), m_Value(TrueVal), + m_Value(FalseVal))); + }; + + // Make sure one side of the binop is a select instruction, and the other is a + // zero/sign extension operating on a i1. + if (MatchSelectAndCast(LHS, RHS)) + CastOp = LHS; + else if (MatchSelectAndCast(RHS, LHS)) + CastOp = RHS; + else + return nullptr; + + auto NewFoldedConst = [&](bool IsTrueArm, Value *V) { + bool IsCastOpRHS = (CastOp == RHS); + bool IsZExt = isa<ZExtInst>(CastOp); + Constant *C; + + if (IsTrueArm) { + C = Constant::getNullValue(V->getType()); + } else if (IsZExt) { + unsigned BitWidth = V->getType()->getScalarSizeInBits(); + C = Constant::getIntegerValue(V->getType(), APInt(BitWidth, 1)); + } else { + C = Constant::getAllOnesValue(V->getType()); + } + + return IsCastOpRHS ? Builder.CreateBinOp(Opc, V, C) + : Builder.CreateBinOp(Opc, C, V); + }; + + // If the value used in the zext/sext is the select condition, or the negated + // of the select condition, the binop can be simplified. + if (CondVal == A) + return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal), + NewFoldedConst(true, FalseVal)); + + if (match(A, m_Not(m_Specific(CondVal)))) + return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal), + NewFoldedConst(false, FalseVal)); + + return nullptr; +} + Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); @@ -948,6 +1151,7 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, /// Freely adapt every user of V as-if V was changed to !V. /// WARNING: only if canFreelyInvertAllUsersOf() said this can be done. void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { + assert(!isa<Constant>(I) && "Shouldn't invert users of constant"); for (User *U : make_early_inc_range(I->users())) { if (U == IgnoredUser) continue; // Don't consider this user. @@ -1033,63 +1237,39 @@ Instruction *InstCombinerImpl::foldBinopOfSextBoolToSelect(BinaryOperator &BO) { return SelectInst::Create(X, TVal, FVal); } -static Constant *constantFoldOperationIntoSelectOperand( - Instruction &I, SelectInst *SI, Value *SO) { - auto *ConstSO = dyn_cast<Constant>(SO); - if (!ConstSO) - return nullptr; - +static Constant *constantFoldOperationIntoSelectOperand(Instruction &I, + SelectInst *SI, + bool IsTrueArm) { SmallVector<Constant *> ConstOps; for (Value *Op : I.operands()) { - if (Op == SI) - ConstOps.push_back(ConstSO); - else if (auto *C = dyn_cast<Constant>(Op)) - ConstOps.push_back(C); - else - llvm_unreachable("Operands should be select or constant"); - } - return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); -} + CmpInst::Predicate Pred; + Constant *C = nullptr; + if (Op == SI) { + C = dyn_cast<Constant>(IsTrueArm ? SI->getTrueValue() + : SI->getFalseValue()); + } else if (match(SI->getCondition(), + m_ICmp(Pred, m_Specific(Op), m_Constant(C))) && + Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) && + isGuaranteedNotToBeUndefOrPoison(C)) { + // Pass + } else { + C = dyn_cast<Constant>(Op); + } + if (C == nullptr) + return nullptr; -static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, - InstCombiner::BuilderTy &Builder) { - if (auto *Cast = dyn_cast<CastInst>(&I)) - return Builder.CreateCast(Cast->getOpcode(), SO, I.getType()); - - if (auto *II = dyn_cast<IntrinsicInst>(&I)) { - assert(canConstantFoldCallTo(II, cast<Function>(II->getCalledOperand())) && - "Expected constant-foldable intrinsic"); - Intrinsic::ID IID = II->getIntrinsicID(); - if (II->arg_size() == 1) - return Builder.CreateUnaryIntrinsic(IID, SO); - - // This works for real binary ops like min/max (where we always expect the - // constant operand to be canonicalized as op1) and unary ops with a bonus - // constant argument like ctlz/cttz. - // TODO: Handle non-commutative binary intrinsics as below for binops. - assert(II->arg_size() == 2 && "Expected binary intrinsic"); - assert(isa<Constant>(II->getArgOperand(1)) && "Expected constant operand"); - return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); + ConstOps.push_back(C); } - if (auto *EI = dyn_cast<ExtractElementInst>(&I)) - return Builder.CreateExtractElement(SO, EI->getIndexOperand()); - - assert(I.isBinaryOp() && "Unexpected opcode for select folding"); - - // Figure out if the constant is the left or the right argument. - bool ConstIsRHS = isa<Constant>(I.getOperand(1)); - Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); - - Value *Op0 = SO, *Op1 = ConstOperand; - if (!ConstIsRHS) - std::swap(Op0, Op1); + return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); +} - Value *NewBO = Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), Op0, - Op1, SO->getName() + ".op"); - if (auto *NewBOI = dyn_cast<Instruction>(NewBO)) - NewBOI->copyIRFlags(&I); - return NewBO; +static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, + Value *NewOp, InstCombiner &IC) { + Instruction *Clone = I.clone(); + Clone->replaceUsesOfWith(SI, NewOp); + IC.InsertNewInstBefore(Clone, *SI); + return Clone; } Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, @@ -1122,56 +1302,17 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return nullptr; } - // Test if a CmpInst instruction is used exclusively by a select as - // part of a minimum or maximum operation. If so, refrain from doing - // any other folding. This helps out other analyses which understand - // non-obfuscated minimum and maximum idioms, such as ScalarEvolution - // and CodeGen. And in this case, at least one of the comparison - // operands has at least one user besides the compare (the select), - // which would often largely negate the benefit of folding anyway. - if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { - if (CI->hasOneUse()) { - Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); - - // FIXME: This is a hack to avoid infinite looping with min/max patterns. - // We have to ensure that vector constants that only differ with - // undef elements are treated as equivalent. - auto areLooselyEqual = [](Value *A, Value *B) { - if (A == B) - return true; - - // Test for vector constants. - Constant *ConstA, *ConstB; - if (!match(A, m_Constant(ConstA)) || !match(B, m_Constant(ConstB))) - return false; - - // TODO: Deal with FP constants? - if (!A->getType()->isIntOrIntVectorTy() || A->getType() != B->getType()) - return false; - - // Compare for equality including undefs as equal. - auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB); - const APInt *C; - return match(Cmp, m_APIntAllowUndef(C)) && C->isOne(); - }; - - if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) || - (areLooselyEqual(FV, Op0) && areLooselyEqual(TV, Op1))) - return nullptr; - } - } - // Make sure that one of the select arms constant folds successfully. - Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV); - Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV); + Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); + Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); if (!NewTV && !NewFV) return nullptr; // Create an instruction for the arm that did not fold. if (!NewTV) - NewTV = foldOperationIntoSelectOperand(Op, TV, Builder); + NewTV = foldOperationIntoSelectOperand(Op, SI, TV, *this); if (!NewFV) - NewFV = foldOperationIntoSelectOperand(Op, FV, Builder); + NewFV = foldOperationIntoSelectOperand(Op, SI, FV, *this); return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } @@ -1263,6 +1404,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues()); InsertNewInstBefore(NewPN, *PN); NewPN->takeName(PN); + NewPN->setDebugLoc(PN->getDebugLoc()); // If we are going to have to insert a new computation, do so right before the // predecessor's terminator. @@ -1291,6 +1433,10 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { replaceInstUsesWith(*User, NewPN); eraseInstFromFunction(*User); } + + replaceAllDbgUsesWith(const_cast<PHINode &>(*PN), + const_cast<PHINode &>(*NewPN), + const_cast<PHINode &>(*PN), DT); return replaceInstUsesWith(I, NewPN); } @@ -1301,7 +1447,7 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { auto *Phi0 = dyn_cast<PHINode>(BO.getOperand(0)); auto *Phi1 = dyn_cast<PHINode>(BO.getOperand(1)); if (!Phi0 || !Phi1 || !Phi0->hasOneUse() || !Phi1->hasOneUse() || - Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2) + Phi0->getNumOperands() != Phi1->getNumOperands()) return nullptr; // TODO: Remove the restriction for binop being in the same block as the phis. @@ -1309,6 +1455,51 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { BO.getParent() != Phi1->getParent()) return nullptr; + // Fold if there is at least one specific constant value in phi0 or phi1's + // incoming values that comes from the same block and this specific constant + // value can be used to do optimization for specific binary operator. + // For example: + // %phi0 = phi i32 [0, %bb0], [%i, %bb1] + // %phi1 = phi i32 [%j, %bb0], [0, %bb1] + // %add = add i32 %phi0, %phi1 + // ==> + // %add = phi i32 [%j, %bb0], [%i, %bb1] + Constant *C = ConstantExpr::getBinOpIdentity(BO.getOpcode(), BO.getType(), + /*AllowRHSConstant*/ false); + if (C) { + SmallVector<Value *, 4> NewIncomingValues; + auto CanFoldIncomingValuePair = [&](std::tuple<Use &, Use &> T) { + auto &Phi0Use = std::get<0>(T); + auto &Phi1Use = std::get<1>(T); + if (Phi0->getIncomingBlock(Phi0Use) != Phi1->getIncomingBlock(Phi1Use)) + return false; + Value *Phi0UseV = Phi0Use.get(); + Value *Phi1UseV = Phi1Use.get(); + if (Phi0UseV == C) + NewIncomingValues.push_back(Phi1UseV); + else if (Phi1UseV == C) + NewIncomingValues.push_back(Phi0UseV); + else + return false; + return true; + }; + + if (all_of(zip(Phi0->operands(), Phi1->operands()), + CanFoldIncomingValuePair)) { + PHINode *NewPhi = + PHINode::Create(Phi0->getType(), Phi0->getNumOperands()); + assert(NewIncomingValues.size() == Phi0->getNumOperands() && + "The number of collected incoming values should equal the number " + "of the original PHINode operands!"); + for (unsigned I = 0; I < Phi0->getNumOperands(); I++) + NewPhi->addIncoming(NewIncomingValues[I], Phi0->getIncomingBlock(I)); + return NewPhi; + } + } + + if (Phi0->getNumOperands() != 2 || Phi1->getNumOperands() != 2) + return nullptr; + // Match a pair of incoming constants for one of the predecessor blocks. BasicBlock *ConstBB, *OtherBB; Constant *C0, *C1; @@ -1374,28 +1565,6 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { return nullptr; } -/// Given a pointer type and a constant offset, determine whether or not there -/// is a sequence of GEP indices into the pointed type that will land us at the -/// specified offset. If so, fill them into NewIndices and return the resultant -/// element type, otherwise return null. -static Type *findElementAtOffset(PointerType *PtrTy, int64_t IntOffset, - SmallVectorImpl<Value *> &NewIndices, - const DataLayout &DL) { - // Only used by visitGEPOfBitcast(), which is skipped for opaque pointers. - Type *Ty = PtrTy->getNonOpaquePointerElementType(); - if (!Ty->isSized()) - return nullptr; - - APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), IntOffset); - SmallVector<APInt> Indices = DL.getGEPIndicesForOffset(Ty, Offset); - if (!Offset.isZero()) - return nullptr; - - for (const APInt &Index : Indices) - NewIndices.push_back(ConstantInt::get(PtrTy->getContext(), Index)); - return Ty; -} - static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { // If this GEP has only 0 indices, it is the same pointer as // Src. If Src is not a trivial GEP too, don't combine @@ -1406,248 +1575,6 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) { return true; } -/// Return a value X such that Val = X * Scale, or null if none. -/// If the multiplication is known not to overflow, then NoSignedWrap is set. -Value *InstCombinerImpl::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { - assert(isa<IntegerType>(Val->getType()) && "Can only descale integers!"); - assert(cast<IntegerType>(Val->getType())->getBitWidth() == - Scale.getBitWidth() && "Scale not compatible with value!"); - - // If Val is zero or Scale is one then Val = Val * Scale. - if (match(Val, m_Zero()) || Scale == 1) { - NoSignedWrap = true; - return Val; - } - - // If Scale is zero then it does not divide Val. - if (Scale.isMinValue()) - return nullptr; - - // Look through chains of multiplications, searching for a constant that is - // divisible by Scale. For example, descaling X*(Y*(Z*4)) by a factor of 4 - // will find the constant factor 4 and produce X*(Y*Z). Descaling X*(Y*8) by - // a factor of 4 will produce X*(Y*2). The principle of operation is to bore - // down from Val: - // - // Val = M1 * X || Analysis starts here and works down - // M1 = M2 * Y || Doesn't descend into terms with more - // M2 = Z * 4 \/ than one use - // - // Then to modify a term at the bottom: - // - // Val = M1 * X - // M1 = Z * Y || Replaced M2 with Z - // - // Then to work back up correcting nsw flags. - - // Op - the term we are currently analyzing. Starts at Val then drills down. - // Replaced with its descaled value before exiting from the drill down loop. - Value *Op = Val; - - // Parent - initially null, but after drilling down notes where Op came from. - // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the - // 0'th operand of Val. - std::pair<Instruction *, unsigned> Parent; - - // Set if the transform requires a descaling at deeper levels that doesn't - // overflow. - bool RequireNoSignedWrap = false; - - // Log base 2 of the scale. Negative if not a power of 2. - int32_t logScale = Scale.exactLogBase2(); - - for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { - // If Op is a constant divisible by Scale then descale to the quotient. - APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth. - APInt::sdivrem(CI->getValue(), Scale, Quotient, Remainder); - if (!Remainder.isMinValue()) - // Not divisible by Scale. - return nullptr; - // Replace with the quotient in the parent. - Op = ConstantInt::get(CI->getType(), Quotient); - NoSignedWrap = true; - break; - } - - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) { - if (BO->getOpcode() == Instruction::Mul) { - // Multiplication. - NoSignedWrap = BO->hasNoSignedWrap(); - if (RequireNoSignedWrap && !NoSignedWrap) - return nullptr; - - // There are three cases for multiplication: multiplication by exactly - // the scale, multiplication by a constant different to the scale, and - // multiplication by something else. - Value *LHS = BO->getOperand(0); - Value *RHS = BO->getOperand(1); - - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // Multiplication by a constant. - if (CI->getValue() == Scale) { - // Multiplication by exactly the scale, replace the multiplication - // by its left-hand side in the parent. - Op = LHS; - break; - } - - // Otherwise drill down into the constant. - if (!Op->hasOneUse()) - return nullptr; - - Parent = std::make_pair(BO, 1); - continue; - } - - // Multiplication by something else. Drill down into the left-hand side - // since that's where the reassociate pass puts the good stuff. - if (!Op->hasOneUse()) - return nullptr; - - Parent = std::make_pair(BO, 0); - continue; - } - - if (logScale > 0 && BO->getOpcode() == Instruction::Shl && - isa<ConstantInt>(BO->getOperand(1))) { - // Multiplication by a power of 2. - NoSignedWrap = BO->hasNoSignedWrap(); - if (RequireNoSignedWrap && !NoSignedWrap) - return nullptr; - - Value *LHS = BO->getOperand(0); - int32_t Amt = cast<ConstantInt>(BO->getOperand(1))-> - getLimitedValue(Scale.getBitWidth()); - // Op = LHS << Amt. - - if (Amt == logScale) { - // Multiplication by exactly the scale, replace the multiplication - // by its left-hand side in the parent. - Op = LHS; - break; - } - if (Amt < logScale || !Op->hasOneUse()) - return nullptr; - - // Multiplication by more than the scale. Reduce the multiplying amount - // by the scale in the parent. - Parent = std::make_pair(BO, 1); - Op = ConstantInt::get(BO->getType(), Amt - logScale); - break; - } - } - - if (!Op->hasOneUse()) - return nullptr; - - if (CastInst *Cast = dyn_cast<CastInst>(Op)) { - if (Cast->getOpcode() == Instruction::SExt) { - // Op is sign-extended from a smaller type, descale in the smaller type. - unsigned SmallSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); - APInt SmallScale = Scale.trunc(SmallSize); - // Suppose Op = sext X, and we descale X as Y * SmallScale. We want to - // descale Op as (sext Y) * Scale. In order to have - // sext (Y * SmallScale) = (sext Y) * Scale - // some conditions need to hold however: SmallScale must sign-extend to - // Scale and the multiplication Y * SmallScale should not overflow. - if (SmallScale.sext(Scale.getBitWidth()) != Scale) - // SmallScale does not sign-extend to Scale. - return nullptr; - assert(SmallScale.exactLogBase2() == logScale); - // Require that Y * SmallScale must not overflow. - RequireNoSignedWrap = true; - - // Drill down through the cast. - Parent = std::make_pair(Cast, 0); - Scale = SmallScale; - continue; - } - - if (Cast->getOpcode() == Instruction::Trunc) { - // Op is truncated from a larger type, descale in the larger type. - // Suppose Op = trunc X, and we descale X as Y * sext Scale. Then - // trunc (Y * sext Scale) = (trunc Y) * Scale - // always holds. However (trunc Y) * Scale may overflow even if - // trunc (Y * sext Scale) does not, so nsw flags need to be cleared - // from this point up in the expression (see later). - if (RequireNoSignedWrap) - return nullptr; - - // Drill down through the cast. - unsigned LargeSize = Cast->getSrcTy()->getPrimitiveSizeInBits(); - Parent = std::make_pair(Cast, 0); - Scale = Scale.sext(LargeSize); - if (logScale + 1 == (int32_t)Cast->getType()->getPrimitiveSizeInBits()) - logScale = -1; - assert(Scale.exactLogBase2() == logScale); - continue; - } - } - - // Unsupported expression, bail out. - return nullptr; - } - - // If Op is zero then Val = Op * Scale. - if (match(Op, m_Zero())) { - NoSignedWrap = true; - return Op; - } - - // We know that we can successfully descale, so from here on we can safely - // modify the IR. Op holds the descaled version of the deepest term in the - // expression. NoSignedWrap is 'true' if multiplying Op by Scale is known - // not to overflow. - - if (!Parent.first) - // The expression only had one term. - return Op; - - // Rewrite the parent using the descaled version of its operand. - assert(Parent.first->hasOneUse() && "Drilled down when more than one use!"); - assert(Op != Parent.first->getOperand(Parent.second) && - "Descaling was a no-op?"); - replaceOperand(*Parent.first, Parent.second, Op); - Worklist.push(Parent.first); - - // Now work back up the expression correcting nsw flags. The logic is based - // on the following observation: if X * Y is known not to overflow as a signed - // multiplication, and Y is replaced by a value Z with smaller absolute value, - // then X * Z will not overflow as a signed multiplication either. As we work - // our way up, having NoSignedWrap 'true' means that the descaled value at the - // current level has strictly smaller absolute value than the original. - Instruction *Ancestor = Parent.first; - do { - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Ancestor)) { - // If the multiplication wasn't nsw then we can't say anything about the - // value of the descaled multiplication, and we have to clear nsw flags - // from this point on up. - bool OpNoSignedWrap = BO->hasNoSignedWrap(); - NoSignedWrap &= OpNoSignedWrap; - if (NoSignedWrap != OpNoSignedWrap) { - BO->setHasNoSignedWrap(NoSignedWrap); - Worklist.push(Ancestor); - } - } else if (Ancestor->getOpcode() == Instruction::Trunc) { - // The fact that the descaled input to the trunc has smaller absolute - // value than the original input doesn't tell us anything useful about - // the absolute values of the truncations. - NoSignedWrap = false; - } - assert((Ancestor->getOpcode() != Instruction::SExt || NoSignedWrap) && - "Failed to keep proper track of nsw flags while drilling down?"); - - if (Ancestor == Val) - // Got to the top, all done! - return Val; - - // Move up one level in the expression. - assert(Ancestor->hasOneUse() && "Drilled down when more than one use!"); - Ancestor = Ancestor->user_back(); - } while (true); -} - Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (!isa<VectorType>(Inst.getType())) return nullptr; @@ -1748,9 +1675,9 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { // TODO: Allow arbitrary shuffles by shuffling after binop? // That might be legal, but we have to deal with poison. if (LShuf->isSelect() && - !is_contained(LShuf->getShuffleMask(), UndefMaskElem) && + !is_contained(LShuf->getShuffleMask(), PoisonMaskElem) && RShuf->isSelect() && - !is_contained(RShuf->getShuffleMask(), UndefMaskElem)) { + !is_contained(RShuf->getShuffleMask(), PoisonMaskElem)) { // Example: // LHS = shuffle V1, V2, <0, 5, 6, 3> // RHS = shuffle V2, V1, <0, 5, 6, 3> @@ -1991,50 +1918,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; - if (Src->getResultElementType() == GEP.getSourceElementType() && - Src->getNumOperands() == 2 && GEP.getNumOperands() == 2 && - Src->hasOneUse()) { - Value *GO1 = GEP.getOperand(1); - Value *SO1 = Src->getOperand(1); - - if (LI) { - // Try to reassociate loop invariant GEP chains to enable LICM. - if (Loop *L = LI->getLoopFor(GEP.getParent())) { - // Reassociate the two GEPs if SO1 is variant in the loop and GO1 is - // invariant: this breaks the dependence between GEPs and allows LICM - // to hoist the invariant part out of the loop. - if (L->isLoopInvariant(GO1) && !L->isLoopInvariant(SO1)) { - // The swapped GEPs are inbounds if both original GEPs are inbounds - // and the sign of the offsets is the same. For simplicity, only - // handle both offsets being non-negative. - bool IsInBounds = Src->isInBounds() && GEP.isInBounds() && - isKnownNonNegative(SO1, DL, 0, &AC, &GEP, &DT) && - isKnownNonNegative(GO1, DL, 0, &AC, &GEP, &DT); - // Put NewSrc at same location as %src. - Builder.SetInsertPoint(cast<Instruction>(Src)); - Value *NewSrc = Builder.CreateGEP(GEP.getSourceElementType(), - Src->getPointerOperand(), GO1, - Src->getName(), IsInBounds); - GetElementPtrInst *NewGEP = GetElementPtrInst::Create( - GEP.getSourceElementType(), NewSrc, {SO1}); - NewGEP->setIsInBounds(IsInBounds); - return NewGEP; - } - } - } - } - - // Note that if our source is a gep chain itself then we wait for that - // chain to be resolved before we perform this transformation. This - // avoids us creating a TON of code in some cases. - if (auto *SrcGEP = dyn_cast<GEPOperator>(Src->getOperand(0))) - if (SrcGEP->getNumOperands() == 2 && shouldMergeGEPs(*Src, *SrcGEP)) - return nullptr; // Wait until our source is folded to completion. - // For constant GEPs, use a more general offset-based folding approach. - // Only do this for opaque pointers, as the result element type may change. Type *PtrTy = Src->getType()->getScalarType(); - if (PtrTy->isOpaquePointerTy() && GEP.hasAllConstantIndices() && + if (GEP.hasAllConstantIndices() && (Src->hasOneUse() || Src->hasAllConstantIndices())) { // Split Src into a variable part and a constant suffix. gep_type_iterator GTI = gep_type_begin(*Src); @@ -2077,13 +1963,11 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, // If both GEP are constant-indexed, and cannot be merged in either way, // convert them to a GEP of i8. if (Src->hasAllConstantIndices()) - return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) - ? GetElementPtrInst::CreateInBounds( - Builder.getInt8Ty(), Src->getOperand(0), - Builder.getInt(OffsetOld), GEP.getName()) - : GetElementPtrInst::Create( - Builder.getInt8Ty(), Src->getOperand(0), - Builder.getInt(OffsetOld), GEP.getName()); + return replaceInstUsesWith( + GEP, Builder.CreateGEP( + Builder.getInt8Ty(), Src->getOperand(0), + Builder.getInt(OffsetOld), "", + isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)))); return nullptr; } @@ -2100,13 +1984,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, IsInBounds &= Idx.isNonNegative() == ConstIndices[0].isNonNegative(); } - return IsInBounds - ? GetElementPtrInst::CreateInBounds(Src->getSourceElementType(), - Src->getOperand(0), Indices, - GEP.getName()) - : GetElementPtrInst::Create(Src->getSourceElementType(), - Src->getOperand(0), Indices, - GEP.getName()); + return replaceInstUsesWith( + GEP, Builder.CreateGEP(Src->getSourceElementType(), Src->getOperand(0), + Indices, "", IsInBounds)); } if (Src->getResultElementType() != GEP.getSourceElementType()) @@ -2160,118 +2040,10 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, } if (!Indices.empty()) - return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) - ? GetElementPtrInst::CreateInBounds( - Src->getSourceElementType(), Src->getOperand(0), Indices, - GEP.getName()) - : GetElementPtrInst::Create(Src->getSourceElementType(), - Src->getOperand(0), Indices, - GEP.getName()); - - return nullptr; -} - -// Note that we may have also stripped an address space cast in between. -Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, - GetElementPtrInst &GEP) { - // With opaque pointers, there is no pointer element type we can use to - // adjust the GEP type. - PointerType *SrcType = cast<PointerType>(BCI->getSrcTy()); - if (SrcType->isOpaque()) - return nullptr; - - Type *GEPEltType = GEP.getSourceElementType(); - Type *SrcEltType = SrcType->getNonOpaquePointerElementType(); - Value *SrcOp = BCI->getOperand(0); - - // GEP directly using the source operand if this GEP is accessing an element - // of a bitcasted pointer to vector or array of the same dimensions: - // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z - // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z - auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy, - const DataLayout &DL) { - auto *VecVTy = cast<FixedVectorType>(VecTy); - return ArrTy->getArrayElementType() == VecVTy->getElementType() && - ArrTy->getArrayNumElements() == VecVTy->getNumElements() && - DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy); - }; - if (GEP.getNumOperands() == 3 && - ((GEPEltType->isArrayTy() && isa<FixedVectorType>(SrcEltType) && - areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) || - (isa<FixedVectorType>(GEPEltType) && SrcEltType->isArrayTy() && - areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) { - - // Create a new GEP here, as using `setOperand()` followed by - // `setSourceElementType()` won't actually update the type of the - // existing GEP Value. Causing issues if this Value is accessed when - // constructing an AddrSpaceCastInst - SmallVector<Value *, 8> Indices(GEP.indices()); - Value *NGEP = - Builder.CreateGEP(SrcEltType, SrcOp, Indices, "", GEP.isInBounds()); - NGEP->takeName(&GEP); - - // Preserve GEP address space to satisfy users - if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(NGEP, GEP.getType()); - - return replaceInstUsesWith(GEP, NGEP); - } - - // See if we can simplify: - // X = bitcast A* to B* - // Y = gep X, <...constant indices...> - // into a gep of the original struct. This is important for SROA and alias - // analysis of unions. If "A" is also a bitcast, wait for A/X to be merged. - unsigned OffsetBits = DL.getIndexTypeSizeInBits(GEP.getType()); - APInt Offset(OffsetBits, 0); - - // If the bitcast argument is an allocation, The bitcast is for convertion - // to actual type of allocation. Removing such bitcasts, results in having - // GEPs with i8* base and pure byte offsets. That means GEP is not aware of - // struct or array hierarchy. - // By avoiding such GEPs, phi translation and MemoryDependencyAnalysis have - // a better chance to succeed. - if (!isa<BitCastInst>(SrcOp) && GEP.accumulateConstantOffset(DL, Offset) && - !isAllocationFn(SrcOp, &TLI)) { - // If this GEP instruction doesn't move the pointer, just replace the GEP - // with a bitcast of the real input to the dest type. - if (!Offset) { - // If the bitcast is of an allocation, and the allocation will be - // converted to match the type of the cast, don't touch this. - if (isa<AllocaInst>(SrcOp)) { - // See if the bitcast simplifies, if so, don't nuke this GEP yet. - if (Instruction *I = visitBitCast(*BCI)) { - if (I != BCI) { - I->takeName(BCI); - I->insertInto(BCI->getParent(), BCI->getIterator()); - replaceInstUsesWith(*BCI, I); - } - return &GEP; - } - } - - if (SrcType->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(SrcOp, GEP.getType()); - return new BitCastInst(SrcOp, GEP.getType()); - } - - // Otherwise, if the offset is non-zero, we need to find out if there is a - // field at Offset in 'A's type. If so, we can pull the cast through the - // GEP. - SmallVector<Value *, 8> NewIndices; - if (findElementAtOffset(SrcType, Offset.getSExtValue(), NewIndices, DL)) { - Value *NGEP = Builder.CreateGEP(SrcEltType, SrcOp, NewIndices, "", - GEP.isInBounds()); - - if (NGEP->getType() == GEP.getType()) - return replaceInstUsesWith(GEP, NGEP); - NGEP->takeName(&GEP); - - if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) - return new AddrSpaceCastInst(NGEP, GEP.getType()); - return new BitCastInst(NGEP, GEP.getType()); - } - } + return replaceInstUsesWith( + GEP, Builder.CreateGEP( + Src->getSourceElementType(), Src->getOperand(0), Indices, "", + isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)))); return nullptr; } @@ -2497,192 +2269,6 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEPType->isVectorTy()) return nullptr; - // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). - Value *StrippedPtr = PtrOp->stripPointerCasts(); - PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType()); - - // TODO: The basic approach of these folds is not compatible with opaque - // pointers, because we can't use bitcasts as a hint for a desirable GEP - // type. Instead, we should perform canonicalization directly on the GEP - // type. For now, skip these. - if (StrippedPtr != PtrOp && !StrippedPtrTy->isOpaque()) { - bool HasZeroPointerIndex = false; - Type *StrippedPtrEltTy = StrippedPtrTy->getNonOpaquePointerElementType(); - - if (auto *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) - HasZeroPointerIndex = C->isZero(); - - // Transform: GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... - // into : GEP [10 x i8]* X, i32 0, ... - // - // Likewise, transform: GEP (bitcast i8* X to [0 x i8]*), i32 0, ... - // into : GEP i8* X, ... - // - // This occurs when the program declares an array extern like "int X[];" - if (HasZeroPointerIndex) { - if (auto *CATy = dyn_cast<ArrayType>(GEPEltType)) { - // GEP (bitcast i8* X to [0 x i8]*), i32 0, ... ? - if (CATy->getElementType() == StrippedPtrEltTy) { - // -> GEP i8* X, ... - SmallVector<Value *, 8> Idx(drop_begin(GEP.indices())); - GetElementPtrInst *Res = GetElementPtrInst::Create( - StrippedPtrEltTy, StrippedPtr, Idx, GEP.getName()); - Res->setIsInBounds(GEP.isInBounds()); - if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) - return Res; - // Insert Res, and create an addrspacecast. - // e.g., - // GEP (addrspacecast i8 addrspace(1)* X to [0 x i8]*), i32 0, ... - // -> - // %0 = GEP i8 addrspace(1)* X, ... - // addrspacecast i8 addrspace(1)* %0 to i8* - return new AddrSpaceCastInst(Builder.Insert(Res), GEPType); - } - - if (auto *XATy = dyn_cast<ArrayType>(StrippedPtrEltTy)) { - // GEP (bitcast [10 x i8]* X to [0 x i8]*), i32 0, ... ? - if (CATy->getElementType() == XATy->getElementType()) { - // -> GEP [10 x i8]* X, i32 0, ... - // At this point, we know that the cast source type is a pointer - // to an array of the same type as the destination pointer - // array. Because the array type is never stepped over (there - // is a leading zero) we can fold the cast into this GEP. - if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) { - GEP.setSourceElementType(XATy); - return replaceOperand(GEP, 0, StrippedPtr); - } - // Cannot replace the base pointer directly because StrippedPtr's - // address space is different. Instead, create a new GEP followed by - // an addrspacecast. - // e.g., - // GEP (addrspacecast [10 x i8] addrspace(1)* X to [0 x i8]*), - // i32 0, ... - // -> - // %0 = GEP [10 x i8] addrspace(1)* X, ... - // addrspacecast i8 addrspace(1)* %0 to i8* - SmallVector<Value *, 8> Idx(GEP.indices()); - Value *NewGEP = - Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName(), GEP.isInBounds()); - return new AddrSpaceCastInst(NewGEP, GEPType); - } - } - } - } else if (GEP.getNumOperands() == 2 && !IsGEPSrcEleScalable) { - // Skip if GEP source element type is scalable. The type alloc size is - // unknown at compile-time. - // Transform things like: %t = getelementptr i32* - // bitcast ([2 x i32]* %str to i32*), i32 %V into: %t1 = getelementptr [2 - // x i32]* %str, i32 0, i32 %V; bitcast - if (StrippedPtrEltTy->isArrayTy() && - DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) == - DL.getTypeAllocSize(GEPEltType)) { - Type *IdxType = DL.getIndexType(GEPType); - Value *Idx[2] = {Constant::getNullValue(IdxType), GEP.getOperand(1)}; - Value *NewGEP = Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Idx, - GEP.getName(), GEP.isInBounds()); - - // V and GEP are both pointer types --> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, GEPType); - } - - // Transform things like: - // %V = mul i64 %N, 4 - // %t = getelementptr i8* bitcast (i32* %arr to i8*), i32 %V - // into: %t1 = getelementptr i32* %arr, i32 %N; bitcast - if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { - // Check that changing the type amounts to dividing the index by a scale - // factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); - uint64_t SrcSize = - DL.getTypeAllocSize(StrippedPtrEltTy).getFixedValue(); - if (ResSize && SrcSize % ResSize == 0) { - Value *Idx = GEP.getOperand(1); - unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); - uint64_t Scale = SrcSize / ResSize; - - // Earlier transforms ensure that the index has the right type - // according to Data Layout, which considerably simplifies the - // logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIndexType(GEPType) && - "Index type does not match the Data Layout preferences"); - - bool NSW; - if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { - // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. - // If the multiplication NewIdx * Scale may overflow then the new - // GEP may not be "inbounds". - Value *NewGEP = - Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, NewIdx, - GEP.getName(), GEP.isInBounds() && NSW); - - // The NewGEP must be pointer typed, so must the old one -> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEPType); - } - } - } - - // Similarly, transform things like: - // getelementptr i8* bitcast ([100 x double]* X to i8*), i32 %tmp - // (where tmp = 8*tmp2) into: - // getelementptr [100 x double]* %arr, i32 0, i32 %tmp2; bitcast - if (GEPEltType->isSized() && StrippedPtrEltTy->isSized() && - StrippedPtrEltTy->isArrayTy()) { - // Check that changing to the array element type amounts to dividing the - // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); - uint64_t ArrayEltSize = - DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) - .getFixedValue(); - if (ResSize && ArrayEltSize % ResSize == 0) { - Value *Idx = GEP.getOperand(1); - unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); - uint64_t Scale = ArrayEltSize / ResSize; - - // Earlier transforms ensure that the index has the right type - // according to the Data Layout, which considerably simplifies - // the logic by eliminating implicit casts. - assert(Idx->getType() == DL.getIndexType(GEPType) && - "Index type does not match the Data Layout preferences"); - - bool NSW; - if (Value *NewIdx = Descale(Idx, APInt(BitWidth, Scale), NSW)) { - // Successfully decomposed Idx as NewIdx * Scale, form a new GEP. - // If the multiplication NewIdx * Scale may overflow then the new - // GEP may not be "inbounds". - Type *IndTy = DL.getIndexType(GEPType); - Value *Off[2] = {Constant::getNullValue(IndTy), NewIdx}; - - Value *NewGEP = - Builder.CreateGEP(StrippedPtrEltTy, StrippedPtr, Off, - GEP.getName(), GEP.isInBounds() && NSW); - // The NewGEP must be pointer typed, so must the old one -> BitCast - return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP, - GEPType); - } - } - } - } - } - - // addrspacecast between types is canonicalized as a bitcast, then an - // addrspacecast. To take advantage of the below bitcast + struct GEP, look - // through the addrspacecast. - Value *ASCStrippedPtrOp = PtrOp; - if (auto *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) { - // X = bitcast A addrspace(1)* to B addrspace(1)* - // Y = addrspacecast A addrspace(1)* to B addrspace(2)* - // Z = gep Y, <...constant indices...> - // Into an addrspacecasted GEP of the struct. - if (auto *BC = dyn_cast<BitCastInst>(ASC->getOperand(0))) - ASCStrippedPtrOp = BC; - } - - if (auto *BCI = dyn_cast<BitCastInst>(ASCStrippedPtrOp)) - if (Instruction *I = visitGEPOfBitcast(BCI, GEP)) - return I; - if (!GEP.isInBounds()) { unsigned IdxWidth = DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); @@ -2690,12 +2276,13 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Value *UnderlyingPtrOp = PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, BasePtrOffset); - if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { + bool CanBeNull, CanBeFreed; + uint64_t DerefBytes = UnderlyingPtrOp->getPointerDereferenceableBytes( + DL, CanBeNull, CanBeFreed); + if (!CanBeNull && !CanBeFreed && DerefBytes != 0) { if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && BasePtrOffset.isNonNegative()) { - APInt AllocSize( - IdxWidth, - DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinValue()); + APInt AllocSize(IdxWidth, DerefBytes); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( GEP.getSourceElementType(), PtrOp, Indices, GEP.getName()); @@ -2881,8 +2468,11 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { if (II->getIntrinsicID() == Intrinsic::objectsize) { - Value *Result = - lowerObjectSizeCall(II, DL, &TLI, AA, /*MustSucceed=*/true); + SmallVector<Instruction *> InsertedInstructions; + Value *Result = lowerObjectSizeCall( + II, DL, &TLI, AA, /*MustSucceed=*/true, &InsertedInstructions); + for (Instruction *Inserted : InsertedInstructions) + Worklist.add(Inserted); replaceInstUsesWith(*I, Result); eraseInstFromFunction(*I); Users[i] = nullptr; // Skip examining in the next loop. @@ -3089,50 +2679,27 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) { return nullptr; } -static bool isMustTailCall(Value *V) { - if (auto *CI = dyn_cast<CallInst>(V)) - return CI->isMustTailCall(); - return false; -} - Instruction *InstCombinerImpl::visitReturnInst(ReturnInst &RI) { - if (RI.getNumOperands() == 0) // ret void - return nullptr; - - Value *ResultOp = RI.getOperand(0); - Type *VTy = ResultOp->getType(); - if (!VTy->isIntegerTy() || isa<Constant>(ResultOp)) - return nullptr; - - // Don't replace result of musttail calls. - if (isMustTailCall(ResultOp)) - return nullptr; - - // There might be assume intrinsics dominating this return that completely - // determine the value. If so, constant fold it. - KnownBits Known = computeKnownBits(ResultOp, 0, &RI); - if (Known.isConstant()) - return replaceOperand(RI, 0, - Constant::getIntegerValue(VTy, Known.getConstant())); - + // Nothing for now. return nullptr; } // WARNING: keep in sync with SimplifyCFGOpt::simplifyUnreachable()! -Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { +bool InstCombinerImpl::removeInstructionsBeforeUnreachable(Instruction &I) { // Try to remove the previous instruction if it must lead to unreachable. // This includes instructions like stores and "llvm.assume" that may not get // removed by simple dead code elimination. + bool Changed = false; while (Instruction *Prev = I.getPrevNonDebugInstruction()) { // While we theoretically can erase EH, that would result in a block that // used to start with an EH no longer starting with EH, which is invalid. // To make it valid, we'd need to fixup predecessors to no longer refer to // this block, but that changes CFG, which is not allowed in InstCombine. if (Prev->isEHPad()) - return nullptr; // Can not drop any more instructions. We're done here. + break; // Can not drop any more instructions. We're done here. if (!isGuaranteedToTransferExecutionToSuccessor(Prev)) - return nullptr; // Can not drop any more instructions. We're done here. + break; // Can not drop any more instructions. We're done here. // Otherwise, this instruction can be freely erased, // even if it is not side-effect free. @@ -3140,9 +2707,13 @@ Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { // another unreachable block), so convert those to poison. replaceInstUsesWith(*Prev, PoisonValue::get(Prev->getType())); eraseInstFromFunction(*Prev); + Changed = true; } - assert(I.getParent()->sizeWithoutDebug() == 1 && "The block is now empty."); - // FIXME: recurse into unconditional predecessors? + return Changed; +} + +Instruction *InstCombinerImpl::visitUnreachableInst(UnreachableInst &I) { + removeInstructionsBeforeUnreachable(I); return nullptr; } @@ -3175,6 +2746,57 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { return nullptr; } +// Under the assumption that I is unreachable, remove it and following +// instructions. +bool InstCombinerImpl::handleUnreachableFrom(Instruction *I) { + bool Changed = false; + BasicBlock *BB = I->getParent(); + for (Instruction &Inst : make_early_inc_range( + make_range(std::next(BB->getTerminator()->getReverseIterator()), + std::next(I->getReverseIterator())))) { + if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) { + replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType())); + Changed = true; + } + if (Inst.isEHPad() || Inst.getType()->isTokenTy()) + continue; + eraseInstFromFunction(Inst); + Changed = true; + } + + // Replace phi node operands in successor blocks with poison. + for (BasicBlock *Succ : successors(BB)) + for (PHINode &PN : Succ->phis()) + for (Use &U : PN.incoming_values()) + if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) { + replaceUse(U, PoisonValue::get(PN.getType())); + addToWorklist(&PN); + Changed = true; + } + + // TODO: Successor blocks may also be dead. + return Changed; +} + +bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB, + BasicBlock *LiveSucc) { + bool Changed = false; + for (BasicBlock *Succ : successors(BB)) { + // The live successor isn't dead. + if (Succ == LiveSucc) + continue; + + if (!all_of(predecessors(Succ), [&](BasicBlock *Pred) { + return DT.dominates(BasicBlockEdge(BB, Succ), + BasicBlockEdge(Pred, Succ)); + })) + continue; + + Changed |= handleUnreachableFrom(&Succ->front()); + } + return Changed; +} + Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { if (BI.isUnconditional()) return visitUnconditionalBranchInst(BI); @@ -3218,6 +2840,14 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return &BI; } + if (isa<UndefValue>(Cond) && + handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr)) + return &BI; + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + if (handlePotentiallyDeadSuccessors(BI.getParent(), + BI.getSuccessor(!CI->getZExtValue()))) + return &BI; + return nullptr; } @@ -3236,6 +2866,14 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, Op0); } + if (isa<UndefValue>(Cond) && + handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr)) + return &SI; + if (auto *CI = dyn_cast<ConstantInt>(Cond)) + if (handlePotentiallyDeadSuccessors( + SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor())) + return &SI; + KnownBits Known = computeKnownBits(Cond, 0, &SI); unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); @@ -3243,10 +2881,10 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { // Compute the number of leading bits we can ignore. // TODO: A better way to determine this would use ComputeNumSignBits(). for (const auto &C : SI.cases()) { - LeadingKnownZeros = std::min( - LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); - LeadingKnownOnes = std::min( - LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes()); + LeadingKnownZeros = + std::min(LeadingKnownZeros, C.getCaseValue()->getValue().countl_zero()); + LeadingKnownOnes = + std::min(LeadingKnownOnes, C.getCaseValue()->getValue().countl_one()); } unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); @@ -3412,6 +3050,11 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { return R; if (LoadInst *L = dyn_cast<LoadInst>(Agg)) { + // Bail out if the aggregate contains scalable vector type + if (auto *STy = dyn_cast<StructType>(Agg->getType()); + STy && STy->containsScalableVectorType()) + return nullptr; + // If the (non-volatile) load only has one use, we can rewrite this to a // load from a GEP. This reduces the size of the load. If a load is used // only by extractvalue instructions then this either must have been @@ -3965,6 +3608,17 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { return Changed; } +// Check if any direct or bitcast user of this value is a shuffle instruction. +static bool isUsedWithinShuffleVector(Value *V) { + for (auto *U : V->users()) { + if (isa<ShuffleVectorInst>(U)) + return true; + else if (match(U, m_BitCast(m_Specific(V))) && isUsedWithinShuffleVector(U)) + return true; + } + return false; +} + Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { Value *Op0 = I.getOperand(0); @@ -4014,8 +3668,14 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { return BestValue; }; - if (match(Op0, m_Undef())) + if (match(Op0, m_Undef())) { + // Don't fold freeze(undef/poison) if it's used as a vector operand in + // a shuffle. This may improve codegen for shuffles that allow + // unspecified inputs. + if (isUsedWithinShuffleVector(&I)) + return nullptr; return replaceInstUsesWith(I, getUndefReplacement(I.getType())); + } Constant *C; if (match(Op0, m_Constant(C)) && C->containsUndefOrPoisonElement()) { @@ -4078,8 +3738,8 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) { /// beginning of DestBlock, which can only happen if it's safe to move the /// instruction past all of the instructions between it and the end of its /// block. -static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, - TargetLibraryInfo &TLI) { +bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, + BasicBlock *DestBlock) { BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -4126,10 +3786,13 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, return false; } - I->dropDroppableUses([DestBlock](const Use *U) { - if (auto *I = dyn_cast<Instruction>(U->getUser())) - return I->getParent() != DestBlock; - return true; + I->dropDroppableUses([&](const Use *U) { + auto *I = dyn_cast<Instruction>(U->getUser()); + if (I && I->getParent() != DestBlock) { + Worklist.add(I); + return true; + } + return false; }); /// FIXME: We could remove droppable uses that are not dominated by /// the new position. @@ -4227,23 +3890,6 @@ bool InstCombinerImpl::run() { if (!DebugCounter::shouldExecute(VisitCounter)) continue; - // Instruction isn't dead, see if we can constant propagate it. - if (!I->use_empty() && - (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { - if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { - LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I - << '\n'); - - // Add operands to the worklist. - replaceInstUsesWith(*I, C); - ++NumConstProp; - if (isInstructionTriviallyDead(I, &TLI)) - eraseInstFromFunction(*I); - MadeIRChange = true; - continue; - } - } - // See if we can trivially sink this instruction to its user if we can // prove that the successor is not executed more frequently than our block. // Return the UserBlock if successful. @@ -4319,7 +3965,7 @@ bool InstCombinerImpl::run() { if (OptBB) { auto *UserParent = *OptBB; // Okay, the CFG is simple enough, try to sink this instruction. - if (TryToSinkInstruction(I, UserParent, TLI)) { + if (tryToSinkInstruction(I, UserParent)) { LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); MadeIRChange = true; // We'll add uses of the sunk instruction below, but since @@ -4520,15 +4166,21 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Recursively visit successors. If this is a branch or switch on a // constant, only visit the reachable successor. Instruction *TI = BB->getTerminator(); - if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { - if (BI->isConditional() && isa<ConstantInt>(BI->getCondition())) { - bool CondVal = cast<ConstantInt>(BI->getCondition())->getZExtValue(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI); BI && BI->isConditional()) { + if (isa<UndefValue>(BI->getCondition())) + // Branch on undef is UB. + continue; + if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) { + bool CondVal = Cond->getZExtValue(); BasicBlock *ReachableBB = BI->getSuccessor(!CondVal); Worklist.push_back(ReachableBB); continue; } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { + if (isa<UndefValue>(SI->getCondition())) + // Switch on undef is UB. + continue; + if (auto *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } @@ -4584,7 +4236,6 @@ static bool combineInstructionsOverFunction( DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { auto &DL = F.getParent()->getDataLayout(); - MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue()); /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. @@ -4601,13 +4252,6 @@ static bool combineInstructionsOverFunction( bool MadeIRChange = false; if (ShouldLowerDbgDeclare) MadeIRChange = LowerDbgDeclare(F); - // LowerDbgDeclare calls RemoveRedundantDbgInstrs, but LowerDbgDeclare will - // almost never return true when running an assignment tracking build. Take - // this opportunity to do some clean up for assignment tracking builds too. - if (!MadeIRChange && isAssignmentTrackingEnabled(*F.getParent())) { - for (auto &BB : F) - RemoveRedundantDbgInstrs(&BB); - } // Iterate while there is work to do. unsigned Iteration = 0; @@ -4643,13 +4287,29 @@ static bool combineInstructionsOverFunction( MadeIRChange = true; } + if (Iteration == 1) + ++NumOneIteration; + else if (Iteration == 2) + ++NumTwoIterations; + else if (Iteration == 3) + ++NumThreeIterations; + else + ++NumFourOrMoreIterations; + return MadeIRChange; } -InstCombinePass::InstCombinePass() : MaxIterations(LimitMaxIterations) {} +InstCombinePass::InstCombinePass(InstCombineOptions Opts) : Options(Opts) {} -InstCombinePass::InstCombinePass(unsigned MaxIterations) - : MaxIterations(MaxIterations) {} +void InstCombinePass::printPipeline( + raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) { + static_cast<PassInfoMixin<InstCombinePass> *>(this)->printPipeline( + OS, MapClassName2PassName); + OS << '<'; + OS << "max-iterations=" << Options.MaxIterations << ";"; + OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info"; + OS << '>'; +} PreservedAnalyses InstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { @@ -4659,7 +4319,11 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); + // TODO: Only use LoopInfo when the option is set. This requires that the + // callers in the pass pipeline explicitly set the option. auto *LI = AM.getCachedResult<LoopAnalysis>(F); + if (!LI && Options.UseLoopInfo) + LI = &AM.getResult<LoopAnalysis>(F); auto *AA = &AM.getResult<AAManager>(F); auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); @@ -4669,7 +4333,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, MaxIterations, LI)) + BFI, PSI, Options.MaxIterations, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -4718,18 +4382,13 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { nullptr; return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, MaxIterations, LI); + BFI, PSI, + InstCombineDefaultMaxIterations, LI); } char InstructionCombiningPass::ID = 0; -InstructionCombiningPass::InstructionCombiningPass() - : FunctionPass(ID), MaxIterations(InstCombineDefaultMaxIterations) { - initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); -} - -InstructionCombiningPass::InstructionCombiningPass(unsigned MaxIterations) - : FunctionPass(ID), MaxIterations(MaxIterations) { +InstructionCombiningPass::InstructionCombiningPass() : FunctionPass(ID) { initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); } @@ -4752,18 +4411,6 @@ void llvm::initializeInstCombine(PassRegistry &Registry) { initializeInstructionCombiningPassPass(Registry); } -void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { - initializeInstructionCombiningPassPass(*unwrap(R)); -} - FunctionPass *llvm::createInstructionCombiningPass() { return new InstructionCombiningPass(); } - -FunctionPass *llvm::createInstructionCombiningPass(unsigned MaxIterations) { - return new InstructionCombiningPass(MaxIterations); -} - -void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createInstructionCombiningPass()); -} |