aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/InstructionSimplify.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
commit344a3780b2e33f6ca763666c380202b18aab72a3 (patch)
treef0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Analysis/InstructionSimplify.cpp
parentb60736ec1405bb0a8dd40989f67ef4c93da068ab (diff)
downloadsrc-344a3780b2e33f6ca763666c380202b18aab72a3.tar.gz
src-344a3780b2e33f6ca763666c380202b18aab72a3.zip
the upstream release/13.x branch was created.
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp1094
1 files changed, 686 insertions, 408 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index c40e5c36cdc7..23083bc8178e 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -17,7 +17,10 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/InstructionSimplify.h"
+
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -26,6 +29,7 @@
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/OverflowInstAnalysis.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
@@ -68,6 +72,8 @@ static Value *SimplifyCastInst(unsigned, Value *, Type *,
const SimplifyQuery &, unsigned);
static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &,
unsigned);
+static Value *SimplifySelectInst(Value *, Value *, Value *,
+ const SimplifyQuery &, unsigned);
static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
Value *FalseVal) {
@@ -185,12 +191,15 @@ static Value *handleOtherCmpSelSimplifications(Value *TCmp, Value *FCmp,
// If the false value simplified to false, then the result of the compare
// is equal to "Cond && TCmp". This also catches the case when the false
// value simplified to false and the true value to true, returning "Cond".
- if (match(FCmp, m_Zero()))
+ // Folding select to and/or isn't poison-safe in general; impliesPoison
+ // checks whether folding it does not convert a well-defined value into
+ // poison.
+ if (match(FCmp, m_Zero()) && impliesPoison(TCmp, Cond))
if (Value *V = SimplifyAndInst(Cond, TCmp, Q, MaxRecurse))
return V;
// If the true value simplified to true, then the result of the compare
// is equal to "Cond || FCmp".
- if (match(TCmp, m_One()))
+ if (match(TCmp, m_One()) && impliesPoison(FCmp, Cond))
if (Value *V = SimplifyOrInst(Cond, FCmp, Q, MaxRecurse))
return V;
// Finally, if the false value simplified to true and the true value to
@@ -221,8 +230,8 @@ static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {
// Otherwise, if the instruction is in the entry block and is not an invoke,
// then it obviously dominates all phi nodes.
- if (I->getParent() == &I->getFunction()->getEntryBlock() &&
- !isa<InvokeInst>(I) && !isa<CallBrInst>(I))
+ if (I->getParent()->isEntryBlock() && !isa<InvokeInst>(I) &&
+ !isa<CallBrInst>(I))
return true;
return false;
@@ -730,6 +739,11 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q))
return C;
+ // X - poison -> poison
+ // poison - X -> poison
+ if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1))
+ return PoisonValue::get(Op0->getType());
+
// X - undef -> undef
// undef - X -> undef
if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1))
@@ -865,6 +879,10 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
return C;
+ // X * poison -> poison
+ if (isa<PoisonValue>(Op1))
+ return Op1;
+
// X * undef -> 0
// X * 0 -> 0
if (Q.isUndefValue(Op1) || match(Op1, m_Zero()))
@@ -920,8 +938,11 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
/// Check for common or similar folds of integer division or integer remainder.
/// This applies to all 4 opcodes (sdiv/udiv/srem/urem).
-static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv,
- const SimplifyQuery &Q) {
+static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
+ Value *Op1, const SimplifyQuery &Q) {
+ bool IsDiv = (Opcode == Instruction::SDiv || Opcode == Instruction::UDiv);
+ bool IsSigned = (Opcode == Instruction::SDiv || Opcode == Instruction::SRem);
+
Type *Ty = Op0->getType();
// X / undef -> poison
@@ -948,6 +969,11 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv,
}
}
+ // poison / X -> poison
+ // poison % X -> poison
+ if (isa<PoisonValue>(Op0))
+ return Op0;
+
// undef / X -> 0
// undef % X -> 0
if (Q.isUndefValue(Op0))
@@ -973,6 +999,21 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv,
(match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))
return IsDiv ? Op0 : Constant::getNullValue(Ty);
+ // If X * Y does not overflow, then:
+ // X * Y / Y -> X
+ // X * Y % Y -> 0
+ if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) {
+ auto *Mul = cast<OverflowingBinaryOperator>(Op0);
+ // The multiplication can't overflow if it is defined not to, or if
+ // X == A / Y for some A.
+ if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) ||
+ (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)) ||
+ (IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) ||
+ (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) {
+ return IsDiv ? X : Constant::getNullValue(Op0->getType());
+ }
+ }
+
return nullptr;
}
@@ -1044,25 +1085,11 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
- if (Value *V = simplifyDivRem(Op0, Op1, true, Q))
+ if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q))
return V;
bool IsSigned = Opcode == Instruction::SDiv;
- // (X * Y) / Y -> X if the multiplication does not overflow.
- Value *X;
- if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) {
- auto *Mul = cast<OverflowingBinaryOperator>(Op0);
- // If the Mul does not overflow, then we are good to go.
- if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) ||
- (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)))
- return X;
- // If X has the form X = A / Y, then X * Y cannot overflow.
- if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) ||
- (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1)))))
- return X;
- }
-
// (X rem Y) / Y -> 0
if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) ||
(!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1)))))
@@ -1070,7 +1097,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
// (X /u C1) /u C2 -> 0 if C1 * C2 overflow
ConstantInt *C1, *C2;
- if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) &&
+ if (!IsSigned && match(Op0, m_UDiv(m_Value(), m_ConstantInt(C1))) &&
match(Op1, m_ConstantInt(C2))) {
bool Overflow;
(void)C1->getValue().umul_ov(C2->getValue(), Overflow);
@@ -1102,7 +1129,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
- if (Value *V = simplifyDivRem(Op0, Op1, false, Q))
+ if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q))
return V;
// (X % Y) % Y -> X % Y
@@ -1209,8 +1236,7 @@ static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) {
// Shifting by the bitwidth or more is undefined.
if (ConstantInt *CI = dyn_cast<ConstantInt>(C))
- if (CI->getValue().getLimitedValue() >=
- CI->getType()->getScalarSizeInBits())
+ if (CI->getValue().uge(CI->getType()->getScalarSizeInBits()))
return true;
// If all lanes of a vector shift are undefined the whole shift is.
@@ -1229,10 +1255,15 @@ static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) {
/// Given operands for an Shl, LShr or AShr, see if we can fold the result.
/// If not, this returns null.
static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
- Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) {
+ Value *Op1, bool IsNSW, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
+ // poison shift by X -> poison
+ if (isa<PoisonValue>(Op0))
+ return Op0;
+
// 0 shift by X -> 0
if (match(Op0, m_Zero()))
return Constant::getNullValue(Op0->getType());
@@ -1263,16 +1294,31 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
// If any bits in the shift amount make that value greater than or equal to
// the number of bits in the type, the shift is undefined.
- KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
- if (Known.One.getLimitedValue() >= Known.getBitWidth())
+ KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
return PoisonValue::get(Op0->getType());
// If all valid bits in the shift amount are known zero, the first operand is
// unchanged.
- unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth());
- if (Known.countMinTrailingZeros() >= NumValidShiftBits)
+ unsigned NumValidShiftBits = Log2_32_Ceil(KnownAmt.getBitWidth());
+ if (KnownAmt.countMinTrailingZeros() >= NumValidShiftBits)
return Op0;
+ // Check for nsw shl leading to a poison value.
+ if (IsNSW) {
+ assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
+ KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);
+
+ if (KnownVal.Zero.isSignBitSet())
+ KnownShl.Zero.setSignBit();
+ if (KnownVal.One.isSignBitSet())
+ KnownShl.One.setSignBit();
+
+ if (KnownShl.hasConflict())
+ return PoisonValue::get(Op0->getType());
+ }
+
return nullptr;
}
@@ -1281,7 +1327,8 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
Value *Op1, bool isExact, const SimplifyQuery &Q,
unsigned MaxRecurse) {
- if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse))
+ if (Value *V =
+ SimplifyShift(Opcode, Op0, Op1, /*IsNSW*/ false, Q, MaxRecurse))
return V;
// X >> X -> 0
@@ -1307,7 +1354,8 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
/// If not, this returns null.
static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse))
+ if (Value *V =
+ SimplifyShift(Instruction::Shl, Op0, Op1, isNSW, Q, MaxRecurse))
return V;
// undef << X -> 0
@@ -1928,77 +1976,6 @@ static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q,
return nullptr;
}
-/// Check that the Op1 is in expected form, i.e.:
-/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???)
-/// %Op1 = extractvalue { i4, i1 } %Agg, 1
-static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1,
- Value *X) {
- auto *Extract = dyn_cast<ExtractValueInst>(Op1);
- // We should only be extracting the overflow bit.
- if (!Extract || !Extract->getIndices().equals(1))
- return false;
- Value *Agg = Extract->getAggregateOperand();
- // This should be a multiplication-with-overflow intrinsic.
- if (!match(Agg, m_CombineOr(m_Intrinsic<Intrinsic::umul_with_overflow>(),
- m_Intrinsic<Intrinsic::smul_with_overflow>())))
- return false;
- // One of its multipliers should be the value we checked for zero before.
- if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)),
- m_Argument<1>(m_Specific(X)))))
- return false;
- return true;
-}
-
-/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some
-/// other form of check, e.g. one that was using division; it may have been
-/// guarded against division-by-zero. We can drop that check now.
-/// Look for:
-/// %Op0 = icmp ne i4 %X, 0
-/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???)
-/// %Op1 = extractvalue { i4, i1 } %Agg, 1
-/// %??? = and i1 %Op0, %Op1
-/// We can just return %Op1
-static Value *omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1) {
- ICmpInst::Predicate Pred;
- Value *X;
- if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) ||
- Pred != ICmpInst::Predicate::ICMP_NE)
- return nullptr;
- // Is Op1 in expected form?
- if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X))
- return nullptr;
- // Can omit 'and', and just return the overflow bit.
- return Op1;
-}
-
-/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some
-/// other form of check, e.g. one that was using division; it may have been
-/// guarded against division-by-zero. We can drop that check now.
-/// Look for:
-/// %Op0 = icmp eq i4 %X, 0
-/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???)
-/// %Op1 = extractvalue { i4, i1 } %Agg, 1
-/// %NotOp1 = xor i1 %Op1, true
-/// %or = or i1 %Op0, %NotOp1
-/// We can just return %NotOp1
-static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0,
- Value *NotOp1) {
- ICmpInst::Predicate Pred;
- Value *X;
- if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) ||
- Pred != ICmpInst::Predicate::ICMP_EQ)
- return nullptr;
- // We expect the other hand of an 'or' to be a 'not'.
- Value *Op1;
- if (!match(NotOp1, m_Not(m_Value(Op1))))
- return nullptr;
- // Is Op1 in expected form?
- if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X))
- return nullptr;
- // Can omit 'and', and just return the inverted overflow bit.
- return NotOp1;
-}
-
/// Given a bitwise logic op, check if the operands are add/sub with a common
/// source value and inverted constant (identity: C - X -> ~(X + ~C)).
static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1,
@@ -2030,6 +2007,10 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q))
return C;
+ // X & poison -> poison
+ if (isa<PoisonValue>(Op1))
+ return Op1;
+
// X & undef -> 0
if (Q.isUndefValue(Op1))
return Constant::getNullValue(Op0->getType());
@@ -2083,10 +2064,10 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// If we have a multiplication overflow check that is being 'and'ed with a
// check that one of the multipliers is not zero, we can omit the 'and', and
// only keep the overflow check.
- if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op0, Op1))
- return V;
- if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op1, Op0))
- return V;
+ if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true))
+ return Op1;
+ if (isCheckForZeroAndMulWithOverflow(Op1, Op0, true))
+ return Op0;
// A & (-A) = A if A is a power of two or zero.
if (match(Op0, m_Neg(m_Specific(Op1))) ||
@@ -2198,6 +2179,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q))
return C;
+ // X | poison -> poison
+ if (isa<PoisonValue>(Op1))
+ return Op1;
+
// X | undef -> -1
// X | -1 = -1
// Do not return Op1 because it may contain undef elements if it's a vector.
@@ -2297,10 +2282,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// If we have a multiplication overflow check that is being 'and'ed with a
// check that one of the multipliers is not zero, we can omit the 'and', and
// only keep the overflow check.
- if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1))
- return V;
- if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0))
- return V;
+ if (isCheckForZeroAndMulWithOverflow(Op0, Op1, false))
+ return Op1;
+ if (isCheckForZeroAndMulWithOverflow(Op1, Op0, false))
+ return Op0;
// Try some generic simplifications for associative operations.
if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q,
@@ -2469,10 +2454,14 @@ static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
// area, it may be possible to update LLVM's semantics accordingly and reinstate
// this optimization.
static Constant *
-computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI,
- const DominatorTree *DT, CmpInst::Predicate Pred,
- AssumptionCache *AC, const Instruction *CxtI,
- const InstrInfoQuery &IIQ, Value *LHS, Value *RHS) {
+computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
+ const SimplifyQuery &Q) {
+ const DataLayout &DL = Q.DL;
+ const TargetLibraryInfo *TLI = Q.TLI;
+ const DominatorTree *DT = Q.DT;
+ const Instruction *CxtI = Q.CxtI;
+ const InstrInfoQuery &IIQ = Q.IIQ;
+
// First, skip past any trivial no-ops.
LHS = LHS->stripPointerCasts();
RHS = RHS->stripPointerCasts();
@@ -3395,6 +3384,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
Type *ITy = GetCompareTy(LHS); // The return type.
+ // icmp poison, X -> poison
+ if (isa<PoisonValue>(RHS))
+ return PoisonValue::get(ITy);
+
// For EQ and NE, we can always pick a value for the undef to make the
// predicate pass or fail, so we can return undef.
// Matches behavior in llvm::ConstantFoldCompareInstruction.
@@ -3409,6 +3402,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q))
return V;
+ // TODO: Sink/common this with other potentially expensive calls that use
+ // ValueTracking? See comment below for isKnownNonEqual().
if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q))
return V;
@@ -3428,13 +3423,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
auto LHS_CR = getConstantRangeFromMetadata(
*LHS_Instr->getMetadata(LLVMContext::MD_range));
- auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
- if (Satisfied_CR.contains(LHS_CR))
+ if (LHS_CR.icmp(Pred, RHS_CR))
return ConstantInt::getTrue(RHS->getContext());
- auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion(
- CmpInst::getInversePredicate(Pred), RHS_CR);
- if (InversedSatisfied_CR.contains(LHS_CR))
+ if (LHS_CR.icmp(CmpInst::getInversePredicate(Pred), RHS_CR))
return ConstantInt::getFalse(RHS->getContext());
}
}
@@ -3617,7 +3609,9 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
}
// icmp eq|ne X, Y -> false|true if X != Y
- if (ICmpInst::isEquality(Pred) &&
+ // This is potentially expensive, and we have already computedKnownBits for
+ // compares with 0 above here, so only try this for a non-zero compare.
+ if (ICmpInst::isEquality(Pred) && !match(RHS, m_Zero()) &&
isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) {
return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy);
}
@@ -3634,8 +3628,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
// Simplify comparisons of related pointers using a powerful, recursive
// GEP-walk when we have target data available..
if (LHS->getType()->isPointerTy())
- if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI,
- Q.IIQ, LHS, RHS))
+ if (auto *C = computePointerICmp(Pred, LHS, RHS, Q))
return C;
if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS))
if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS))
@@ -3643,9 +3636,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
Q.DL.getTypeSizeInBits(CLHS->getType()) &&
Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) ==
Q.DL.getTypeSizeInBits(CRHS->getType()))
- if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI,
- Q.IIQ, CLHS->getPointerOperand(),
- CRHS->getPointerOperand()))
+ if (auto *C = computePointerICmp(Pred, CLHS->getPointerOperand(),
+ CRHS->getPointerOperand(), Q))
return C;
if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) {
@@ -3728,6 +3720,11 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
if (match(RHS, m_NaN()))
return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred));
+ // fcmp pred x, poison and fcmp pred poison, x
+ // fold to poison
+ if (isa<PoisonValue>(LHS) || isa<PoisonValue>(RHS))
+ return PoisonValue::get(RetTy);
+
// fcmp pred x, undef and fcmp pred undef, x
// fold to true if unordered, false if ordered
if (Q.isUndefValue(LHS) || Q.isUndefValue(RHS)) {
@@ -3896,10 +3893,12 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
}
-static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
unsigned MaxRecurse) {
+ assert(!Op->getType()->isVectorTy() && "This is not safe for vectors");
+
// Trivial replacement.
if (V == Op)
return RepOp;
@@ -3909,109 +3908,110 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
return nullptr;
auto *I = dyn_cast<Instruction>(V);
- if (!I)
- return nullptr;
-
- // Consider:
- // %cmp = icmp eq i32 %x, 2147483647
- // %add = add nsw i32 %x, 1
- // %sel = select i1 %cmp, i32 -2147483648, i32 %add
- //
- // We can't replace %sel with %add unless we strip away the flags (which will
- // be done in InstCombine).
- // TODO: This is unsound, because it only catches some forms of refinement.
- if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))
+ if (!I || !is_contained(I->operands(), Op))
return nullptr;
- // The simplification queries below may return the original value. Consider:
- // %div = udiv i32 %arg, %arg2
- // %mul = mul nsw i32 %div, %arg2
- // %cmp = icmp eq i32 %mul, %arg
- // %sel = select i1 %cmp, i32 %div, i32 undef
- // Replacing %arg by %mul, %div becomes "udiv i32 %mul, %arg2", which
- // simplifies back to %arg. This can only happen because %mul does not
- // dominate %div. To ensure a consistent return value contract, we make sure
- // that this case returns nullptr as well.
- auto PreventSelfSimplify = [V](Value *Simplified) {
- return Simplified != V ? Simplified : nullptr;
- };
-
- // If this is a binary operator, try to simplify it with the replaced op.
- if (auto *B = dyn_cast<BinaryOperator>(I)) {
- if (MaxRecurse) {
- if (B->getOperand(0) == Op)
- return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), RepOp,
- B->getOperand(1), Q,
- MaxRecurse - 1));
- if (B->getOperand(1) == Op)
- return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(),
- B->getOperand(0), RepOp, Q,
- MaxRecurse - 1));
+ // Replace Op with RepOp in instruction operands.
+ SmallVector<Value *, 8> NewOps(I->getNumOperands());
+ transform(I->operands(), NewOps.begin(),
+ [&](Value *V) { return V == Op ? RepOp : V; });
+
+ if (!AllowRefinement) {
+ // General InstSimplify functions may refine the result, e.g. by returning
+ // a constant for a potentially poison value. To avoid this, implement only
+ // a few non-refining but profitable transforms here.
+
+ if (auto *BO = dyn_cast<BinaryOperator>(I)) {
+ unsigned Opcode = BO->getOpcode();
+ // id op x -> x, x op id -> x
+ if (NewOps[0] == ConstantExpr::getBinOpIdentity(Opcode, I->getType()))
+ return NewOps[1];
+ if (NewOps[1] == ConstantExpr::getBinOpIdentity(Opcode, I->getType(),
+ /* RHS */ true))
+ return NewOps[0];
+
+ // x & x -> x, x | x -> x
+ if ((Opcode == Instruction::And || Opcode == Instruction::Or) &&
+ NewOps[0] == NewOps[1])
+ return NewOps[0];
}
- }
- // Same for CmpInsts.
- if (CmpInst *C = dyn_cast<CmpInst>(I)) {
- if (MaxRecurse) {
- if (C->getOperand(0) == Op)
- return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), RepOp,
- C->getOperand(1), Q,
- MaxRecurse - 1));
- if (C->getOperand(1) == Op)
- return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(),
- C->getOperand(0), RepOp, Q,
- MaxRecurse - 1));
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
+ // getelementptr x, 0 -> x
+ if (NewOps.size() == 2 && match(NewOps[1], m_Zero()) &&
+ !GEP->isInBounds())
+ return NewOps[0];
}
- }
+ } else if (MaxRecurse) {
+ // The simplification queries below may return the original value. Consider:
+ // %div = udiv i32 %arg, %arg2
+ // %mul = mul nsw i32 %div, %arg2
+ // %cmp = icmp eq i32 %mul, %arg
+ // %sel = select i1 %cmp, i32 %div, i32 undef
+ // Replacing %arg by %mul, %div becomes "udiv i32 %mul, %arg2", which
+ // simplifies back to %arg. This can only happen because %mul does not
+ // dominate %div. To ensure a consistent return value contract, we make sure
+ // that this case returns nullptr as well.
+ auto PreventSelfSimplify = [V](Value *Simplified) {
+ return Simplified != V ? Simplified : nullptr;
+ };
+
+ if (auto *B = dyn_cast<BinaryOperator>(I))
+ return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), NewOps[0],
+ NewOps[1], Q, MaxRecurse - 1));
- // Same for GEPs.
- if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
- if (MaxRecurse) {
- SmallVector<Value *, 8> NewOps(GEP->getNumOperands());
- transform(GEP->operands(), NewOps.begin(),
- [&](Value *V) { return V == Op ? RepOp : V; });
+ if (CmpInst *C = dyn_cast<CmpInst>(I))
+ return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), NewOps[0],
+ NewOps[1], Q, MaxRecurse - 1));
+
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
return PreventSelfSimplify(SimplifyGEPInst(GEP->getSourceElementType(),
NewOps, Q, MaxRecurse - 1));
- }
- }
- // TODO: We could hand off more cases to instsimplify here.
+ if (isa<SelectInst>(I))
+ return PreventSelfSimplify(
+ SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q,
+ MaxRecurse - 1));
+ // TODO: We could hand off more cases to instsimplify here.
+ }
// If all operands are constant after substituting Op for RepOp then we can
// constant fold the instruction.
- if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) {
- // Build a list of all constant operands.
- SmallVector<Constant *, 8> ConstOps;
- for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
- if (I->getOperand(i) == Op)
- ConstOps.push_back(CRepOp);
- else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i)))
- ConstOps.push_back(COp);
- else
- break;
- }
+ SmallVector<Constant *, 8> ConstOps;
+ for (Value *NewOp : NewOps) {
+ if (Constant *ConstOp = dyn_cast<Constant>(NewOp))
+ ConstOps.push_back(ConstOp);
+ else
+ return nullptr;
+ }
- // All operands were constants, fold it.
- if (ConstOps.size() == I->getNumOperands()) {
- if (CmpInst *C = dyn_cast<CmpInst>(I))
- return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
- ConstOps[1], Q.DL, Q.TLI);
+ // Consider:
+ // %cmp = icmp eq i32 %x, 2147483647
+ // %add = add nsw i32 %x, 1
+ // %sel = select i1 %cmp, i32 -2147483648, i32 %add
+ //
+ // We can't replace %sel with %add unless we strip away the flags (which
+ // will be done in InstCombine).
+ // TODO: This may be unsound, because it only catches some forms of
+ // refinement.
+ if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))
+ return nullptr;
- if (LoadInst *LI = dyn_cast<LoadInst>(I))
- if (!LI->isVolatile())
- return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL);
+ if (CmpInst *C = dyn_cast<CmpInst>(I))
+ return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
+ ConstOps[1], Q.DL, Q.TLI);
- return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
- }
- }
+ if (LoadInst *LI = dyn_cast<LoadInst>(I))
+ if (!LI->isVolatile())
+ return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL);
- return nullptr;
+ return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
}
-Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement) {
- return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+ return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
RecursionLimit);
}
@@ -4127,21 +4127,23 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
TrueVal, FalseVal))
return V;
- // If we have an equality comparison, then we know the value in one of the
- // arms of the select. See if substituting this value into the arm and
+ // If we have a scalar equality comparison, then we know the value in one of
+ // the arms of the select. See if substituting this value into the arm and
// simplifying the result yields the same value as the other arm.
- if (Pred == ICmpInst::ICMP_EQ) {
- if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+ // Note that the equivalence/replacement opportunity does not hold for vectors
+ // because each element of a vector select is chosen independently.
+ if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) {
+ if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ false, MaxRecurse) ==
TrueVal ||
- SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+ simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
/* AllowRefinement */ false, MaxRecurse) ==
TrueVal)
return FalseVal;
- if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
+ if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ true, MaxRecurse) ==
FalseVal ||
- SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
+ simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
/* AllowRefinement */ true, MaxRecurse) ==
FalseVal)
return FalseVal;
@@ -4190,17 +4192,21 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
if (auto *FalseC = dyn_cast<Constant>(FalseVal))
return ConstantFoldSelectInstruction(CondC, TrueC, FalseC);
+ // select poison, X, Y -> poison
+ if (isa<PoisonValue>(CondC))
+ return PoisonValue::get(TrueVal->getType());
+
// select undef, X, Y -> X or Y
if (Q.isUndefValue(CondC))
return isa<Constant>(FalseVal) ? FalseVal : TrueVal;
- // TODO: Vector constants with undef elements don't simplify.
-
- // select true, X, Y -> X
- if (CondC->isAllOnesValue())
+ // select true, X, Y --> X
+ // select false, X, Y --> Y
+ // For vectors, allow undef/poison elements in the condition to match the
+ // defined elements, so we can eliminate the select.
+ if (match(CondC, m_One()))
return TrueVal;
- // select false, X, Y -> Y
- if (CondC->isNullValue())
+ if (match(CondC, m_Zero()))
return FalseVal;
}
@@ -4217,15 +4223,20 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
if (TrueVal == FalseVal)
return TrueVal;
+ // If the true or false value is poison, we can fold to the other value.
// If the true or false value is undef, we can fold to the other value as
// long as the other value isn't poison.
- // select ?, undef, X -> X
- if (Q.isUndefValue(TrueVal) &&
- isGuaranteedNotToBeUndefOrPoison(FalseVal, Q.AC, Q.CxtI, Q.DT))
+ // select ?, poison, X -> X
+ // select ?, undef, X -> X
+ if (isa<PoisonValue>(TrueVal) ||
+ (Q.isUndefValue(TrueVal) &&
+ isGuaranteedNotToBePoison(FalseVal, Q.AC, Q.CxtI, Q.DT)))
return FalseVal;
- // select ?, X, undef -> X
- if (Q.isUndefValue(FalseVal) &&
- isGuaranteedNotToBeUndefOrPoison(TrueVal, Q.AC, Q.CxtI, Q.DT))
+ // select ?, X, poison -> X
+ // select ?, X, undef -> X
+ if (isa<PoisonValue>(FalseVal) ||
+ (Q.isUndefValue(FalseVal) &&
+ isGuaranteedNotToBePoison(TrueVal, Q.AC, Q.CxtI, Q.DT)))
return TrueVal;
// Deal with partial undef vector constants: select ?, VecC, VecC' --> VecC''
@@ -4247,11 +4258,11 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
// one element is undef, choose the defined element as the safe result.
if (TEltC == FEltC)
NewC.push_back(TEltC);
- else if (Q.isUndefValue(TEltC) &&
- isGuaranteedNotToBeUndefOrPoison(FEltC))
+ else if (isa<PoisonValue>(TEltC) ||
+ (Q.isUndefValue(TEltC) && isGuaranteedNotToBePoison(FEltC)))
NewC.push_back(FEltC);
- else if (Q.isUndefValue(FEltC) &&
- isGuaranteedNotToBeUndefOrPoison(TEltC))
+ else if (isa<PoisonValue>(FEltC) ||
+ (Q.isUndefValue(FEltC) && isGuaranteedNotToBePoison(TEltC)))
NewC.push_back(TEltC);
else
break;
@@ -4297,10 +4308,14 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,
// Compute the (pointer) type returned by the GEP instruction.
Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Ops.slice(1));
Type *GEPTy = PointerType::get(LastType, AS);
- if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType()))
- GEPTy = VectorType::get(GEPTy, VT->getElementCount());
- else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType()))
- GEPTy = VectorType::get(GEPTy, VT->getElementCount());
+ for (Value *Op : Ops) {
+ // If one of the operands is a vector, the result type is a vector of
+ // pointers. All vector operands must have the same number of elements.
+ if (VectorType *VT = dyn_cast<VectorType>(Op->getType())) {
+ GEPTy = VectorType::get(GEPTy, VT->getElementCount());
+ break;
+ }
+ }
// getelementptr poison, idx -> poison
// getelementptr baseptr, poison -> poison
@@ -4310,7 +4325,10 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,
if (Q.isUndefValue(Ops[0]))
return UndefValue::get(GEPTy);
- bool IsScalableVec = isa<ScalableVectorType>(SrcTy);
+ bool IsScalableVec =
+ isa<ScalableVectorType>(SrcTy) || any_of(Ops, [](const Value *V) {
+ return isa<ScalableVectorType>(V->getType());
+ });
if (Ops.size() == 2) {
// getelementptr P, 0 -> P.
@@ -4330,40 +4348,32 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,
// doesn't truncate the pointers.
if (Ops[1]->getType()->getScalarSizeInBits() ==
Q.DL.getPointerSizeInBits(AS)) {
- auto PtrToInt = [GEPTy](Value *P) -> Value * {
- Value *Temp;
- if (match(P, m_PtrToInt(m_Value(Temp))))
- if (Temp->getType() == GEPTy)
- return Temp;
- return nullptr;
+ auto CanSimplify = [GEPTy, &P, V = Ops[0]]() -> bool {
+ return P->getType() == GEPTy &&
+ getUnderlyingObject(P) == getUnderlyingObject(V);
};
-
- // FIXME: The following transforms are only legal if P and V have the
- // same provenance (PR44403). Check whether getUnderlyingObject() is
- // the same?
-
// getelementptr V, (sub P, V) -> P if P points to a type of size 1.
if (TyAllocSize == 1 &&
- match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0])))))
- if (Value *R = PtrToInt(P))
- return R;
-
- // getelementptr V, (ashr (sub P, V), C) -> Q
- // if P points to a type of size 1 << C.
- if (match(Ops[1],
- m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))),
- m_ConstantInt(C))) &&
- TyAllocSize == 1ULL << C)
- if (Value *R = PtrToInt(P))
- return R;
-
- // getelementptr V, (sdiv (sub P, V), C) -> Q
- // if P points to a type of size C.
- if (match(Ops[1],
- m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))),
- m_SpecificInt(TyAllocSize))))
- if (Value *R = PtrToInt(P))
- return R;
+ match(Ops[1], m_Sub(m_PtrToInt(m_Value(P)),
+ m_PtrToInt(m_Specific(Ops[0])))) &&
+ CanSimplify())
+ return P;
+
+ // getelementptr V, (ashr (sub P, V), C) -> P if P points to a type of
+ // size 1 << C.
+ if (match(Ops[1], m_AShr(m_Sub(m_PtrToInt(m_Value(P)),
+ m_PtrToInt(m_Specific(Ops[0]))),
+ m_ConstantInt(C))) &&
+ TyAllocSize == 1ULL << C && CanSimplify())
+ return P;
+
+ // getelementptr V, (sdiv (sub P, V), C) -> P if P points to a type of
+ // size C.
+ if (match(Ops[1], m_SDiv(m_Sub(m_PtrToInt(m_Value(P)),
+ m_PtrToInt(m_Specific(Ops[0]))),
+ m_SpecificInt(TyAllocSize))) &&
+ CanSimplify())
+ return P;
}
}
}
@@ -4523,30 +4533,33 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx,
if (auto *CIdx = dyn_cast<Constant>(Idx))
return ConstantExpr::getExtractElement(CVec, CIdx);
- // The index is not relevant if our vector is a splat.
- if (auto *Splat = CVec->getSplatValue())
- return Splat;
-
if (Q.isUndefValue(Vec))
return UndefValue::get(VecVTy->getElementType());
}
+ // An undef extract index can be arbitrarily chosen to be an out-of-range
+ // index value, which would result in the instruction being poison.
+ if (Q.isUndefValue(Idx))
+ return PoisonValue::get(VecVTy->getElementType());
+
// If extracting a specified index from the vector, see if we can recursively
// find a previously computed scalar that was inserted into the vector.
if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) {
// For fixed-length vector, fold into undef if index is out of bounds.
- if (isa<FixedVectorType>(VecVTy) &&
- IdxC->getValue().uge(cast<FixedVectorType>(VecVTy)->getNumElements()))
+ unsigned MinNumElts = VecVTy->getElementCount().getKnownMinValue();
+ if (isa<FixedVectorType>(VecVTy) && IdxC->getValue().uge(MinNumElts))
return PoisonValue::get(VecVTy->getElementType());
+ // Handle case where an element is extracted from a splat.
+ if (IdxC->getValue().ult(MinNumElts))
+ if (auto *Splat = getSplatValue(Vec))
+ return Splat;
if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue()))
return Elt;
+ } else {
+ // The index is not relevant if our vector is a splat.
+ if (Value *Splat = getSplatValue(Vec))
+ return Splat;
}
-
- // An undef extract index can be arbitrarily chosen to be an out-of-range
- // index value, which would result in the instruction being poison.
- if (Q.isUndefValue(Idx))
- return PoisonValue::get(VecVTy->getElementType());
-
return nullptr;
}
@@ -4556,7 +4569,8 @@ Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx,
}
/// See if we can fold the given phi. If not, returns null.
-static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) {
+static Value *SimplifyPHINode(PHINode *PN, ArrayRef<Value *> IncomingValues,
+ const SimplifyQuery &Q) {
// WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE
// here, because the PHI we may succeed simplifying to was not
// def-reachable from the original PHI!
@@ -4565,7 +4579,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) {
// with the common value.
Value *CommonValue = nullptr;
bool HasUndefInput = false;
- for (Value *Incoming : PN->incoming_values()) {
+ for (Value *Incoming : IncomingValues) {
// If the incoming value is the phi node itself, it can safely be skipped.
if (Incoming == PN) continue;
if (Q.isUndefValue(Incoming)) {
@@ -4842,11 +4856,17 @@ static Constant *propagateNaN(Constant *In) {
}
/// Perform folds that are common to any floating-point operation. This implies
-/// transforms based on undef/NaN because the operation itself makes no
+/// transforms based on poison/undef/NaN because the operation itself makes no
/// difference to the result.
-static Constant *simplifyFPOp(ArrayRef<Value *> Ops,
- FastMathFlags FMF,
- const SimplifyQuery &Q) {
+static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ // Poison is independent of anything else. It always propagates from an
+ // operand to a math result.
+ if (any_of(Ops, [](Value *V) { return match(V, m_Poison()); }))
+ return PoisonValue::get(Ops[0]->getType());
+
for (Value *V : Ops) {
bool IsNan = match(V, m_NaN());
bool IsInf = match(V, m_Inf());
@@ -4860,22 +4880,34 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops,
if (FMF.noInfs() && (IsInf || IsUndef))
return PoisonValue::get(V->getType());
- if (IsUndef || IsNan)
- return propagateNaN(cast<Constant>(V));
+ if (isDefaultFPEnvironment(ExBehavior, Rounding)) {
+ if (IsUndef || IsNan)
+ return propagateNaN(cast<Constant>(V));
+ } else if (ExBehavior != fp::ebStrict) {
+ if (IsNan)
+ return propagateNaN(cast<Constant>(V));
+ }
}
return nullptr;
}
/// Given operands for an FAdd, see if we can fold the result. If not, this
/// returns null.
-static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))
- return C;
+static Value *
+SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse,
+ fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
+ RoundingMode Rounding = RoundingMode::NearestTiesToEven) {
+ if (isDefaultFPEnvironment(ExBehavior, Rounding))
+ if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))
+ return C;
- if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))
+ if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding))
return C;
+ if (!isDefaultFPEnvironment(ExBehavior, Rounding))
+ return nullptr;
+
// fadd X, -0 ==> X
if (match(Op1, m_NegZeroFP()))
return Op0;
@@ -4915,14 +4947,21 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
/// Given operands for an FSub, see if we can fold the result. If not, this
/// returns null.
-static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))
- return C;
+static Value *
+SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse,
+ fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
+ RoundingMode Rounding = RoundingMode::NearestTiesToEven) {
+ if (isDefaultFPEnvironment(ExBehavior, Rounding))
+ if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))
+ return C;
- if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))
+ if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding))
return C;
+ if (!isDefaultFPEnvironment(ExBehavior, Rounding))
+ return nullptr;
+
// fsub X, +0 ==> X
if (match(Op1, m_PosZeroFP()))
return Op0;
@@ -4961,10 +5000,15 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
}
static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))
+ const SimplifyQuery &Q, unsigned MaxRecurse,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding))
return C;
+ if (!isDefaultFPEnvironment(ExBehavior, Rounding))
+ return nullptr;
+
// fmul X, 1.0 ==> X
if (match(Op1, m_FPOne()))
return Op0;
@@ -4994,43 +5038,65 @@ static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
}
/// Given the operands for an FMul, see if we can fold the result
-static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
- return C;
+static Value *
+SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse,
+ fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
+ RoundingMode Rounding = RoundingMode::NearestTiesToEven) {
+ if (isDefaultFPEnvironment(ExBehavior, Rounding))
+ if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+ return C;
// Now apply simplifications that do not require rounding.
- return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse);
+ return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse, ExBehavior, Rounding);
}
Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q) {
- return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit);
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior,
+ Rounding);
}
-
Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q) {
- return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit);
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior,
+ Rounding);
}
Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q) {
- return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit);
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior,
+ Rounding);
}
Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q) {
- return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit);
-}
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior,
+ Rounding);
+}
+
+static Value *
+SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned,
+ fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
+ RoundingMode Rounding = RoundingMode::NearestTiesToEven) {
+ if (isDefaultFPEnvironment(ExBehavior, Rounding))
+ if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
+ return C;
-static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned) {
- if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
+ if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding))
return C;
- if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))
- return C;
+ if (!isDefaultFPEnvironment(ExBehavior, Rounding))
+ return nullptr;
// X / 1.0 -> X
if (match(Op1, m_FPOne()))
@@ -5065,17 +5131,27 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
}
Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q) {
- return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit);
-}
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior,
+ Rounding);
+}
+
+static Value *
+SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned,
+ fp::ExceptionBehavior ExBehavior = fp::ebIgnore,
+ RoundingMode Rounding = RoundingMode::NearestTiesToEven) {
+ if (isDefaultFPEnvironment(ExBehavior, Rounding))
+ if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))
+ return C;
-static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned) {
- if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))
+ if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q, ExBehavior, Rounding))
return C;
- if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))
- return C;
+ if (!isDefaultFPEnvironment(ExBehavior, Rounding))
+ return nullptr;
// Unlike fdiv, the result of frem always matches the sign of the dividend.
// The constant match may include undef elements in a vector, so return a full
@@ -5093,8 +5169,11 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
}
Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q) {
- return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit);
+ const SimplifyQuery &Q,
+ fp::ExceptionBehavior ExBehavior,
+ RoundingMode Rounding) {
+ return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior,
+ Rounding);
}
//=== Helper functions for higher up the class hierarchy.
@@ -5373,6 +5452,12 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
return Op0;
break;
}
+ case Intrinsic::experimental_vector_reverse:
+ // experimental.vector.reverse(experimental.vector.reverse(x)) -> x
+ if (match(Op0,
+ m_Intrinsic<Intrinsic::experimental_vector_reverse>(m_Value(X))))
+ return X;
+ break;
default:
break;
}
@@ -5380,16 +5465,6 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
return nullptr;
}
-static Intrinsic::ID getMaxMinOpposite(Intrinsic::ID IID) {
- switch (IID) {
- case Intrinsic::smax: return Intrinsic::smin;
- case Intrinsic::smin: return Intrinsic::smax;
- case Intrinsic::umax: return Intrinsic::umin;
- case Intrinsic::umin: return Intrinsic::umax;
- default: llvm_unreachable("Unexpected intrinsic");
- }
-}
-
static APInt getMaxMinLimit(Intrinsic::ID IID, unsigned BitWidth) {
switch (IID) {
case Intrinsic::smax: return APInt::getSignedMaxValue(BitWidth);
@@ -5429,7 +5504,7 @@ static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) {
if (IID0 == IID)
return MM0;
// max (min X, Y), X --> X
- if (IID0 == getMaxMinOpposite(IID))
+ if (IID0 == getInverseMinMaxIntrinsic(IID))
return Op1;
}
return nullptr;
@@ -5449,6 +5524,20 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
return Op0;
break;
+ case Intrinsic::cttz: {
+ Value *X;
+ if (match(Op0, m_Shl(m_One(), m_Value(X))))
+ return X;
+ break;
+ }
+ case Intrinsic::ctlz: {
+ Value *X;
+ if (match(Op0, m_LShr(m_Negative(), m_Value(X))))
+ return X;
+ if (match(Op0, m_AShr(m_Negative(), m_Value())))
+ return Constant::getNullValue(ReturnType);
+ break;
+ }
case Intrinsic::smax:
case Intrinsic::smin:
case Intrinsic::umax:
@@ -5475,7 +5564,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
// If the constant op is the opposite of the limit value, the other must
// be larger/smaller or equal. For example:
// umin(i8 %x, i8 255) --> %x
- if (*C == getMaxMinLimit(getMaxMinOpposite(IID), BitWidth))
+ if (*C == getMaxMinLimit(getInverseMinMaxIntrinsic(IID), BitWidth))
return Op0;
// Remove nested call if constant operands allow it. Example:
@@ -5661,6 +5750,19 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
break;
}
+ case Intrinsic::experimental_vector_extract: {
+ Type *ReturnType = F->getReturnType();
+
+ // (extract_vector (insert_vector _, X, 0), 0) -> X
+ unsigned IdxN = cast<ConstantInt>(Op1)->getZExtValue();
+ Value *X = nullptr;
+ if (match(Op0, m_Intrinsic<Intrinsic::experimental_vector_insert>(
+ m_Value(), m_Value(X), m_Zero())) &&
+ IdxN == 0 && X->getType() == ReturnType)
+ return X;
+
+ break;
+ }
default:
break;
}
@@ -5717,15 +5819,115 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
}
return nullptr;
}
+ case Intrinsic::experimental_constrained_fma: {
+ Value *Op0 = Call->getArgOperand(0);
+ Value *Op1 = Call->getArgOperand(1);
+ Value *Op2 = Call->getArgOperand(2);
+ auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
+ if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q,
+ FPI->getExceptionBehavior().getValue(),
+ FPI->getRoundingMode().getValue()))
+ return V;
+ return nullptr;
+ }
case Intrinsic::fma:
case Intrinsic::fmuladd: {
Value *Op0 = Call->getArgOperand(0);
Value *Op1 = Call->getArgOperand(1);
Value *Op2 = Call->getArgOperand(2);
- if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }, {}, Q))
+ if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q, fp::ebIgnore,
+ RoundingMode::NearestTiesToEven))
return V;
return nullptr;
}
+ case Intrinsic::smul_fix:
+ case Intrinsic::smul_fix_sat: {
+ Value *Op0 = Call->getArgOperand(0);
+ Value *Op1 = Call->getArgOperand(1);
+ Value *Op2 = Call->getArgOperand(2);
+ Type *ReturnType = F->getReturnType();
+
+ // Canonicalize constant operand as Op1 (ConstantFolding handles the case
+ // when both Op0 and Op1 are constant so we do not care about that special
+ // case here).
+ if (isa<Constant>(Op0))
+ std::swap(Op0, Op1);
+
+ // X * 0 -> 0
+ if (match(Op1, m_Zero()))
+ return Constant::getNullValue(ReturnType);
+
+ // X * undef -> 0
+ if (Q.isUndefValue(Op1))
+ return Constant::getNullValue(ReturnType);
+
+ // X * (1 << Scale) -> X
+ APInt ScaledOne =
+ APInt::getOneBitSet(ReturnType->getScalarSizeInBits(),
+ cast<ConstantInt>(Op2)->getZExtValue());
+ if (ScaledOne.isNonNegative() && match(Op1, m_SpecificInt(ScaledOne)))
+ return Op0;
+
+ return nullptr;
+ }
+ case Intrinsic::experimental_vector_insert: {
+ Value *Vec = Call->getArgOperand(0);
+ Value *SubVec = Call->getArgOperand(1);
+ Value *Idx = Call->getArgOperand(2);
+ Type *ReturnType = F->getReturnType();
+
+ // (insert_vector Y, (extract_vector X, 0), 0) -> X
+ // where: Y is X, or Y is undef
+ unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue();
+ Value *X = nullptr;
+ if (match(SubVec, m_Intrinsic<Intrinsic::experimental_vector_extract>(
+ m_Value(X), m_Zero())) &&
+ (Q.isUndefValue(Vec) || Vec == X) && IdxN == 0 &&
+ X->getType() == ReturnType)
+ return X;
+
+ return nullptr;
+ }
+ case Intrinsic::experimental_constrained_fadd: {
+ auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
+ return SimplifyFAddInst(FPI->getArgOperand(0), FPI->getArgOperand(1),
+ FPI->getFastMathFlags(), Q,
+ FPI->getExceptionBehavior().getValue(),
+ FPI->getRoundingMode().getValue());
+ break;
+ }
+ case Intrinsic::experimental_constrained_fsub: {
+ auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
+ return SimplifyFSubInst(FPI->getArgOperand(0), FPI->getArgOperand(1),
+ FPI->getFastMathFlags(), Q,
+ FPI->getExceptionBehavior().getValue(),
+ FPI->getRoundingMode().getValue());
+ break;
+ }
+ case Intrinsic::experimental_constrained_fmul: {
+ auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
+ return SimplifyFMulInst(FPI->getArgOperand(0), FPI->getArgOperand(1),
+ FPI->getFastMathFlags(), Q,
+ FPI->getExceptionBehavior().getValue(),
+ FPI->getRoundingMode().getValue());
+ break;
+ }
+ case Intrinsic::experimental_constrained_fdiv: {
+ auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
+ return SimplifyFDivInst(FPI->getArgOperand(0), FPI->getArgOperand(1),
+ FPI->getFastMathFlags(), Q,
+ FPI->getExceptionBehavior().getValue(),
+ FPI->getRoundingMode().getValue());
+ break;
+ }
+ case Intrinsic::experimental_constrained_frem: {
+ auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
+ return SimplifyFRemInst(FPI->getArgOperand(0), FPI->getArgOperand(1),
+ FPI->getFastMathFlags(), Q,
+ FPI->getExceptionBehavior().getValue(),
+ FPI->getRoundingMode().getValue());
+ break;
+ }
default:
return nullptr;
}
@@ -5788,162 +5990,223 @@ Value *llvm::SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) {
return ::SimplifyFreezeInst(Op0, Q);
}
+static Constant *ConstructLoadOperandConstant(Value *Op) {
+ SmallVector<Value *, 4> Worklist;
+ // Invalid IR in unreachable code may contain self-referential values. Don't infinitely loop.
+ SmallPtrSet<Value *, 4> Visited;
+ Worklist.push_back(Op);
+ while (true) {
+ Value *CurOp = Worklist.back();
+ if (!Visited.insert(CurOp).second)
+ return nullptr;
+ if (isa<Constant>(CurOp))
+ break;
+ if (auto *BC = dyn_cast<BitCastOperator>(CurOp)) {
+ Worklist.push_back(BC->getOperand(0));
+ } else if (auto *GEP = dyn_cast<GEPOperator>(CurOp)) {
+ for (unsigned I = 1; I != GEP->getNumOperands(); ++I) {
+ if (!isa<Constant>(GEP->getOperand(I)))
+ return nullptr;
+ }
+ Worklist.push_back(GEP->getOperand(0));
+ } else if (auto *II = dyn_cast<IntrinsicInst>(CurOp)) {
+ if (II->isLaunderOrStripInvariantGroup())
+ Worklist.push_back(II->getOperand(0));
+ else
+ return nullptr;
+ } else {
+ return nullptr;
+ }
+ }
+
+ Constant *NewOp = cast<Constant>(Worklist.pop_back_val());
+ while (!Worklist.empty()) {
+ Value *CurOp = Worklist.pop_back_val();
+ if (isa<BitCastOperator>(CurOp)) {
+ NewOp = ConstantExpr::getBitCast(NewOp, CurOp->getType());
+ } else if (auto *GEP = dyn_cast<GEPOperator>(CurOp)) {
+ SmallVector<Constant *> Idxs;
+ Idxs.reserve(GEP->getNumOperands() - 1);
+ for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) {
+ Idxs.push_back(cast<Constant>(GEP->getOperand(I)));
+ }
+ NewOp = ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), NewOp,
+ Idxs, GEP->isInBounds(),
+ GEP->getInRangeIndex());
+ } else {
+ assert(isa<IntrinsicInst>(CurOp) &&
+ cast<IntrinsicInst>(CurOp)->isLaunderOrStripInvariantGroup() &&
+ "expected invariant group intrinsic");
+ NewOp = ConstantExpr::getBitCast(NewOp, CurOp->getType());
+ }
+ }
+ return NewOp;
+}
+
+static Value *SimplifyLoadInst(LoadInst *LI, Value *PtrOp,
+ const SimplifyQuery &Q) {
+ if (LI->isVolatile())
+ return nullptr;
+
+ // Try to make the load operand a constant, specifically handle
+ // invariant.group intrinsics.
+ auto *PtrOpC = dyn_cast<Constant>(PtrOp);
+ if (!PtrOpC)
+ PtrOpC = ConstructLoadOperandConstant(PtrOp);
+
+ if (PtrOpC)
+ return ConstantFoldLoadFromConstPtr(PtrOpC, LI->getType(), Q.DL);
+
+ return nullptr;
+}
+
/// See if we can compute a simplified version of this instruction.
/// If not, this returns null.
-Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
- OptimizationRemarkEmitter *ORE) {
+static Value *simplifyInstructionWithOperands(Instruction *I,
+ ArrayRef<Value *> NewOps,
+ const SimplifyQuery &SQ,
+ OptimizationRemarkEmitter *ORE) {
const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I);
- Value *Result;
+ Value *Result = nullptr;
switch (I->getOpcode()) {
default:
- Result = ConstantFoldInstruction(I, Q.DL, Q.TLI);
+ if (llvm::all_of(NewOps, [](Value *V) { return isa<Constant>(V); })) {
+ SmallVector<Constant *, 8> NewConstOps(NewOps.size());
+ transform(NewOps, NewConstOps.begin(),
+ [](Value *V) { return cast<Constant>(V); });
+ Result = ConstantFoldInstOperands(I, NewConstOps, Q.DL, Q.TLI);
+ }
break;
case Instruction::FNeg:
- Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q);
+ Result = SimplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q);
break;
case Instruction::FAdd:
- Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1),
- I->getFastMathFlags(), Q);
+ Result = SimplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
break;
case Instruction::Add:
- Result =
- SimplifyAddInst(I->getOperand(0), I->getOperand(1),
- Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
- Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
+ Result = SimplifyAddInst(
+ NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
break;
case Instruction::FSub:
- Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1),
- I->getFastMathFlags(), Q);
+ Result = SimplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
break;
case Instruction::Sub:
- Result =
- SimplifySubInst(I->getOperand(0), I->getOperand(1),
- Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
- Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
+ Result = SimplifySubInst(
+ NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
break;
case Instruction::FMul:
- Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1),
- I->getFastMathFlags(), Q);
+ Result = SimplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
break;
case Instruction::Mul:
- Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyMulInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::SDiv:
- Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifySDivInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::UDiv:
- Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyUDivInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::FDiv:
- Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1),
- I->getFastMathFlags(), Q);
+ Result = SimplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
break;
case Instruction::SRem:
- Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifySRemInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::URem:
- Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyURemInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::FRem:
- Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1),
- I->getFastMathFlags(), Q);
+ Result = SimplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
break;
case Instruction::Shl:
- Result =
- SimplifyShlInst(I->getOperand(0), I->getOperand(1),
- Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
- Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
+ Result = SimplifyShlInst(
+ NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
break;
case Instruction::LShr:
- Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1),
+ Result = SimplifyLShrInst(NewOps[0], NewOps[1],
Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
break;
case Instruction::AShr:
- Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1),
+ Result = SimplifyAShrInst(NewOps[0], NewOps[1],
Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
break;
case Instruction::And:
- Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyAndInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::Or:
- Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyOrInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::Xor:
- Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyXorInst(NewOps[0], NewOps[1], Q);
break;
case Instruction::ICmp:
- Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(),
- I->getOperand(0), I->getOperand(1), Q);
+ Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0],
+ NewOps[1], Q);
break;
case Instruction::FCmp:
- Result =
- SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0),
- I->getOperand(1), I->getFastMathFlags(), Q);
+ Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0],
+ NewOps[1], I->getFastMathFlags(), Q);
break;
case Instruction::Select:
- Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1),
- I->getOperand(2), Q);
+ Result = SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q);
break;
case Instruction::GetElementPtr: {
- SmallVector<Value *, 8> Ops(I->operands());
Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(),
- Ops, Q);
+ NewOps, Q);
break;
}
case Instruction::InsertValue: {
InsertValueInst *IV = cast<InsertValueInst>(I);
- Result = SimplifyInsertValueInst(IV->getAggregateOperand(),
- IV->getInsertedValueOperand(),
- IV->getIndices(), Q);
+ Result = SimplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q);
break;
}
case Instruction::InsertElement: {
- auto *IE = cast<InsertElementInst>(I);
- Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1),
- IE->getOperand(2), Q);
+ Result = SimplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q);
break;
}
case Instruction::ExtractValue: {
auto *EVI = cast<ExtractValueInst>(I);
- Result = SimplifyExtractValueInst(EVI->getAggregateOperand(),
- EVI->getIndices(), Q);
+ Result = SimplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q);
break;
}
case Instruction::ExtractElement: {
- auto *EEI = cast<ExtractElementInst>(I);
- Result = SimplifyExtractElementInst(EEI->getVectorOperand(),
- EEI->getIndexOperand(), Q);
+ Result = SimplifyExtractElementInst(NewOps[0], NewOps[1], Q);
break;
}
case Instruction::ShuffleVector: {
auto *SVI = cast<ShuffleVectorInst>(I);
- Result =
- SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1),
- SVI->getShuffleMask(), SVI->getType(), Q);
+ Result = SimplifyShuffleVectorInst(
+ NewOps[0], NewOps[1], SVI->getShuffleMask(), SVI->getType(), Q);
break;
}
case Instruction::PHI:
- Result = SimplifyPHINode(cast<PHINode>(I), Q);
+ Result = SimplifyPHINode(cast<PHINode>(I), NewOps, Q);
break;
case Instruction::Call: {
+ // TODO: Use NewOps
Result = SimplifyCall(cast<CallInst>(I), Q);
break;
}
case Instruction::Freeze:
- Result = SimplifyFreezeInst(I->getOperand(0), Q);
+ Result = llvm::SimplifyFreezeInst(NewOps[0], Q);
break;
#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc:
#include "llvm/IR/Instruction.def"
#undef HANDLE_CAST_INST
- Result =
- SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q);
+ Result = SimplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q);
break;
case Instruction::Alloca:
// No simplifications for Alloca and it can't be constant folded.
Result = nullptr;
break;
+ case Instruction::Load:
+ Result = SimplifyLoadInst(cast<LoadInst>(I), NewOps[0], Q);
+ break;
}
/// If called on unreachable code, the above logic may report that the
@@ -5952,6 +6215,21 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
return Result == I ? UndefValue::get(I->getType()) : Result;
}
+Value *llvm::SimplifyInstructionWithOperands(Instruction *I,
+ ArrayRef<Value *> NewOps,
+ const SimplifyQuery &SQ,
+ OptimizationRemarkEmitter *ORE) {
+ assert(NewOps.size() == I->getNumOperands() &&
+ "Number of operands should match the instruction!");
+ return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE);
+}
+
+Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
+ OptimizationRemarkEmitter *ORE) {
+ SmallVector<Value *, 8> Ops(I->operands());
+ return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE);
+}
+
/// Implementation of recursive simplification through an instruction's
/// uses.
///