diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 327 |
1 files changed, 199 insertions, 128 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 2654c00929d8..edb0756e8c3b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -1868,8 +1868,7 @@ SDValue DAGCombiner::combine(SDNode *N) { // If N is a commutative binary node, try to eliminate it if the commuted // version is already present in the DAG. - if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode()) && - N->getNumValues() == 1) { + if (!RV.getNode() && TLI.isCommutativeBinOp(N->getOpcode())) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -4159,6 +4158,10 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { if (SDValue RMUL = reassociateOps(ISD::MUL, DL, N0, N1, N->getFlags())) return RMUL; + // Simplify the operands using demanded-bits information. + if (SimplifyDemandedBits(SDValue(N, 0))) + return SDValue(N, 0); + return SDValue(); } @@ -5978,44 +5981,64 @@ static SDValue combineShiftAnd1ToBitTest(SDNode *And, SelectionDAG &DAG) { if (!TLI.isTypeLegal(VT)) return SDValue(); - // Look through an optional extension and find a 'not'. - // TODO: Should we favor test+set even without the 'not' op? - SDValue Not = And->getOperand(0), And1 = And->getOperand(1); - if (Not.getOpcode() == ISD::ANY_EXTEND) - Not = Not.getOperand(0); - if (!isBitwiseNot(Not) || !Not.hasOneUse() || !isOneConstant(And1)) + // Look through an optional extension. + SDValue And0 = And->getOperand(0), And1 = And->getOperand(1); + if (And0.getOpcode() == ISD::ANY_EXTEND && And0.hasOneUse()) + And0 = And0.getOperand(0); + if (!isOneConstant(And1) || !And0.hasOneUse()) return SDValue(); - // Look though an optional truncation. The source operand may not be the same - // type as the original 'and', but that is ok because we are masking off - // everything but the low bit. - SDValue Srl = Not.getOperand(0); - if (Srl.getOpcode() == ISD::TRUNCATE) - Srl = Srl.getOperand(0); + SDValue Src = And0; + + // Attempt to find a 'not' op. + // TODO: Should we favor test+set even without the 'not' op? + bool FoundNot = false; + if (isBitwiseNot(Src)) { + FoundNot = true; + Src = Src.getOperand(0); + + // Look though an optional truncation. The source operand may not be the + // same type as the original 'and', but that is ok because we are masking + // off everything but the low bit. + if (Src.getOpcode() == ISD::TRUNCATE && Src.hasOneUse()) + Src = Src.getOperand(0); + } // Match a shift-right by constant. - if (Srl.getOpcode() != ISD::SRL || !Srl.hasOneUse() || - !isa<ConstantSDNode>(Srl.getOperand(1))) + if (Src.getOpcode() != ISD::SRL || !Src.hasOneUse()) return SDValue(); // We might have looked through casts that make this transform invalid. // TODO: If the source type is wider than the result type, do the mask and // compare in the source type. - const APInt &ShiftAmt = Srl.getConstantOperandAPInt(1); - unsigned VTBitWidth = VT.getSizeInBits(); - if (ShiftAmt.uge(VTBitWidth)) + unsigned VTBitWidth = VT.getScalarSizeInBits(); + SDValue ShiftAmt = Src.getOperand(1); + auto *ShiftAmtC = dyn_cast<ConstantSDNode>(ShiftAmt); + if (!ShiftAmtC || !ShiftAmtC->getAPIntValue().ult(VTBitWidth)) return SDValue(); - if (!TLI.hasBitTest(Srl.getOperand(0), Srl.getOperand(1))) + // Set source to shift source. + Src = Src.getOperand(0); + + // Try again to find a 'not' op. + // TODO: Should we favor test+set even with two 'not' ops? + if (!FoundNot) { + if (!isBitwiseNot(Src)) + return SDValue(); + Src = Src.getOperand(0); + } + + if (!TLI.hasBitTest(Src, ShiftAmt)) return SDValue(); // Turn this into a bit-test pattern using mask op + setcc: // and (not (srl X, C)), 1 --> (and X, 1<<C) == 0 + // and (srl (not X), C)), 1 --> (and X, 1<<C) == 0 SDLoc DL(And); - SDValue X = DAG.getZExtOrTrunc(Srl.getOperand(0), DL, VT); + SDValue X = DAG.getZExtOrTrunc(Src, DL, VT); EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); SDValue Mask = DAG.getConstant( - APInt::getOneBitSet(VTBitWidth, ShiftAmt.getZExtValue()), DL, VT); + APInt::getOneBitSet(VTBitWidth, ShiftAmtC->getZExtValue()), DL, VT); SDValue NewAnd = DAG.getNode(ISD::AND, DL, VT, X, Mask); SDValue Zero = DAG.getConstant(0, DL, VT); SDValue Setcc = DAG.getSetCC(DL, CCVT, NewAnd, Zero, ISD::SETEQ); @@ -6229,7 +6252,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // This can be a pure constant or a vector splat, in which case we treat the // vector as a scalar and use the splat value. APInt Constant = APInt::getZero(1); - if (const ConstantSDNode *C = isConstOrConstSplat(N1)) { + if (const ConstantSDNode *C = isConstOrConstSplat( + N1, /*AllowUndef=*/false, /*AllowTruncation=*/true)) { Constant = C->getAPIntValue(); } else if (BuildVectorSDNode *Vector = dyn_cast<BuildVectorSDNode>(N1)) { APInt SplatValue, SplatUndef; @@ -6339,18 +6363,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) - // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) - if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD || - (N0.getOpcode() == ISD::ANY_EXTEND && - N0.getOperand(0).getOpcode() == ISD::LOAD))) { - if (SDValue Res = reduceLoadWidth(N)) { - LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND - ? cast<LoadSDNode>(N0.getOperand(0)) : cast<LoadSDNode>(N0); - AddToWorklist(N); - DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 0), Res); - return SDValue(N, 0); - } - } + if (N1C && N0.getOpcode() == ISD::LOAD && !VT.isVector()) + if (SDValue Res = reduceLoadWidth(N)) + return Res; if (LegalTypes) { // Attempt to propagate the AND back up to the leaves which, if they're @@ -6856,20 +6871,23 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *N) { } /// OR combines for which the commuted variant will be tried as well. -static SDValue visitORCommutative( - SelectionDAG &DAG, SDValue N0, SDValue N1, SDNode *N) { +static SDValue visitORCommutative(SelectionDAG &DAG, SDValue N0, SDValue N1, + SDNode *N) { EVT VT = N0.getValueType(); if (N0.getOpcode() == ISD::AND) { + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + // fold (or (and X, (xor Y, -1)), Y) -> (or X, Y) // TODO: Set AllowUndefs = true. - if (getBitwiseNotOperand(N0.getOperand(1), N0.getOperand(0), + if (getBitwiseNotOperand(N01, N00, /* AllowUndefs */ false) == N1) - return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(0), N1); + return DAG.getNode(ISD::OR, SDLoc(N), VT, N00, N1); // fold (or (and (xor Y, -1), X), Y) -> (or X, Y) - if (getBitwiseNotOperand(N0.getOperand(0), N0.getOperand(1), + if (getBitwiseNotOperand(N00, N01, /* AllowUndefs */ false) == N1) - return DAG.getNode(ISD::OR, SDLoc(N), VT, N0.getOperand(1), N1); + return DAG.getNode(ISD::OR, SDLoc(N), VT, N01, N1); } if (SDValue R = foldLogicOfShifts(N, N0, N1, DAG)) @@ -7915,7 +7933,7 @@ SDValue DAGCombiner::mergeTruncStores(StoreSDNode *N) { int64_t FirstOffset = INT64_MAX; StoreSDNode *FirstStore = nullptr; Optional<BaseIndexOffset> Base; - for (auto Store : Stores) { + for (auto *Store : Stores) { // All the stores store different parts of the CombinedValue. A truncate is // required to get the partial value. SDValue Trunc = Store->getValue(); @@ -8488,28 +8506,6 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { return DAG.getNode(ISD::AND, DL, VT, NotX, N1); } - if ((N0Opcode == ISD::SRL || N0Opcode == ISD::SHL) && N0.hasOneUse()) { - ConstantSDNode *XorC = isConstOrConstSplat(N1); - ConstantSDNode *ShiftC = isConstOrConstSplat(N0.getOperand(1)); - unsigned BitWidth = VT.getScalarSizeInBits(); - if (XorC && ShiftC) { - // Don't crash on an oversized shift. We can not guarantee that a bogus - // shift has been simplified to undef. - uint64_t ShiftAmt = ShiftC->getLimitedValue(); - if (ShiftAmt < BitWidth) { - APInt Ones = APInt::getAllOnes(BitWidth); - Ones = N0Opcode == ISD::SHL ? Ones.shl(ShiftAmt) : Ones.lshr(ShiftAmt); - if (XorC->getAPIntValue() == Ones) { - // If the xor constant is a shifted -1, do a 'not' before the shift: - // xor (X << ShiftC), XorC --> (not X) << ShiftC - // xor (X >> ShiftC), XorC --> (not X) >> ShiftC - SDValue Not = DAG.getNOT(DL, N0.getOperand(0), VT); - return DAG.getNode(N0Opcode, DL, VT, Not, N0.getOperand(1)); - } - } - } - } - // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X) if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { SDValue A = N0Opcode == ISD::ADD ? N0 : N1; @@ -11817,6 +11813,9 @@ SDValue DAGCombiner::foldSextSetcc(SDNode *N) { EVT N00VT = N00.getValueType(); SDLoc DL(N); + // Propagate fast-math-flags. + SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags()); + // On some architectures (such as SSE/NEON/etc) the SETCC result type is // the same size as the compared operands. Try to optimize sext(setcc()) // if this is the case. @@ -12358,6 +12357,9 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { return V; if (N0.getOpcode() == ISD::SETCC) { + // Propagate fast-math-flags. + SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags()); + // Only do this before legalize for now. if (!LegalOperations && VT.isVector() && N0.getValueType().getVectorElementType() == MVT::i1) { @@ -12549,6 +12551,9 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { } if (N0.getOpcode() == ISD::SETCC) { + // Propagate fast-math-flags. + SelectionDAG::FlagInserter FlagsInserter(DAG, N0->getFlags()); + // For vectors: // aext(setcc) -> vsetcc // aext(setcc) -> truncate(vsetcc) @@ -13155,6 +13160,19 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { return N0.getOperand(0); } + // Try to narrow a truncate-of-sext_in_reg to the destination type: + // trunc (sign_ext_inreg X, iM) to iN --> sign_ext_inreg (trunc X to iN), iM + if (!LegalTypes && N0.getOpcode() == ISD::SIGN_EXTEND_INREG && + N0.hasOneUse()) { + SDValue X = N0.getOperand(0); + SDValue ExtVal = N0.getOperand(1); + EVT ExtVT = cast<VTSDNode>(ExtVal)->getVT(); + if (ExtVT.bitsLT(VT)) { + SDValue TrX = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, X); + return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, TrX, ExtVal); + } + } + // If this is anyext(trunc), don't fold it, allow ourselves to be folded. if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ANY_EXTEND)) return SDValue(); @@ -19478,7 +19496,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { return Shuf; // Handle <1 x ???> vector insertion special cases. - if (VT.getVectorNumElements() == 1) { + if (NumElts == 1) { // insert_vector_elt(x, extract_vector_elt(y, 0), 0) -> y if (InVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT && InVal.getOperand(0).getValueType() == VT && @@ -19506,80 +19524,77 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) { } } - // Attempt to fold the insertion into a legal BUILD_VECTOR. + // Attempt to convert an insert_vector_elt chain into a legal build_vector. if (!LegalOperations || TLI.isOperationLegal(ISD::BUILD_VECTOR, VT)) { - auto UpdateBuildVector = [&](SmallVectorImpl<SDValue> &Ops) { - assert(Ops.size() == NumElts && "Unexpected vector size"); - - // Insert the element - if (Elt < Ops.size()) { - // All the operands of BUILD_VECTOR must have the same type; - // we enforce that here. - EVT OpVT = Ops[0].getValueType(); - Ops[Elt] = - OpVT.isInteger() ? DAG.getAnyExtOrTrunc(InVal, DL, OpVT) : InVal; + // vXi1 vector - we don't need to recurse. + if (NumElts == 1) + return DAG.getBuildVector(VT, DL, {InVal}); + + // If we haven't already collected the element, insert into the op list. + EVT MaxEltVT = InVal.getValueType(); + auto AddBuildVectorOp = [&](SmallVectorImpl<SDValue> &Ops, SDValue Elt, + unsigned Idx) { + if (!Ops[Idx]) { + Ops[Idx] = Elt; + if (VT.isInteger()) { + EVT EltVT = Elt.getValueType(); + MaxEltVT = MaxEltVT.bitsGE(EltVT) ? MaxEltVT : EltVT; + } } + }; - // Return the new vector + // Ensure all the operands are the same value type, fill any missing + // operands with UNDEF and create the BUILD_VECTOR. + auto CanonicalizeBuildVector = [&](SmallVectorImpl<SDValue> &Ops) { + assert(Ops.size() == NumElts && "Unexpected vector size"); + for (SDValue &Op : Ops) { + if (Op) + Op = VT.isInteger() ? DAG.getAnyExtOrTrunc(Op, DL, MaxEltVT) : Op; + else + Op = DAG.getUNDEF(MaxEltVT); + } return DAG.getBuildVector(VT, DL, Ops); }; - // Check that the operand is a BUILD_VECTOR (or UNDEF, which can essentially - // be converted to a BUILD_VECTOR). Fill in the Ops vector with the - // vector elements. - SmallVector<SDValue, 8> Ops; + SmallVector<SDValue, 8> Ops(NumElts, SDValue()); + Ops[Elt] = InVal; - // Do not combine these two vectors if the output vector will not replace - // the input vector. - if (InVec.getOpcode() == ISD::BUILD_VECTOR && InVec.hasOneUse()) { - Ops.append(InVec->op_begin(), InVec->op_end()); - return UpdateBuildVector(Ops); - } + // Recurse up a INSERT_VECTOR_ELT chain to build a BUILD_VECTOR. + for (SDValue CurVec = InVec; CurVec;) { + // UNDEF - build new BUILD_VECTOR from already inserted operands. + if (CurVec.isUndef()) + return CanonicalizeBuildVector(Ops); - if (InVec.getOpcode() == ISD::SCALAR_TO_VECTOR && InVec.hasOneUse()) { - Ops.push_back(InVec.getOperand(0)); - Ops.append(NumElts - 1, DAG.getUNDEF(InVec.getOperand(0).getValueType())); - return UpdateBuildVector(Ops); - } + // BUILD_VECTOR - insert unused operands and build new BUILD_VECTOR. + if (CurVec.getOpcode() == ISD::BUILD_VECTOR && CurVec.hasOneUse()) { + for (unsigned I = 0; I != NumElts; ++I) + AddBuildVectorOp(Ops, CurVec.getOperand(I), I); + return CanonicalizeBuildVector(Ops); + } - if (InVec.isUndef()) { - Ops.append(NumElts, DAG.getUNDEF(InVal.getValueType())); - return UpdateBuildVector(Ops); - } + // SCALAR_TO_VECTOR - insert unused scalar and build new BUILD_VECTOR. + if (CurVec.getOpcode() == ISD::SCALAR_TO_VECTOR && CurVec.hasOneUse()) { + AddBuildVectorOp(Ops, CurVec.getOperand(0), 0); + return CanonicalizeBuildVector(Ops); + } - // If we're inserting into the end of a vector as part of an sequence, see - // if we can create a BUILD_VECTOR by following the sequence back up the - // chain. - if (Elt == (NumElts - 1)) { - SmallVector<SDValue> ReverseInsertions; - ReverseInsertions.push_back(InVal); - - EVT MaxEltVT = InVal.getValueType(); - SDValue CurVec = InVec; - for (unsigned I = 1; I != NumElts; ++I) { - if (CurVec.getOpcode() != ISD::INSERT_VECTOR_ELT || !CurVec.hasOneUse()) - break; + // INSERT_VECTOR_ELT - insert operand and continue up the chain. + if (CurVec.getOpcode() == ISD::INSERT_VECTOR_ELT && CurVec.hasOneUse()) + if (auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2))) + if (CurIdx->getAPIntValue().ult(NumElts)) { + unsigned Idx = CurIdx->getZExtValue(); + AddBuildVectorOp(Ops, CurVec.getOperand(1), Idx); - auto *CurIdx = dyn_cast<ConstantSDNode>(CurVec.getOperand(2)); - if (!CurIdx || CurIdx->getAPIntValue() != ((NumElts - 1) - I)) - break; - SDValue CurVal = CurVec.getOperand(1); - ReverseInsertions.push_back(CurVal); - if (VT.isInteger()) { - EVT CurValVT = CurVal.getValueType(); - MaxEltVT = MaxEltVT.bitsGE(CurValVT) ? MaxEltVT : CurValVT; - } - CurVec = CurVec.getOperand(0); - } + // Found entire BUILD_VECTOR. + if (all_of(Ops, [](SDValue Op) { return !!Op; })) + return CanonicalizeBuildVector(Ops); - if (ReverseInsertions.size() == NumElts) { - for (unsigned I = 0; I != NumElts; ++I) { - SDValue Val = ReverseInsertions[(NumElts - 1) - I]; - Val = VT.isInteger() ? DAG.getAnyExtOrTrunc(Val, DL, MaxEltVT) : Val; - Ops.push_back(Val); - } - return DAG.getBuildVector(VT, DL, Ops); - } + CurVec = CurVec->getOperand(0); + continue; + } + + // Failed to find a match in the chain - bail. + break; } } @@ -22643,6 +22658,56 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { } } + // If we're not performing a select/blend shuffle, see if we can convert the + // shuffle into a AND node, with all the out-of-lane elements are known zero. + if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) { + bool IsInLaneMask = true; + ArrayRef<int> Mask = SVN->getMask(); + SmallVector<int, 16> ClearMask(NumElts, -1); + APInt DemandedLHS = APInt::getNullValue(NumElts); + APInt DemandedRHS = APInt::getNullValue(NumElts); + for (int I = 0; I != (int)NumElts; ++I) { + int M = Mask[I]; + if (M < 0) + continue; + ClearMask[I] = M == I ? I : (I + NumElts); + IsInLaneMask &= (M == I) || (M == (int)(I + NumElts)); + if (M != I) { + APInt &Demanded = M < (int)NumElts ? DemandedLHS : DemandedRHS; + Demanded.setBit(M % NumElts); + } + } + // TODO: Should we try to mask with N1 as well? + if (!IsInLaneMask && + (!DemandedLHS.isNullValue() || !DemandedRHS.isNullValue()) && + (DemandedLHS.isNullValue() || + DAG.MaskedVectorIsZero(N0, DemandedLHS)) && + (DemandedRHS.isNullValue() || + DAG.MaskedVectorIsZero(N1, DemandedRHS))) { + SDLoc DL(N); + EVT IntVT = VT.changeVectorElementTypeToInteger(); + EVT IntSVT = VT.getVectorElementType().changeTypeToInteger(); + SDValue ZeroElt = DAG.getConstant(0, DL, IntSVT); + SDValue AllOnesElt = DAG.getAllOnesConstant(DL, IntSVT); + SmallVector<SDValue, 16> AndMask(NumElts, DAG.getUNDEF(IntSVT)); + for (int I = 0; I != (int)NumElts; ++I) + if (0 <= Mask[I]) + AndMask[I] = Mask[I] == I ? AllOnesElt : ZeroElt; + + // See if a clear mask is legal instead of going via + // XformToShuffleWithZero which loses UNDEF mask elements. + if (TLI.isVectorClearMaskLegal(ClearMask, IntVT)) + return DAG.getBitcast( + VT, DAG.getVectorShuffle(IntVT, DL, DAG.getBitcast(IntVT, N0), + DAG.getConstant(0, DL, IntVT), ClearMask)); + + if (TLI.isOperationLegalOrCustom(ISD::AND, IntVT)) + return DAG.getBitcast( + VT, DAG.getNode(ISD::AND, DL, IntVT, DAG.getBitcast(IntVT, N0), + DAG.getBuildVector(IntVT, DL, AndMask))); + } + } + // Attempt to combine a shuffle of 2 inputs of 'scalar sources' - // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR. if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT)) @@ -23385,10 +23450,14 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG, int Index0, Index1; SDValue Src0 = DAG.getSplatSourceVector(N0, Index0); SDValue Src1 = DAG.getSplatSourceVector(N1, Index1); + // Extract element from splat_vector should be free. + // TODO: use DAG.isSplatValue instead? + bool IsBothSplatVector = N0.getOpcode() == ISD::SPLAT_VECTOR && + N1.getOpcode() == ISD::SPLAT_VECTOR; if (!Src0 || !Src1 || Index0 != Index1 || Src0.getValueType().getVectorElementType() != EltVT || Src1.getValueType().getVectorElementType() != EltVT || - !TLI.isExtractVecEltCheap(VT, Index0) || + !(IsBothSplatVector || TLI.isExtractVecEltCheap(VT, Index0)) || !TLI.isOperationLegalOrCustom(Opcode, EltVT)) return SDValue(); @@ -23410,6 +23479,8 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG, } // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index + if (VT.isScalableVector()) + return DAG.getSplatVector(VT, DL, ScalarBO); SmallVector<SDValue, 8> Ops(VT.getVectorNumElements(), ScalarBO); return DAG.getBuildVector(VT, DL, Ops); } |