diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2021-07-29 20:15:26 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2021-07-29 20:15:26 +0000 |
commit | 344a3780b2e33f6ca763666c380202b18aab72a3 (patch) | |
tree | f0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Analysis/InstructionSimplify.cpp | |
parent | b60736ec1405bb0a8dd40989f67ef4c93da068ab (diff) | |
download | src-344a3780b2e33f6ca763666c380202b18aab72a3.tar.gz src-344a3780b2e33f6ca763666c380202b18aab72a3.zip |
Vendor import of llvm-project main 88e66fa60ae5, the last commit beforevendor/llvm-project/llvmorg-13-init-16847-g88e66fa60ae5vendor/llvm-project/llvmorg-12.0.1-rc2-0-ge7dac564cd0evendor/llvm-project/llvmorg-12.0.1-0-gfed41342a82f
the upstream release/13.x branch was created.
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 1094 |
1 files changed, 686 insertions, 408 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index c40e5c36cdc7..23083bc8178e 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -17,7 +17,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/InstructionSimplify.h" + +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" @@ -26,6 +29,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -68,6 +72,8 @@ static Value *SimplifyCastInst(unsigned, Value *, Type *, const SimplifyQuery &, unsigned); static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &, unsigned); +static Value *SimplifySelectInst(Value *, Value *, Value *, + const SimplifyQuery &, unsigned); static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, Value *FalseVal) { @@ -185,12 +191,15 @@ static Value *handleOtherCmpSelSimplifications(Value *TCmp, Value *FCmp, // If the false value simplified to false, then the result of the compare // is equal to "Cond && TCmp". This also catches the case when the false // value simplified to false and the true value to true, returning "Cond". - if (match(FCmp, m_Zero())) + // Folding select to and/or isn't poison-safe in general; impliesPoison + // checks whether folding it does not convert a well-defined value into + // poison. + if (match(FCmp, m_Zero()) && impliesPoison(TCmp, Cond)) if (Value *V = SimplifyAndInst(Cond, TCmp, Q, MaxRecurse)) return V; // If the true value simplified to true, then the result of the compare // is equal to "Cond || FCmp". - if (match(TCmp, m_One())) + if (match(TCmp, m_One()) && impliesPoison(FCmp, Cond)) if (Value *V = SimplifyOrInst(Cond, FCmp, Q, MaxRecurse)) return V; // Finally, if the false value simplified to true and the true value to @@ -221,8 +230,8 @@ static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { // Otherwise, if the instruction is in the entry block and is not an invoke, // then it obviously dominates all phi nodes. - if (I->getParent() == &I->getFunction()->getEntryBlock() && - !isa<InvokeInst>(I) && !isa<CallBrInst>(I)) + if (I->getParent()->isEntryBlock() && !isa<InvokeInst>(I) && + !isa<CallBrInst>(I)) return true; return false; @@ -730,6 +739,11 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) return C; + // X - poison -> poison + // poison - X -> poison + if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1)) + return PoisonValue::get(Op0->getType()); + // X - undef -> undef // undef - X -> undef if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) @@ -865,6 +879,10 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) return C; + // X * poison -> poison + if (isa<PoisonValue>(Op1)) + return Op1; + // X * undef -> 0 // X * 0 -> 0 if (Q.isUndefValue(Op1) || match(Op1, m_Zero())) @@ -920,8 +938,11 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { /// Check for common or similar folds of integer division or integer remainder. /// This applies to all 4 opcodes (sdiv/udiv/srem/urem). -static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, - const SimplifyQuery &Q) { +static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const SimplifyQuery &Q) { + bool IsDiv = (Opcode == Instruction::SDiv || Opcode == Instruction::UDiv); + bool IsSigned = (Opcode == Instruction::SDiv || Opcode == Instruction::SRem); + Type *Ty = Op0->getType(); // X / undef -> poison @@ -948,6 +969,11 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, } } + // poison / X -> poison + // poison % X -> poison + if (isa<PoisonValue>(Op0)) + return Op0; + // undef / X -> 0 // undef % X -> 0 if (Q.isUndefValue(Op0)) @@ -973,6 +999,21 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return IsDiv ? Op0 : Constant::getNullValue(Ty); + // If X * Y does not overflow, then: + // X * Y / Y -> X + // X * Y % Y -> 0 + if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { + auto *Mul = cast<OverflowingBinaryOperator>(Op0); + // The multiplication can't overflow if it is defined not to, or if + // X == A / Y for some A. + if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || + (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)) || + (IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || + (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) { + return IsDiv ? X : Constant::getNullValue(Op0->getType()); + } + } + return nullptr; } @@ -1044,25 +1085,11 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, true, Q)) + if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q)) return V; bool IsSigned = Opcode == Instruction::SDiv; - // (X * Y) / Y -> X if the multiplication does not overflow. - Value *X; - if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { - auto *Mul = cast<OverflowingBinaryOperator>(Op0); - // If the Mul does not overflow, then we are good to go. - if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || - (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul))) - return X; - // If X has the form X = A / Y, then X * Y cannot overflow. - if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || - (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) - return X; - } - // (X rem Y) / Y -> 0 if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) @@ -1070,7 +1097,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, // (X /u C1) /u C2 -> 0 if C1 * C2 overflow ConstantInt *C1, *C2; - if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && + if (!IsSigned && match(Op0, m_UDiv(m_Value(), m_ConstantInt(C1))) && match(Op1, m_ConstantInt(C2))) { bool Overflow; (void)C1->getValue().umul_ov(C2->getValue(), Overflow); @@ -1102,7 +1129,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, false, Q)) + if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q)) return V; // (X % Y) % Y -> X % Y @@ -1209,8 +1236,7 @@ static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) { // Shifting by the bitwidth or more is undefined. if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) - if (CI->getValue().getLimitedValue() >= - CI->getType()->getScalarSizeInBits()) + if (CI->getValue().uge(CI->getType()->getScalarSizeInBits())) return true; // If all lanes of a vector shift are undefined the whole shift is. @@ -1229,10 +1255,15 @@ static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) { /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, - Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + Value *Op1, bool IsNSW, const SimplifyQuery &Q, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; + // poison shift by X -> poison + if (isa<PoisonValue>(Op0)) + return Op0; + // 0 shift by X -> 0 if (match(Op0, m_Zero())) return Constant::getNullValue(Op0->getType()); @@ -1263,16 +1294,31 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // If any bits in the shift amount make that value greater than or equal to // the number of bits in the type, the shift is undefined. - KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (Known.One.getLimitedValue() >= Known.getBitWidth()) + KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth())) return PoisonValue::get(Op0->getType()); // If all valid bits in the shift amount are known zero, the first operand is // unchanged. - unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth()); - if (Known.countMinTrailingZeros() >= NumValidShiftBits) + unsigned NumValidShiftBits = Log2_32_Ceil(KnownAmt.getBitWidth()); + if (KnownAmt.countMinTrailingZeros() >= NumValidShiftBits) return Op0; + // Check for nsw shl leading to a poison value. + if (IsNSW) { + assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction"); + KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt); + + if (KnownVal.Zero.isSignBitSet()) + KnownShl.Zero.setSignBit(); + if (KnownVal.One.isSignBitSet()) + KnownShl.One.setSignBit(); + + if (KnownShl.hasConflict()) + return PoisonValue::get(Op0->getType()); + } + return nullptr; } @@ -1281,7 +1327,8 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) + if (Value *V = + SimplifyShift(Opcode, Op0, Op1, /*IsNSW*/ false, Q, MaxRecurse)) return V; // X >> X -> 0 @@ -1307,7 +1354,8 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, /// If not, this returns null. static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse)) + if (Value *V = + SimplifyShift(Instruction::Shl, Op0, Op1, isNSW, Q, MaxRecurse)) return V; // undef << X -> 0 @@ -1928,77 +1976,6 @@ static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, return nullptr; } -/// Check that the Op1 is in expected form, i.e.: -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1, - Value *X) { - auto *Extract = dyn_cast<ExtractValueInst>(Op1); - // We should only be extracting the overflow bit. - if (!Extract || !Extract->getIndices().equals(1)) - return false; - Value *Agg = Extract->getAggregateOperand(); - // This should be a multiplication-with-overflow intrinsic. - if (!match(Agg, m_CombineOr(m_Intrinsic<Intrinsic::umul_with_overflow>(), - m_Intrinsic<Intrinsic::smul_with_overflow>()))) - return false; - // One of its multipliers should be the value we checked for zero before. - if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)), - m_Argument<1>(m_Specific(X))))) - return false; - return true; -} - -/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some -/// other form of check, e.g. one that was using division; it may have been -/// guarded against division-by-zero. We can drop that check now. -/// Look for: -/// %Op0 = icmp ne i4 %X, 0 -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -/// %??? = and i1 %Op0, %Op1 -/// We can just return %Op1 -static Value *omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1) { - ICmpInst::Predicate Pred; - Value *X; - if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || - Pred != ICmpInst::Predicate::ICMP_NE) - return nullptr; - // Is Op1 in expected form? - if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) - return nullptr; - // Can omit 'and', and just return the overflow bit. - return Op1; -} - -/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some -/// other form of check, e.g. one that was using division; it may have been -/// guarded against division-by-zero. We can drop that check now. -/// Look for: -/// %Op0 = icmp eq i4 %X, 0 -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -/// %NotOp1 = xor i1 %Op1, true -/// %or = or i1 %Op0, %NotOp1 -/// We can just return %NotOp1 -static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0, - Value *NotOp1) { - ICmpInst::Predicate Pred; - Value *X; - if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || - Pred != ICmpInst::Predicate::ICMP_EQ) - return nullptr; - // We expect the other hand of an 'or' to be a 'not'. - Value *Op1; - if (!match(NotOp1, m_Not(m_Value(Op1)))) - return nullptr; - // Is Op1 in expected form? - if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) - return nullptr; - // Can omit 'and', and just return the inverted overflow bit. - return NotOp1; -} - /// Given a bitwise logic op, check if the operands are add/sub with a common /// source value and inverted constant (identity: C - X -> ~(X + ~C)). static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, @@ -2030,6 +2007,10 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) return C; + // X & poison -> poison + if (isa<PoisonValue>(Op1)) + return Op1; + // X & undef -> 0 if (Q.isUndefValue(Op1)) return Constant::getNullValue(Op0->getType()); @@ -2083,10 +2064,10 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If we have a multiplication overflow check that is being 'and'ed with a // check that one of the multipliers is not zero, we can omit the 'and', and // only keep the overflow check. - if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op0, Op1)) - return V; - if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op1, Op0)) - return V; + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true)) + return Op1; + if (isCheckForZeroAndMulWithOverflow(Op1, Op0, true)) + return Op0; // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || @@ -2198,6 +2179,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) return C; + // X | poison -> poison + if (isa<PoisonValue>(Op1)) + return Op1; + // X | undef -> -1 // X | -1 = -1 // Do not return Op1 because it may contain undef elements if it's a vector. @@ -2297,10 +2282,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If we have a multiplication overflow check that is being 'and'ed with a // check that one of the multipliers is not zero, we can omit the 'and', and // only keep the overflow check. - if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1)) - return V; - if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0)) - return V; + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, false)) + return Op1; + if (isCheckForZeroAndMulWithOverflow(Op1, Op0, false)) + return Op0; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, @@ -2469,10 +2454,14 @@ static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, // area, it may be possible to update LLVM's semantics accordingly and reinstate // this optimization. static Constant * -computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT, CmpInst::Predicate Pred, - AssumptionCache *AC, const Instruction *CxtI, - const InstrInfoQuery &IIQ, Value *LHS, Value *RHS) { +computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + const DataLayout &DL = Q.DL; + const TargetLibraryInfo *TLI = Q.TLI; + const DominatorTree *DT = Q.DT; + const Instruction *CxtI = Q.CxtI; + const InstrInfoQuery &IIQ = Q.IIQ; + // First, skip past any trivial no-ops. LHS = LHS->stripPointerCasts(); RHS = RHS->stripPointerCasts(); @@ -3395,6 +3384,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Type *ITy = GetCompareTy(LHS); // The return type. + // icmp poison, X -> poison + if (isa<PoisonValue>(RHS)) + return PoisonValue::get(ITy); + // For EQ and NE, we can always pick a value for the undef to make the // predicate pass or fail, so we can return undef. // Matches behavior in llvm::ConstantFoldCompareInstruction. @@ -3409,6 +3402,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) return V; + // TODO: Sink/common this with other potentially expensive calls that use + // ValueTracking? See comment below for isKnownNonEqual(). if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q)) return V; @@ -3428,13 +3423,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, auto LHS_CR = getConstantRangeFromMetadata( *LHS_Instr->getMetadata(LLVMContext::MD_range)); - auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); - if (Satisfied_CR.contains(LHS_CR)) + if (LHS_CR.icmp(Pred, RHS_CR)) return ConstantInt::getTrue(RHS->getContext()); - auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( - CmpInst::getInversePredicate(Pred), RHS_CR); - if (InversedSatisfied_CR.contains(LHS_CR)) + if (LHS_CR.icmp(CmpInst::getInversePredicate(Pred), RHS_CR)) return ConstantInt::getFalse(RHS->getContext()); } } @@ -3617,7 +3609,9 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // icmp eq|ne X, Y -> false|true if X != Y - if (ICmpInst::isEquality(Pred) && + // This is potentially expensive, and we have already computedKnownBits for + // compares with 0 above here, so only try this for a non-zero compare. + if (ICmpInst::isEquality(Pred) && !match(RHS, m_Zero()) && isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) { return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy); } @@ -3634,8 +3628,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) - if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, - Q.IIQ, LHS, RHS)) + if (auto *C = computePointerICmp(Pred, LHS, RHS, Q)) return C; if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) @@ -3643,9 +3636,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.DL.getTypeSizeInBits(CLHS->getType()) && Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == Q.DL.getTypeSizeInBits(CRHS->getType())) - if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, - Q.IIQ, CLHS->getPointerOperand(), - CRHS->getPointerOperand())) + if (auto *C = computePointerICmp(Pred, CLHS->getPointerOperand(), + CRHS->getPointerOperand(), Q)) return C; if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) { @@ -3728,6 +3720,11 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (match(RHS, m_NaN())) return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + // fcmp pred x, poison and fcmp pred poison, x + // fold to poison + if (isa<PoisonValue>(LHS) || isa<PoisonValue>(RHS)) + return PoisonValue::get(RetTy); + // fcmp pred x, undef and fcmp pred undef, x // fold to true if unordered, false if ordered if (Q.isUndefValue(LHS) || Q.isUndefValue(RHS)) { @@ -3896,10 +3893,12 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } -static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, +static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement, unsigned MaxRecurse) { + assert(!Op->getType()->isVectorTy() && "This is not safe for vectors"); + // Trivial replacement. if (V == Op) return RepOp; @@ -3909,109 +3908,110 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return nullptr; auto *I = dyn_cast<Instruction>(V); - if (!I) - return nullptr; - - // Consider: - // %cmp = icmp eq i32 %x, 2147483647 - // %add = add nsw i32 %x, 1 - // %sel = select i1 %cmp, i32 -2147483648, i32 %add - // - // We can't replace %sel with %add unless we strip away the flags (which will - // be done in InstCombine). - // TODO: This is unsound, because it only catches some forms of refinement. - if (!AllowRefinement && canCreatePoison(cast<Operator>(I))) + if (!I || !is_contained(I->operands(), Op)) return nullptr; - // The simplification queries below may return the original value. Consider: - // %div = udiv i32 %arg, %arg2 - // %mul = mul nsw i32 %div, %arg2 - // %cmp = icmp eq i32 %mul, %arg - // %sel = select i1 %cmp, i32 %div, i32 undef - // Replacing %arg by %mul, %div becomes "udiv i32 %mul, %arg2", which - // simplifies back to %arg. This can only happen because %mul does not - // dominate %div. To ensure a consistent return value contract, we make sure - // that this case returns nullptr as well. - auto PreventSelfSimplify = [V](Value *Simplified) { - return Simplified != V ? Simplified : nullptr; - }; - - // If this is a binary operator, try to simplify it with the replaced op. - if (auto *B = dyn_cast<BinaryOperator>(I)) { - if (MaxRecurse) { - if (B->getOperand(0) == Op) - return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), RepOp, - B->getOperand(1), Q, - MaxRecurse - 1)); - if (B->getOperand(1) == Op) - return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), - B->getOperand(0), RepOp, Q, - MaxRecurse - 1)); + // Replace Op with RepOp in instruction operands. + SmallVector<Value *, 8> NewOps(I->getNumOperands()); + transform(I->operands(), NewOps.begin(), + [&](Value *V) { return V == Op ? RepOp : V; }); + + if (!AllowRefinement) { + // General InstSimplify functions may refine the result, e.g. by returning + // a constant for a potentially poison value. To avoid this, implement only + // a few non-refining but profitable transforms here. + + if (auto *BO = dyn_cast<BinaryOperator>(I)) { + unsigned Opcode = BO->getOpcode(); + // id op x -> x, x op id -> x + if (NewOps[0] == ConstantExpr::getBinOpIdentity(Opcode, I->getType())) + return NewOps[1]; + if (NewOps[1] == ConstantExpr::getBinOpIdentity(Opcode, I->getType(), + /* RHS */ true)) + return NewOps[0]; + + // x & x -> x, x | x -> x + if ((Opcode == Instruction::And || Opcode == Instruction::Or) && + NewOps[0] == NewOps[1]) + return NewOps[0]; } - } - // Same for CmpInsts. - if (CmpInst *C = dyn_cast<CmpInst>(I)) { - if (MaxRecurse) { - if (C->getOperand(0) == Op) - return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), RepOp, - C->getOperand(1), Q, - MaxRecurse - 1)); - if (C->getOperand(1) == Op) - return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), - C->getOperand(0), RepOp, Q, - MaxRecurse - 1)); + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + // getelementptr x, 0 -> x + if (NewOps.size() == 2 && match(NewOps[1], m_Zero()) && + !GEP->isInBounds()) + return NewOps[0]; } - } + } else if (MaxRecurse) { + // The simplification queries below may return the original value. Consider: + // %div = udiv i32 %arg, %arg2 + // %mul = mul nsw i32 %div, %arg2 + // %cmp = icmp eq i32 %mul, %arg + // %sel = select i1 %cmp, i32 %div, i32 undef + // Replacing %arg by %mul, %div becomes "udiv i32 %mul, %arg2", which + // simplifies back to %arg. This can only happen because %mul does not + // dominate %div. To ensure a consistent return value contract, we make sure + // that this case returns nullptr as well. + auto PreventSelfSimplify = [V](Value *Simplified) { + return Simplified != V ? Simplified : nullptr; + }; + + if (auto *B = dyn_cast<BinaryOperator>(I)) + return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), NewOps[0], + NewOps[1], Q, MaxRecurse - 1)); - // Same for GEPs. - if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { - if (MaxRecurse) { - SmallVector<Value *, 8> NewOps(GEP->getNumOperands()); - transform(GEP->operands(), NewOps.begin(), - [&](Value *V) { return V == Op ? RepOp : V; }); + if (CmpInst *C = dyn_cast<CmpInst>(I)) + return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), NewOps[0], + NewOps[1], Q, MaxRecurse - 1)); + + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) return PreventSelfSimplify(SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, MaxRecurse - 1)); - } - } - // TODO: We could hand off more cases to instsimplify here. + if (isa<SelectInst>(I)) + return PreventSelfSimplify( + SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q, + MaxRecurse - 1)); + // TODO: We could hand off more cases to instsimplify here. + } // If all operands are constant after substituting Op for RepOp then we can // constant fold the instruction. - if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) { - // Build a list of all constant operands. - SmallVector<Constant *, 8> ConstOps; - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - if (I->getOperand(i) == Op) - ConstOps.push_back(CRepOp); - else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i))) - ConstOps.push_back(COp); - else - break; - } + SmallVector<Constant *, 8> ConstOps; + for (Value *NewOp : NewOps) { + if (Constant *ConstOp = dyn_cast<Constant>(NewOp)) + ConstOps.push_back(ConstOp); + else + return nullptr; + } - // All operands were constants, fold it. - if (ConstOps.size() == I->getNumOperands()) { - if (CmpInst *C = dyn_cast<CmpInst>(I)) - return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], - ConstOps[1], Q.DL, Q.TLI); + // Consider: + // %cmp = icmp eq i32 %x, 2147483647 + // %add = add nsw i32 %x, 1 + // %sel = select i1 %cmp, i32 -2147483648, i32 %add + // + // We can't replace %sel with %add unless we strip away the flags (which + // will be done in InstCombine). + // TODO: This may be unsound, because it only catches some forms of + // refinement. + if (!AllowRefinement && canCreatePoison(cast<Operator>(I))) + return nullptr; - if (LoadInst *LI = dyn_cast<LoadInst>(I)) - if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); + if (CmpInst *C = dyn_cast<CmpInst>(I)) + return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], + ConstOps[1], Q.DL, Q.TLI); - return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); - } - } + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + if (!LI->isVolatile()) + return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); - return nullptr; + return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); } -Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, +Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement) { - return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, + return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, RecursionLimit); } @@ -4127,21 +4127,23 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, TrueVal, FalseVal)) return V; - // If we have an equality comparison, then we know the value in one of the - // arms of the select. See if substituting this value into the arm and + // If we have a scalar equality comparison, then we know the value in one of + // the arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. - if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + // Note that the equivalence/replacement opportunity does not hold for vectors + // because each element of a vector select is chosen independently. + if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) { + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ false, MaxRecurse) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, /* AllowRefinement */ false, MaxRecurse) == TrueVal) return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ true, MaxRecurse) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, /* AllowRefinement */ true, MaxRecurse) == FalseVal) return FalseVal; @@ -4190,17 +4192,21 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, if (auto *FalseC = dyn_cast<Constant>(FalseVal)) return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + // select poison, X, Y -> poison + if (isa<PoisonValue>(CondC)) + return PoisonValue::get(TrueVal->getType()); + // select undef, X, Y -> X or Y if (Q.isUndefValue(CondC)) return isa<Constant>(FalseVal) ? FalseVal : TrueVal; - // TODO: Vector constants with undef elements don't simplify. - - // select true, X, Y -> X - if (CondC->isAllOnesValue()) + // select true, X, Y --> X + // select false, X, Y --> Y + // For vectors, allow undef/poison elements in the condition to match the + // defined elements, so we can eliminate the select. + if (match(CondC, m_One())) return TrueVal; - // select false, X, Y -> Y - if (CondC->isNullValue()) + if (match(CondC, m_Zero())) return FalseVal; } @@ -4217,15 +4223,20 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, if (TrueVal == FalseVal) return TrueVal; + // If the true or false value is poison, we can fold to the other value. // If the true or false value is undef, we can fold to the other value as // long as the other value isn't poison. - // select ?, undef, X -> X - if (Q.isUndefValue(TrueVal) && - isGuaranteedNotToBeUndefOrPoison(FalseVal, Q.AC, Q.CxtI, Q.DT)) + // select ?, poison, X -> X + // select ?, undef, X -> X + if (isa<PoisonValue>(TrueVal) || + (Q.isUndefValue(TrueVal) && + isGuaranteedNotToBePoison(FalseVal, Q.AC, Q.CxtI, Q.DT))) return FalseVal; - // select ?, X, undef -> X - if (Q.isUndefValue(FalseVal) && - isGuaranteedNotToBeUndefOrPoison(TrueVal, Q.AC, Q.CxtI, Q.DT)) + // select ?, X, poison -> X + // select ?, X, undef -> X + if (isa<PoisonValue>(FalseVal) || + (Q.isUndefValue(FalseVal) && + isGuaranteedNotToBePoison(TrueVal, Q.AC, Q.CxtI, Q.DT))) return TrueVal; // Deal with partial undef vector constants: select ?, VecC, VecC' --> VecC'' @@ -4247,11 +4258,11 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, // one element is undef, choose the defined element as the safe result. if (TEltC == FEltC) NewC.push_back(TEltC); - else if (Q.isUndefValue(TEltC) && - isGuaranteedNotToBeUndefOrPoison(FEltC)) + else if (isa<PoisonValue>(TEltC) || + (Q.isUndefValue(TEltC) && isGuaranteedNotToBePoison(FEltC))) NewC.push_back(FEltC); - else if (Q.isUndefValue(FEltC) && - isGuaranteedNotToBeUndefOrPoison(TEltC)) + else if (isa<PoisonValue>(FEltC) || + (Q.isUndefValue(FEltC) && isGuaranteedNotToBePoison(TEltC))) NewC.push_back(TEltC); else break; @@ -4297,10 +4308,14 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, // Compute the (pointer) type returned by the GEP instruction. Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Ops.slice(1)); Type *GEPTy = PointerType::get(LastType, AS); - if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) - GEPTy = VectorType::get(GEPTy, VT->getElementCount()); - else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType())) - GEPTy = VectorType::get(GEPTy, VT->getElementCount()); + for (Value *Op : Ops) { + // If one of the operands is a vector, the result type is a vector of + // pointers. All vector operands must have the same number of elements. + if (VectorType *VT = dyn_cast<VectorType>(Op->getType())) { + GEPTy = VectorType::get(GEPTy, VT->getElementCount()); + break; + } + } // getelementptr poison, idx -> poison // getelementptr baseptr, poison -> poison @@ -4310,7 +4325,10 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, if (Q.isUndefValue(Ops[0])) return UndefValue::get(GEPTy); - bool IsScalableVec = isa<ScalableVectorType>(SrcTy); + bool IsScalableVec = + isa<ScalableVectorType>(SrcTy) || any_of(Ops, [](const Value *V) { + return isa<ScalableVectorType>(V->getType()); + }); if (Ops.size() == 2) { // getelementptr P, 0 -> P. @@ -4330,40 +4348,32 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, // doesn't truncate the pointers. if (Ops[1]->getType()->getScalarSizeInBits() == Q.DL.getPointerSizeInBits(AS)) { - auto PtrToInt = [GEPTy](Value *P) -> Value * { - Value *Temp; - if (match(P, m_PtrToInt(m_Value(Temp)))) - if (Temp->getType() == GEPTy) - return Temp; - return nullptr; + auto CanSimplify = [GEPTy, &P, V = Ops[0]]() -> bool { + return P->getType() == GEPTy && + getUnderlyingObject(P) == getUnderlyingObject(V); }; - - // FIXME: The following transforms are only legal if P and V have the - // same provenance (PR44403). Check whether getUnderlyingObject() is - // the same? - // getelementptr V, (sub P, V) -> P if P points to a type of size 1. if (TyAllocSize == 1 && - match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) - if (Value *R = PtrToInt(P)) - return R; - - // getelementptr V, (ashr (sub P, V), C) -> Q - // if P points to a type of size 1 << C. - if (match(Ops[1], - m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), - m_ConstantInt(C))) && - TyAllocSize == 1ULL << C) - if (Value *R = PtrToInt(P)) - return R; - - // getelementptr V, (sdiv (sub P, V), C) -> Q - // if P points to a type of size C. - if (match(Ops[1], - m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), - m_SpecificInt(TyAllocSize)))) - if (Value *R = PtrToInt(P)) - return R; + match(Ops[1], m_Sub(m_PtrToInt(m_Value(P)), + m_PtrToInt(m_Specific(Ops[0])))) && + CanSimplify()) + return P; + + // getelementptr V, (ashr (sub P, V), C) -> P if P points to a type of + // size 1 << C. + if (match(Ops[1], m_AShr(m_Sub(m_PtrToInt(m_Value(P)), + m_PtrToInt(m_Specific(Ops[0]))), + m_ConstantInt(C))) && + TyAllocSize == 1ULL << C && CanSimplify()) + return P; + + // getelementptr V, (sdiv (sub P, V), C) -> P if P points to a type of + // size C. + if (match(Ops[1], m_SDiv(m_Sub(m_PtrToInt(m_Value(P)), + m_PtrToInt(m_Specific(Ops[0]))), + m_SpecificInt(TyAllocSize))) && + CanSimplify()) + return P; } } } @@ -4523,30 +4533,33 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, if (auto *CIdx = dyn_cast<Constant>(Idx)) return ConstantExpr::getExtractElement(CVec, CIdx); - // The index is not relevant if our vector is a splat. - if (auto *Splat = CVec->getSplatValue()) - return Splat; - if (Q.isUndefValue(Vec)) return UndefValue::get(VecVTy->getElementType()); } + // An undef extract index can be arbitrarily chosen to be an out-of-range + // index value, which would result in the instruction being poison. + if (Q.isUndefValue(Idx)) + return PoisonValue::get(VecVTy->getElementType()); + // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) { // For fixed-length vector, fold into undef if index is out of bounds. - if (isa<FixedVectorType>(VecVTy) && - IdxC->getValue().uge(cast<FixedVectorType>(VecVTy)->getNumElements())) + unsigned MinNumElts = VecVTy->getElementCount().getKnownMinValue(); + if (isa<FixedVectorType>(VecVTy) && IdxC->getValue().uge(MinNumElts)) return PoisonValue::get(VecVTy->getElementType()); + // Handle case where an element is extracted from a splat. + if (IdxC->getValue().ult(MinNumElts)) + if (auto *Splat = getSplatValue(Vec)) + return Splat; if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) return Elt; + } else { + // The index is not relevant if our vector is a splat. + if (Value *Splat = getSplatValue(Vec)) + return Splat; } - - // An undef extract index can be arbitrarily chosen to be an out-of-range - // index value, which would result in the instruction being poison. - if (Q.isUndefValue(Idx)) - return PoisonValue::get(VecVTy->getElementType()); - return nullptr; } @@ -4556,7 +4569,8 @@ Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, } /// See if we can fold the given phi. If not, returns null. -static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { +static Value *SimplifyPHINode(PHINode *PN, ArrayRef<Value *> IncomingValues, + const SimplifyQuery &Q) { // WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE // here, because the PHI we may succeed simplifying to was not // def-reachable from the original PHI! @@ -4565,7 +4579,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { // with the common value. Value *CommonValue = nullptr; bool HasUndefInput = false; - for (Value *Incoming : PN->incoming_values()) { + for (Value *Incoming : IncomingValues) { // If the incoming value is the phi node itself, it can safely be skipped. if (Incoming == PN) continue; if (Q.isUndefValue(Incoming)) { @@ -4842,11 +4856,17 @@ static Constant *propagateNaN(Constant *In) { } /// Perform folds that are common to any floating-point operation. This implies -/// transforms based on undef/NaN because the operation itself makes no +/// transforms based on poison/undef/NaN because the operation itself makes no /// difference to the result. -static Constant *simplifyFPOp(ArrayRef<Value *> Ops, - FastMathFlags FMF, - const SimplifyQuery &Q) { +static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF, + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + // Poison is independent of anything else. It always propagates from an + // operand to a math result. + if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); })) + return PoisonValue::get(Ops[0]->getType()); + for (Value *V : Ops) { bool IsNan = match(V, m_NaN()); bool IsInf = match(V, m_Inf()); @@ -4860,22 +4880,34 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, if (FMF.noInfs() && (IsInf || IsUndef)) return PoisonValue::get(V->getType()); - if (IsUndef || IsNan) - return propagateNaN(cast<Constant>(V)); + if (isDefaultFPEnvironment(ExBehavior, Rounding)) { + if (IsUndef || IsNan) + return propagateNaN(cast<Constant>(V)); + } else if (ExBehavior != fp::ebStrict) { + if (IsNan) + return propagateNaN(cast<Constant>(V)); + } } return nullptr; } /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) - return C; +static Value * +SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, + fp::ExceptionBehavior ExBehavior = fp::ebIgnore, + RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + if (isDefaultFPEnvironment(ExBehavior, Rounding)) + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) + return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding)) return C; + if (!isDefaultFPEnvironment(ExBehavior, Rounding)) + return nullptr; + // fadd X, -0 ==> X if (match(Op1, m_NegZeroFP())) return Op0; @@ -4915,14 +4947,21 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) - return C; +static Value * +SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, + fp::ExceptionBehavior ExBehavior = fp::ebIgnore, + RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + if (isDefaultFPEnvironment(ExBehavior, Rounding)) + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding)) return C; + if (!isDefaultFPEnvironment(ExBehavior, Rounding)) + return nullptr; + // fsub X, +0 ==> X if (match(Op1, m_PosZeroFP())) return Op0; @@ -4961,10 +5000,15 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, } static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) + const SimplifyQuery &Q, unsigned MaxRecurse, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding)) return C; + if (!isDefaultFPEnvironment(ExBehavior, Rounding)) + return nullptr; + // fmul X, 1.0 ==> X if (match(Op1, m_FPOne())) return Op0; @@ -4994,43 +5038,65 @@ static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, } /// Given the operands for an FMul, see if we can fold the result -static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) - return C; +static Value * +SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, + fp::ExceptionBehavior ExBehavior = fp::ebIgnore, + RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + if (isDefaultFPEnvironment(ExBehavior, Rounding)) + if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) + return C; // Now apply simplifications that do not require rounding. - return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse); + return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse, ExBehavior, Rounding); } Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + Rounding); } - Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + Rounding); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + Rounding); } Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit); -} + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + Rounding); +} + +static Value * +SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned, + fp::ExceptionBehavior ExBehavior = fp::ebIgnore, + RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + if (isDefaultFPEnvironment(ExBehavior, Rounding)) + if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + return C; -static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned) { - if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) - return C; + if (!isDefaultFPEnvironment(ExBehavior, Rounding)) + return nullptr; // X / 1.0 -> X if (match(Op1, m_FPOne())) @@ -5065,17 +5131,27 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, } Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit); -} + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + Rounding); +} + +static Value * +SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned, + fp::ExceptionBehavior ExBehavior = fp::ebIgnore, + RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + if (isDefaultFPEnvironment(ExBehavior, Rounding)) + if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + return C; -static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned) { - if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) - return C; + if (!isDefaultFPEnvironment(ExBehavior, Rounding)) + return nullptr; // Unlike fdiv, the result of frem always matches the sign of the dividend. // The constant match may include undef elements in a vector, so return a full @@ -5093,8 +5169,11 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, } Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit); + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior, + RoundingMode Rounding) { + return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, + Rounding); } //=== Helper functions for higher up the class hierarchy. @@ -5373,6 +5452,12 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, return Op0; break; } + case Intrinsic::experimental_vector_reverse: + // experimental.vector.reverse(experimental.vector.reverse(x)) -> x + if (match(Op0, + m_Intrinsic<Intrinsic::experimental_vector_reverse>(m_Value(X)))) + return X; + break; default: break; } @@ -5380,16 +5465,6 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, return nullptr; } -static Intrinsic::ID getMaxMinOpposite(Intrinsic::ID IID) { - switch (IID) { - case Intrinsic::smax: return Intrinsic::smin; - case Intrinsic::smin: return Intrinsic::smax; - case Intrinsic::umax: return Intrinsic::umin; - case Intrinsic::umin: return Intrinsic::umax; - default: llvm_unreachable("Unexpected intrinsic"); - } -} - static APInt getMaxMinLimit(Intrinsic::ID IID, unsigned BitWidth) { switch (IID) { case Intrinsic::smax: return APInt::getSignedMaxValue(BitWidth); @@ -5429,7 +5504,7 @@ static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) { if (IID0 == IID) return MM0; // max (min X, Y), X --> X - if (IID0 == getMaxMinOpposite(IID)) + if (IID0 == getInverseMinMaxIntrinsic(IID)) return Op1; } return nullptr; @@ -5449,6 +5524,20 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return Op0; break; + case Intrinsic::cttz: { + Value *X; + if (match(Op0, m_Shl(m_One(), m_Value(X)))) + return X; + break; + } + case Intrinsic::ctlz: { + Value *X; + if (match(Op0, m_LShr(m_Negative(), m_Value(X)))) + return X; + if (match(Op0, m_AShr(m_Negative(), m_Value()))) + return Constant::getNullValue(ReturnType); + break; + } case Intrinsic::smax: case Intrinsic::smin: case Intrinsic::umax: @@ -5475,7 +5564,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // If the constant op is the opposite of the limit value, the other must // be larger/smaller or equal. For example: // umin(i8 %x, i8 255) --> %x - if (*C == getMaxMinLimit(getMaxMinOpposite(IID), BitWidth)) + if (*C == getMaxMinLimit(getInverseMinMaxIntrinsic(IID), BitWidth)) return Op0; // Remove nested call if constant operands allow it. Example: @@ -5661,6 +5750,19 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, break; } + case Intrinsic::experimental_vector_extract: { + Type *ReturnType = F->getReturnType(); + + // (extract_vector (insert_vector _, X, 0), 0) -> X + unsigned IdxN = cast<ConstantInt>(Op1)->getZExtValue(); + Value *X = nullptr; + if (match(Op0, m_Intrinsic<Intrinsic::experimental_vector_insert>( + m_Value(), m_Value(X), m_Zero())) && + IdxN == 0 && X->getType() == ReturnType) + return X; + + break; + } default: break; } @@ -5717,15 +5819,115 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } return nullptr; } + case Intrinsic::experimental_constrained_fma: { + Value *Op0 = Call->getArgOperand(0); + Value *Op1 = Call->getArgOperand(1); + Value *Op2 = Call->getArgOperand(2); + auto *FPI = cast<ConstrainedFPIntrinsic>(Call); + if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, + FPI->getExceptionBehavior().getValue(), + FPI->getRoundingMode().getValue())) + return V; + return nullptr; + } case Intrinsic::fma: case Intrinsic::fmuladd: { Value *Op0 = Call->getArgOperand(0); Value *Op1 = Call->getArgOperand(1); Value *Op2 = Call->getArgOperand(2); - if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }, {}, Q)) + if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, fp::ebIgnore, + RoundingMode::NearestTiesToEven)) return V; return nullptr; } + case Intrinsic::smul_fix: + case Intrinsic::smul_fix_sat: { + Value *Op0 = Call->getArgOperand(0); + Value *Op1 = Call->getArgOperand(1); + Value *Op2 = Call->getArgOperand(2); + Type *ReturnType = F->getReturnType(); + + // Canonicalize constant operand as Op1 (ConstantFolding handles the case + // when both Op0 and Op1 are constant so we do not care about that special + // case here). + if (isa<Constant>(Op0)) + std::swap(Op0, Op1); + + // X * 0 -> 0 + if (match(Op1, m_Zero())) + return Constant::getNullValue(ReturnType); + + // X * undef -> 0 + if (Q.isUndefValue(Op1)) + return Constant::getNullValue(ReturnType); + + // X * (1 << Scale) -> X + APInt ScaledOne = + APInt::getOneBitSet(ReturnType->getScalarSizeInBits(), + cast<ConstantInt>(Op2)->getZExtValue()); + if (ScaledOne.isNonNegative() && match(Op1, m_SpecificInt(ScaledOne))) + return Op0; + + return nullptr; + } + case Intrinsic::experimental_vector_insert: { + Value *Vec = Call->getArgOperand(0); + Value *SubVec = Call->getArgOperand(1); + Value *Idx = Call->getArgOperand(2); + Type *ReturnType = F->getReturnType(); + + // (insert_vector Y, (extract_vector X, 0), 0) -> X + // where: Y is X, or Y is undef + unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); + Value *X = nullptr; + if (match(SubVec, m_Intrinsic<Intrinsic::experimental_vector_extract>( + m_Value(X), m_Zero())) && + (Q.isUndefValue(Vec) || Vec == X) && IdxN == 0 && + X->getType() == ReturnType) + return X; + + return nullptr; + } + case Intrinsic::experimental_constrained_fadd: { + auto *FPI = cast<ConstrainedFPIntrinsic>(Call); + return SimplifyFAddInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + FPI->getFastMathFlags(), Q, + FPI->getExceptionBehavior().getValue(), + FPI->getRoundingMode().getValue()); + break; + } + case Intrinsic::experimental_constrained_fsub: { + auto *FPI = cast<ConstrainedFPIntrinsic>(Call); + return SimplifyFSubInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + FPI->getFastMathFlags(), Q, + FPI->getExceptionBehavior().getValue(), + FPI->getRoundingMode().getValue()); + break; + } + case Intrinsic::experimental_constrained_fmul: { + auto *FPI = cast<ConstrainedFPIntrinsic>(Call); + return SimplifyFMulInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + FPI->getFastMathFlags(), Q, + FPI->getExceptionBehavior().getValue(), + FPI->getRoundingMode().getValue()); + break; + } + case Intrinsic::experimental_constrained_fdiv: { + auto *FPI = cast<ConstrainedFPIntrinsic>(Call); + return SimplifyFDivInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + FPI->getFastMathFlags(), Q, + FPI->getExceptionBehavior().getValue(), + FPI->getRoundingMode().getValue()); + break; + } + case Intrinsic::experimental_constrained_frem: { + auto *FPI = cast<ConstrainedFPIntrinsic>(Call); + return SimplifyFRemInst(FPI->getArgOperand(0), FPI->getArgOperand(1), + FPI->getFastMathFlags(), Q, + FPI->getExceptionBehavior().getValue(), + FPI->getRoundingMode().getValue()); + break; + } default: return nullptr; } @@ -5788,162 +5990,223 @@ Value *llvm::SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { return ::SimplifyFreezeInst(Op0, Q); } +static Constant *ConstructLoadOperandConstant(Value *Op) { + SmallVector<Value *, 4> Worklist; + // Invalid IR in unreachable code may contain self-referential values. Don't infinitely loop. + SmallPtrSet<Value *, 4> Visited; + Worklist.push_back(Op); + while (true) { + Value *CurOp = Worklist.back(); + if (!Visited.insert(CurOp).second) + return nullptr; + if (isa<Constant>(CurOp)) + break; + if (auto *BC = dyn_cast<BitCastOperator>(CurOp)) { + Worklist.push_back(BC->getOperand(0)); + } else if (auto *GEP = dyn_cast<GEPOperator>(CurOp)) { + for (unsigned I = 1; I != GEP->getNumOperands(); ++I) { + if (!isa<Constant>(GEP->getOperand(I))) + return nullptr; + } + Worklist.push_back(GEP->getOperand(0)); + } else if (auto *II = dyn_cast<IntrinsicInst>(CurOp)) { + if (II->isLaunderOrStripInvariantGroup()) + Worklist.push_back(II->getOperand(0)); + else + return nullptr; + } else { + return nullptr; + } + } + + Constant *NewOp = cast<Constant>(Worklist.pop_back_val()); + while (!Worklist.empty()) { + Value *CurOp = Worklist.pop_back_val(); + if (isa<BitCastOperator>(CurOp)) { + NewOp = ConstantExpr::getBitCast(NewOp, CurOp->getType()); + } else if (auto *GEP = dyn_cast<GEPOperator>(CurOp)) { + SmallVector<Constant *> Idxs; + Idxs.reserve(GEP->getNumOperands() - 1); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) { + Idxs.push_back(cast<Constant>(GEP->getOperand(I))); + } + NewOp = ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), NewOp, + Idxs, GEP->isInBounds(), + GEP->getInRangeIndex()); + } else { + assert(isa<IntrinsicInst>(CurOp) && + cast<IntrinsicInst>(CurOp)->isLaunderOrStripInvariantGroup() && + "expected invariant group intrinsic"); + NewOp = ConstantExpr::getBitCast(NewOp, CurOp->getType()); + } + } + return NewOp; +} + +static Value *SimplifyLoadInst(LoadInst *LI, Value *PtrOp, + const SimplifyQuery &Q) { + if (LI->isVolatile()) + return nullptr; + + // Try to make the load operand a constant, specifically handle + // invariant.group intrinsics. + auto *PtrOpC = dyn_cast<Constant>(PtrOp); + if (!PtrOpC) + PtrOpC = ConstructLoadOperandConstant(PtrOp); + + if (PtrOpC) + return ConstantFoldLoadFromConstPtr(PtrOpC, LI->getType(), Q.DL); + + return nullptr; +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +static Value *simplifyInstructionWithOperands(Instruction *I, + ArrayRef<Value *> NewOps, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); - Value *Result; + Value *Result = nullptr; switch (I->getOpcode()) { default: - Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); + if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) { + SmallVector<Constant *, 8> NewConstOps(NewOps.size()); + transform(NewOps, NewConstOps.begin(), + [](Value *V) { return cast<Constant>(V); }); + Result = ConstantFoldInstOperands(I, NewConstOps, Q.DL, Q.TLI); + } break; case Instruction::FNeg: - Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); + Result = SimplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Add: - Result = - SimplifyAddInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + Result = SimplifyAddInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::FSub: - Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Sub: - Result = - SimplifySubInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + Result = SimplifySubInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::FMul: - Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Mul: - Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyMulInst(NewOps[0], NewOps[1], Q); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifySDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyUDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::FDiv: - Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifySRemInst(NewOps[0], NewOps[1], Q); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyURemInst(NewOps[0], NewOps[1], Q); break; case Instruction::FRem: - Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Shl: - Result = - SimplifyShlInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), - Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); + Result = SimplifyShlInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::LShr: - Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), + Result = SimplifyLShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); break; case Instruction::AShr: - Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), + Result = SimplifyAShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); break; case Instruction::And: - Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyAndInst(NewOps[0], NewOps[1], Q); break; case Instruction::Or: - Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyOrInst(NewOps[0], NewOps[1], Q); break; case Instruction::Xor: - Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyXorInst(NewOps[0], NewOps[1], Q); break; case Instruction::ICmp: - Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0], + NewOps[1], Q); break; case Instruction::FCmp: - Result = - SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0), - I->getOperand(1), I->getFastMathFlags(), Q); + Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0], + NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Select: - Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), Q); + Result = SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); break; case Instruction::GetElementPtr: { - SmallVector<Value *, 8> Ops(I->operands()); Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), - Ops, Q); + NewOps, Q); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); - Result = SimplifyInsertValueInst(IV->getAggregateOperand(), - IV->getInsertedValueOperand(), - IV->getIndices(), Q); + Result = SimplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q); break; } case Instruction::InsertElement: { - auto *IE = cast<InsertElementInst>(I); - Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1), - IE->getOperand(2), Q); + Result = SimplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q); break; } case Instruction::ExtractValue: { auto *EVI = cast<ExtractValueInst>(I); - Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), - EVI->getIndices(), Q); + Result = SimplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q); break; } case Instruction::ExtractElement: { - auto *EEI = cast<ExtractElementInst>(I); - Result = SimplifyExtractElementInst(EEI->getVectorOperand(), - EEI->getIndexOperand(), Q); + Result = SimplifyExtractElementInst(NewOps[0], NewOps[1], Q); break; } case Instruction::ShuffleVector: { auto *SVI = cast<ShuffleVectorInst>(I); - Result = - SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), - SVI->getShuffleMask(), SVI->getType(), Q); + Result = SimplifyShuffleVectorInst( + NewOps[0], NewOps[1], SVI->getShuffleMask(), SVI->getType(), Q); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), Q); + Result = SimplifyPHINode(cast<PHINode>(I), NewOps, Q); break; case Instruction::Call: { + // TODO: Use NewOps Result = SimplifyCall(cast<CallInst>(I), Q); break; } case Instruction::Freeze: - Result = SimplifyFreezeInst(I->getOperand(0), Q); + Result = llvm::SimplifyFreezeInst(NewOps[0], Q); break; #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = - SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + Result = SimplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. Result = nullptr; break; + case Instruction::Load: + Result = SimplifyLoadInst(cast<LoadInst>(I), NewOps[0], Q); + break; } /// If called on unreachable code, the above logic may report that the @@ -5952,6 +6215,21 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, return Result == I ? UndefValue::get(I->getType()) : Result; } +Value *llvm::SimplifyInstructionWithOperands(Instruction *I, + ArrayRef<Value *> NewOps, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + assert(NewOps.size() == I->getNumOperands() && + "Number of operands should match the instruction!"); + return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE); +} + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + SmallVector<Value *, 8> Ops(I->operands()); + return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE); +} + /// Implementation of recursive simplification through an instruction's /// uses. /// |