aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp327
1 files changed, 199 insertions, 128 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2654c00929d8..edb0756e8c3b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1868,8 +1868,7 @@ SDValue DAGCombiner::combine(SDNode *N) {
// If N is a commutative binary node, try to eliminate it if the commuted
// version is already present in the DAG.
- if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) &&
- N->getNumValues() == 1) {
+ if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
@@ -4159,6 +4158,10 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags()))
return RMUL;
+ // Simplify the operands using demanded-bits information.
+ if (SimplifyDemandedBits(SDValue(N, 0)))
+ return SDValue(N, 0);
+
return SDValue();
}
@@ -5978,44 +5981,64 @@ static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) {
if (!TLI.isTypeLegal(VT))
return SDValue();
- // Look through an optional extension and find a 'not'.
- // TODO: Should we favor test+set even without the 'not' op?
- SDValue Not = And->getOperand(0), And1 = And->getOperand(1);
- if (Not.getOpcode() == ISD::ANY_EXTEND)
- Not = Not.getOperand(0);
- if (!isBitwiseNot(Not) || !Not.hasOneUse() || !isOneConstant(And1))
+ // Look through an optional extension.
+ SDValue And0 = And->getOperand(0), And1 = And->getOperand(1);
+ if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse())
+ And0 = And0.getOperand(0);
+ if (!isOneConstant(And1) || !And0.hasOneUse())
return SDValue();
- // Look though an optional truncation. The source operand may not be the same
- // type as the original 'and', but that is ok because we are masking off
- // everything but the low bit.
- SDValue Srl = Not.getOperand(0);
- if (Srl.getOpcode() == ISD::TRUNCATE)
- Srl = Srl.getOperand(0);
+ SDValue Src = And0;
+
+ // Attempt to find a 'not' op.
+ // TODO: Should we favor test+set even without the 'not' op?
+ bool FoundNot = false;
+ if (isBitwiseNot(Src)) {
+ FoundNot = true;
+ Src = Src.getOperand(0);
+
+ // Look though an optional truncation. The source operand may not be the
+ // same type as the original 'and', but that is ok because we are masking
+ // off everything but the low bit.
+ if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse())
+ Src = Src.getOperand(0);
+ }
// Match a shift-right by constant.
- if (Srl.getOpcode() != ISD::SRL || !Srl.hasOneUse() ||
- !isa<ConstantSDNode>(Srl.getOperand(1)))
+ if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse())
return SDValue();
// We might have looked through casts that make this transform invalid.
// TODO: If the source type is wider than the result type, do the mask and
// compare in the source type.
- const APInt &ShiftAmt = Srl.getConstantOperandAPInt(1);
- unsigned VTBitWidth = VT.getSizeInBits();
- if (ShiftAmt.uge(VTBitWidth))
+ unsigned VTBitWidth = VT.getScalarSizeInBits();
+ SDValue ShiftAmt = Src.getOperand(1);
+ auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt);
+ if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth))
return SDValue();
- if (!TLI.hasBitTest(Srl.getOperand(0), Srl.getOperand(1)))
+ // Set source to shift source.
+ Src = Src.getOperand(0);
+
+ // Try again to find a 'not' op.
+ // TODO: Should we favor test+set even with two 'not' ops?
+ if (!FoundNot) {
+ if (!isBitwiseNot(Src))
+ return SDValue();
+ Src = Src.getOperand(0);
+ }
+
+ if (!TLI.hasBitTest(Src, ShiftAmt))
return SDValue();
// Turn this into a bit-test pattern using mask op + setcc:
// and (not (srl X, C)), 1 --> (and X, 1<<C) == 0
+ // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0
SDLoc DL(And);
- SDValue X = DAG.getZExtOrTrunc(Srl.getOperand(0), DL, VT);
+ SDValue X = DAG.getZExtOrTrunc(Src, DL, VT);
EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
SDValue Mask = DAG.getConstant(
- APInt::getOneBitSet(VTBitWidth, ShiftAmt.getZExtValue()), DL, VT);
+ APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT);
SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask);
SDValue Zero = DAG.getConstant(0, DL, VT);
SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ);
@@ -6229,7 +6252,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// This can be a pure constant or a vector splat, in which case we treat the
// vector as a scalar and use the splat value.
APInt Constant = APInt::getZero(1);
- if (const ConstantSDNode *C = isConstOrConstSplat(N1)) {
+ if (const ConstantSDNode *C = isConstOrConstSplat(
+ N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) {
Constant = C->getAPIntValue();
} else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) {
APInt SplatValue, SplatUndef;
@@ -6339,18 +6363,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
// fold (and (load x), 255) -> (zextload x, i8)
// fold (and (extload x, i16), 255) -> (zextload x, i8)
- // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8)
- if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD ||
- (N0.getOpcode() == ISD::ANY_EXTEND &&
- N0.getOperand(0).getOpcode() == ISD::LOAD))) {
- if (SDValue Res = reduceLoadWidth(N)) {
- LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND
- ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0);
- AddToWorklist(N);
- DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res);
- return SDValue(N, 0);
- }
- }
+ if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector())
+ if (SDValue Res = reduceLoadWidth(N))
+ return Res;
if (LegalTypes) {
// Attempt to propagate the AND back up to the leaves which, if they're
@@ -6856,20 +6871,23 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) {
}
/// OR combines for which the commuted variant will be tried as well.
-static SDValue visitORCommutative(
- SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) {
+static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1,
+ SDNode *N) {
EVT VT = N0.getValueType();
if (N0.getOpcode() == ISD::AND) {
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+
// fold (or (and X, (xor Y, -1)), Y) -> (or X, Y)
// TODO: Set AllowUndefs = true.
- if (getBitwiseNotOperand(N0.getOperand(1), N0.getOperand(0),
+ if (getBitwiseNotOperand(N01, N00,
/* AllowUndefs */ false) == N1)
- return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1);
+ return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1);
// fold (or (and (xor Y, -1), X), Y) -> (or X, Y)
- if (getBitwiseNotOperand(N0.getOperand(0), N0.getOperand(1),
+ if (getBitwiseNotOperand(N00, N01,
/* AllowUndefs */ false) == N1)
- return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1);
+ return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1);
}
if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG))
@@ -7915,7 +7933,7 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) {
int64_t FirstOffset = INT64_MAX;
StoreSDNode *FirstStore = nullptr;
Optional<BaseIndexOffset> Base;
- for (auto Store : Stores) {
+ for (auto *Store : Stores) {
// All the stores store different parts of the CombinedValue. A truncate is
// required to get the partial value.
SDValue Trunc = Store->getValue();
@@ -8488,28 +8506,6 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
return DAG.getNode(ISD::AND, DL, VT, NotX, N1);
}
- if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) {
- ConstantSDNode *XorC = isConstOrConstSplat(N1);
- ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1));
- unsigned BitWidth = VT.getScalarSizeInBits();
- if (XorC && ShiftC) {
- // Don't crash on an oversized shift. We can not guarantee that a bogus
- // shift has been simplified to undef.
- uint64_t ShiftAmt = ShiftC->getLimitedValue();
- if (ShiftAmt < BitWidth) {
- APInt Ones = APInt::getAllOnes(BitWidth);
- Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt);
- if (XorC->getAPIntValue() == Ones) {
- // If the xor constant is a shifted -1, do a 'not' before the shift:
- // xor (X << ShiftC), XorC --> (not X) << ShiftC
- // xor (X >> ShiftC), XorC --> (not X) >> ShiftC
- SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT);
- return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1));
- }
- }
- }
- }
-
// fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X)
if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) {
SDValue A = N0Opcode == ISD::ADD ? N0 : N1;
@@ -11817,6 +11813,9 @@ SDValue DAGCombiner::foldSextSetcc(SDNode *N) {
EVT N00VT = N00.getValueType();
SDLoc DL(N);
+ // Propagate fast-math-flags.
+ SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
+
// On some architectures (such as SSE/NEON/etc) the SETCC result type is
// the same size as the compared operands. Try to optimize sext(setcc())
// if this is the case.
@@ -12358,6 +12357,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
return V;
if (N0.getOpcode() == ISD::SETCC) {
+ // Propagate fast-math-flags.
+ SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
+
// Only do this before legalize for now.
if (!LegalOperations && VT.isVector() &&
N0.getValueType().getVectorElementType() == MVT::i1) {
@@ -12549,6 +12551,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
}
if (N0.getOpcode() == ISD::SETCC) {
+ // Propagate fast-math-flags.
+ SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags());
+
// For vectors:
// aext(setcc) -> vsetcc
// aext(setcc) -> truncate(vsetcc)
@@ -13155,6 +13160,19 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
return N0.getOperand(0);
}
+ // Try to narrow a truncate-of-sext_in_reg to the destination type:
+ // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM
+ if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG &&
+ N0.hasOneUse()) {
+ SDValue X = N0.getOperand(0);
+ SDValue ExtVal = N0.getOperand(1);
+ EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT();
+ if (ExtVT.bitsLT(VT)) {
+ SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X);
+ return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal);
+ }
+ }
+
// If this is anyext(trunc), don't fold it, allow ourselves to be folded.
if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND))
return SDValue();
@@ -19478,7 +19496,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
return Shuf;
// Handle <1 x ???> vector insertion special cases.
- if (VT.getVectorNumElements() == 1) {
+ if (NumElts == 1) {
// insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y
if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
InVal.getOperand(0).getValueType() == VT &&
@@ -19506,80 +19524,77 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
}
}
- // Attempt to fold the insertion into a legal BUILD_VECTOR.
+ // Attempt to convert an insert_vector_elt chain into a legal build_vector.
if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) {
- auto UpdateBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
- assert(Ops.size() == NumElts && "Unexpected vector size");
-
- // Insert the element
- if (Elt < Ops.size()) {
- // All the operands of BUILD_VECTOR must have the same type;
- // we enforce that here.
- EVT OpVT = Ops[0].getValueType();
- Ops[Elt] =
- OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal;
+ // vXi1 vector - we don't need to recurse.
+ if (NumElts == 1)
+ return DAG.getBuildVector(VT, DL, {InVal});
+
+ // If we haven't already collected the element, insert into the op list.
+ EVT MaxEltVT = InVal.getValueType();
+ auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt,
+ unsigned Idx) {
+ if (!Ops[Idx]) {
+ Ops[Idx] = Elt;
+ if (VT.isInteger()) {
+ EVT EltVT = Elt.getValueType();
+ MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT;
+ }
}
+ };
- // Return the new vector
+ // Ensure all the operands are the same value type, fill any missing
+ // operands with UNDEF and create the BUILD_VECTOR.
+ auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) {
+ assert(Ops.size() == NumElts && "Unexpected vector size");
+ for (SDValue &Op : Ops) {
+ if (Op)
+ Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op;
+ else
+ Op = DAG.getUNDEF(MaxEltVT);
+ }
return DAG.getBuildVector(VT, DL, Ops);
};
- // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially
- // be converted to a BUILD_VECTOR). Fill in the Ops vector with the
- // vector elements.
- SmallVector<SDValue, 8> Ops;
+ SmallVector<SDValue, 8> Ops(NumElts, SDValue());
+ Ops[Elt] = InVal;
- // Do not combine these two vectors if the output vector will not replace
- // the input vector.
- if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) {
- Ops.append(InVec->op_begin(), InVec->op_end());
- return UpdateBuildVector(Ops);
- }
+ // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR.
+ for (SDValue CurVec = InVec; CurVec;) {
+ // UNDEF - build new BUILD_VECTOR from already inserted operands.
+ if (CurVec.isUndef())
+ return CanonicalizeBuildVector(Ops);
- if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR && InVec.hasOneUse()) {
- Ops.push_back(InVec.getOperand(0));
- Ops.append(NumElts - 1, DAG.getUNDEF(InVec.getOperand(0).getValueType()));
- return UpdateBuildVector(Ops);
- }
+ // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR.
+ if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) {
+ for (unsigned I = 0; I != NumElts; ++I)
+ AddBuildVectorOp(Ops, CurVec.getOperand(I), I);
+ return CanonicalizeBuildVector(Ops);
+ }
- if (InVec.isUndef()) {
- Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType()));
- return UpdateBuildVector(Ops);
- }
+ // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR.
+ if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) {
+ AddBuildVectorOp(Ops, CurVec.getOperand(0), 0);
+ return CanonicalizeBuildVector(Ops);
+ }
- // If we're inserting into the end of a vector as part of an sequence, see
- // if we can create a BUILD_VECTOR by following the sequence back up the
- // chain.
- if (Elt == (NumElts - 1)) {
- SmallVector<SDValue> ReverseInsertions;
- ReverseInsertions.push_back(InVal);
-
- EVT MaxEltVT = InVal.getValueType();
- SDValue CurVec = InVec;
- for (unsigned I = 1; I != NumElts; ++I) {
- if (CurVec.getOpcode() != ISD::INSERT_VECTOR_ELT || !CurVec.hasOneUse())
- break;
+ // INSERT_VECTOR_ELT - insert operand and continue up the chain.
+ if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse())
+ if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)))
+ if (CurIdx->getAPIntValue().ult(NumElts)) {
+ unsigned Idx = CurIdx->getZExtValue();
+ AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx);
- auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2));
- if (!CurIdx || CurIdx->getAPIntValue() != ((NumElts - 1) - I))
- break;
- SDValue CurVal = CurVec.getOperand(1);
- ReverseInsertions.push_back(CurVal);
- if (VT.isInteger()) {
- EVT CurValVT = CurVal.getValueType();
- MaxEltVT = MaxEltVT.bitsGE(CurValVT) ? MaxEltVT : CurValVT;
- }
- CurVec = CurVec.getOperand(0);
- }
+ // Found entire BUILD_VECTOR.
+ if (all_of(Ops, [](SDValue Op) { return !!Op; }))
+ return CanonicalizeBuildVector(Ops);
- if (ReverseInsertions.size() == NumElts) {
- for (unsigned I = 0; I != NumElts; ++I) {
- SDValue Val = ReverseInsertions[(NumElts - 1) - I];
- Val = VT.isInteger() ? DAG.getAnyExtOrTrunc(Val, DL, MaxEltVT) : Val;
- Ops.push_back(Val);
- }
- return DAG.getBuildVector(VT, DL, Ops);
- }
+ CurVec = CurVec->getOperand(0);
+ continue;
+ }
+
+ // Failed to find a match in the chain - bail.
+ break;
}
}
@@ -22643,6 +22658,56 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
}
}
+ // If we're not performing a select/blend shuffle, see if we can convert the
+ // shuffle into a AND node, with all the out-of-lane elements are known zero.
+ if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) {
+ bool IsInLaneMask = true;
+ ArrayRef<int> Mask = SVN->getMask();
+ SmallVector<int, 16> ClearMask(NumElts, -1);
+ APInt DemandedLHS = APInt::getNullValue(NumElts);
+ APInt DemandedRHS = APInt::getNullValue(NumElts);
+ for (int I = 0; I != (int)NumElts; ++I) {
+ int M = Mask[I];
+ if (M < 0)
+ continue;
+ ClearMask[I] = M == I ? I : (I + NumElts);
+ IsInLaneMask &= (M == I) || (M == (int)(I + NumElts));
+ if (M != I) {
+ APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS;
+ Demanded.setBit(M % NumElts);
+ }
+ }
+ // TODO: Should we try to mask with N1 as well?
+ if (!IsInLaneMask &&
+ (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) &&
+ (DemandedLHS.isNullValue() ||
+ DAG.MaskedVectorIsZero(N0, DemandedLHS)) &&
+ (DemandedRHS.isNullValue() ||
+ DAG.MaskedVectorIsZero(N1, DemandedRHS))) {
+ SDLoc DL(N);
+ EVT IntVT = VT.changeVectorElementTypeToInteger();
+ EVT IntSVT = VT.getVectorElementType().changeTypeToInteger();
+ SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT);
+ SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT);
+ SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT));
+ for (int I = 0; I != (int)NumElts; ++I)
+ if (0 <= Mask[I])
+ AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt;
+
+ // See if a clear mask is legal instead of going via
+ // XformToShuffleWithZero which loses UNDEF mask elements.
+ if (TLI.isVectorClearMaskLegal(ClearMask, IntVT))
+ return DAG.getBitcast(
+ VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0),
+ DAG.getConstant(0, DL, IntVT), ClearMask));
+
+ if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT))
+ return DAG.getBitcast(
+ VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0),
+ DAG.getBuildVector(IntVT, DL, AndMask)));
+ }
+ }
+
// Attempt to combine a shuffle of 2 inputs of 'scalar sources' -
// BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR.
if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT))
@@ -23385,10 +23450,14 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
int Index0, Index1;
SDValue Src0 = DAG.getSplatSourceVector(N0, Index0);
SDValue Src1 = DAG.getSplatSourceVector(N1, Index1);
+ // Extract element from splat_vector should be free.
+ // TODO: use DAG.isSplatValue instead?
+ bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR &&
+ N1.getOpcode() == ISD::SPLAT_VECTOR;
if (!Src0 || !Src1 || Index0 != Index1 ||
Src0.getValueType().getVectorElementType() != EltVT ||
Src1.getValueType().getVectorElementType() != EltVT ||
- !TLI.isExtractVecEltCheap(VT, Index0) ||
+ !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) ||
!TLI.isOperationLegalOrCustom(Opcode, EltVT))
return SDValue();
@@ -23410,6 +23479,8 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
}
// bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
+ if (VT.isScalableVector())
+ return DAG.getSplatVector(VT, DL, ScalarBO);
SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO);
return DAG.getBuildVector(VT, DL, Ops);
}