diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 1486 |
1 files changed, 855 insertions, 631 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 24ab65171a17..96df20039b15 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -83,7 +83,7 @@ bool TargetLowering::parametersInCSRMatch(const MachineRegisterInfo &MRI, const CCValAssign &ArgLoc = ArgLocs[I]; if (!ArgLoc.isRegLoc()) continue; - Register Reg = ArgLoc.getLocReg(); + MCRegister Reg = ArgLoc.getLocReg(); // Only look at callee saved registers. if (MachineOperand::clobbersPhysReg(CallerPreservedMask, Reg)) continue; @@ -93,7 +93,7 @@ bool TargetLowering::parametersInCSRMatch(const MachineRegisterInfo &MRI, SDValue Value = OutVals[I]; if (Value->getOpcode() != ISD::CopyFromReg) return false; - unsigned ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg(); + MCRegister ArgReg = cast<RegisterSDNode>(Value->getOperand(1))->getReg(); if (MRI.getLiveInPhysReg(ArgReg) != Reg) return false; } @@ -110,14 +110,18 @@ void TargetLoweringBase::ArgListEntry::setAttributes(const CallBase *Call, IsSRet = Call->paramHasAttr(ArgIdx, Attribute::StructRet); IsNest = Call->paramHasAttr(ArgIdx, Attribute::Nest); IsByVal = Call->paramHasAttr(ArgIdx, Attribute::ByVal); + IsPreallocated = Call->paramHasAttr(ArgIdx, Attribute::Preallocated); IsInAlloca = Call->paramHasAttr(ArgIdx, Attribute::InAlloca); IsReturned = Call->paramHasAttr(ArgIdx, Attribute::Returned); IsSwiftSelf = Call->paramHasAttr(ArgIdx, Attribute::SwiftSelf); IsSwiftError = Call->paramHasAttr(ArgIdx, Attribute::SwiftError); - Alignment = Call->getParamAlignment(ArgIdx); + Alignment = Call->getParamAlign(ArgIdx); ByValType = nullptr; - if (Call->paramHasAttr(ArgIdx, Attribute::ByVal)) + if (IsByVal) ByValType = Call->getParamByValType(ArgIdx); + PreallocatedType = nullptr; + if (IsPreallocated) + PreallocatedType = Call->getParamPreallocatedType(ArgIdx); } /// Generate a libcall taking the given operands as arguments and returning a @@ -176,38 +180,24 @@ TargetLowering::makeLibCall(SelectionDAG &DAG, RTLIB::Libcall LC, EVT RetVT, return LowerCallTo(CLI); } -bool -TargetLowering::findOptimalMemOpLowering(std::vector<EVT> &MemOps, - unsigned Limit, uint64_t Size, - unsigned DstAlign, unsigned SrcAlign, - bool IsMemset, - bool ZeroMemset, - bool MemcpyStrSrc, - bool AllowOverlap, - unsigned DstAS, unsigned SrcAS, - const AttributeList &FuncAttributes) const { - // If 'SrcAlign' is zero, that means the memory operation does not need to - // load the value, i.e. memset or memcpy from constant string. Otherwise, - // it's the inferred alignment of the source. 'DstAlign', on the other hand, - // is the specified alignment of the memory operation. If it is zero, that - // means it's possible to change the alignment of the destination. - // 'MemcpyStrSrc' indicates whether the memcpy source is constant so it does - // not need to be loaded. - if (!(SrcAlign == 0 || SrcAlign >= DstAlign)) +bool TargetLowering::findOptimalMemOpLowering( + std::vector<EVT> &MemOps, unsigned Limit, const MemOp &Op, unsigned DstAS, + unsigned SrcAS, const AttributeList &FuncAttributes) const { + if (Op.isMemcpyWithFixedDstAlign() && Op.getSrcAlign() < Op.getDstAlign()) return false; - EVT VT = getOptimalMemOpType(Size, DstAlign, SrcAlign, - IsMemset, ZeroMemset, MemcpyStrSrc, - FuncAttributes); + EVT VT = getOptimalMemOpType(Op, FuncAttributes); if (VT == MVT::Other) { // Use the largest integer type whose alignment constraints are satisfied. // We only need to check DstAlign here as SrcAlign is always greater or // equal to DstAlign (or zero). VT = MVT::i64; - while (DstAlign && DstAlign < VT.getSizeInBits() / 8 && - !allowsMisalignedMemoryAccesses(VT, DstAS, DstAlign)) - VT = (MVT::SimpleValueType)(VT.getSimpleVT().SimpleTy - 1); + if (Op.isFixedDstAlign()) + while ( + Op.getDstAlign() < (VT.getSizeInBits() / 8) && + !allowsMisalignedMemoryAccesses(VT, DstAS, Op.getDstAlign().value())) + VT = (MVT::SimpleValueType)(VT.getSimpleVT().SimpleTy - 1); assert(VT.isInteger()); // Find the largest legal integer type. @@ -223,7 +213,8 @@ TargetLowering::findOptimalMemOpLowering(std::vector<EVT> &MemOps, } unsigned NumMemOps = 0; - while (Size != 0) { + uint64_t Size = Op.size(); + while (Size) { unsigned VTSize = VT.getSizeInBits() / 8; while (VTSize > Size) { // For now, only use non-vector load / store's for the left-over pieces. @@ -257,9 +248,10 @@ TargetLowering::findOptimalMemOpLowering(std::vector<EVT> &MemOps, // If the new VT cannot cover all of the remaining bits, then consider // issuing a (or a pair of) unaligned and overlapping load / store. bool Fast; - if (NumMemOps && AllowOverlap && NewVTSize < Size && - allowsMisalignedMemoryAccesses(VT, DstAS, DstAlign, - MachineMemOperand::MONone, &Fast) && + if (NumMemOps && Op.allowOverlap() && NewVTSize < Size && + allowsMisalignedMemoryAccesses( + VT, DstAS, Op.isFixedDstAlign() ? Op.getDstAlign().value() : 0, + MachineMemOperand::MONone, &Fast) && Fast) VTSize = Size; else { @@ -491,13 +483,15 @@ TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const { /// If the specified instruction has a constant integer operand and there are /// bits set in that constant that are not demanded, then clear those bits and /// return true. -bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, +bool TargetLowering::ShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { SDLoc DL(Op); unsigned Opcode = Op.getOpcode(); // Do target-specific constant optimization. - if (targetShrinkDemandedConstant(Op, Demanded, TLO)) + if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return TLO.New.getNode(); // FIXME: ISD::SELECT, ISD::SELECT_CC @@ -513,12 +507,12 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, // If this is a 'not' op, don't touch it because that's a canonical form. const APInt &C = Op1C->getAPIntValue(); - if (Opcode == ISD::XOR && Demanded.isSubsetOf(C)) + if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C)) return false; - if (!C.isSubsetOf(Demanded)) { + if (!C.isSubsetOf(DemandedBits)) { EVT VT = Op.getValueType(); - SDValue NewC = TLO.DAG.getConstant(Demanded & C, DL, VT); + SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT); SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC); return TLO.CombineTo(Op, NewOp); } @@ -530,6 +524,16 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, return false; } +bool TargetLowering::ShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + TargetLoweringOpt &TLO) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO); +} + /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free. /// This uses isZExtFree and ZERO_EXTEND for the widening cast, but it could be /// generalized for targets with other types of implicit widening casts. @@ -598,6 +602,16 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits, unsigned Depth, bool AssumeSingleUse) const { EVT VT = Op.getValueType(); + + // TODO: We can probably do more work on calculating the known bits and + // simplifying the operations for scalable vectors, but for now we just + // bail out. + if (VT.isScalableVector()) { + // Pretend we don't know anything for now. + Known = KnownBits(DemandedBits.getBitWidth()); + return false; + } + APInt DemandedElts = VT.isVector() ? APInt::getAllOnesValue(VT.getVectorNumElements()) : APInt(1, 1); @@ -623,15 +637,18 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( return DAG.getUNDEF(Op.getValueType()); unsigned NumElts = DemandedElts.getBitWidth(); + unsigned BitWidth = DemandedBits.getBitWidth(); KnownBits LHSKnown, RHSKnown; switch (Op.getOpcode()) { case ISD::BITCAST: { SDValue Src = peekThroughBitcasts(Op.getOperand(0)); EVT SrcVT = Src.getValueType(); EVT DstVT = Op.getValueType(); + if (SrcVT == DstVT) + return Src; + unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits(); unsigned NumDstEltBits = DstVT.getScalarSizeInBits(); - if (NumSrcEltBits == NumDstEltBits) if (SDValue V = SimplifyMultipleUseDemandedBits( Src, DemandedBits, DemandedElts, DAG, Depth + 1)) @@ -719,6 +736,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( return Op.getOperand(1); break; } + case ISD::SHL: { + // If we are only demanding sign bits then we can use the shift source + // directly. + if (const APInt *MaxSA = + DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) { + SDValue Op0 = Op.getOperand(0); + unsigned ShAmt = MaxSA->getZExtValue(); + unsigned NumSignBits = + DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1); + unsigned UpperDemandedBits = BitWidth - DemandedBits.countTrailingZeros(); + if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits)) + return Op0; + } + break; + } case ISD::SETCC: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); @@ -727,7 +759,7 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( // width as the setcc result, and (3) the result of a setcc conforms to 0 or // -1, we may be able to bypass the setcc. if (DemandedBits.isSignMask() && - Op0.getScalarValueSizeInBits() == DemandedBits.getBitWidth() && + Op0.getScalarValueSizeInBits() == BitWidth && getBooleanContents(Op0.getValueType()) == BooleanContent::ZeroOrNegativeOneBooleanContent) { // If we're testing X < 0, then this compare isn't needed - just use X! @@ -742,9 +774,30 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( } case ISD::SIGN_EXTEND_INREG: { // If none of the extended bits are demanded, eliminate the sextinreg. + SDValue Op0 = Op.getOperand(0); EVT ExVT = cast<VTSDNode>(Op.getOperand(1))->getVT(); - if (DemandedBits.getActiveBits() <= ExVT.getScalarSizeInBits()) - return Op.getOperand(0); + unsigned ExBits = ExVT.getScalarSizeInBits(); + if (DemandedBits.getActiveBits() <= ExBits) + return Op0; + // If the input is already sign extended, just drop the extension. + unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1); + if (NumSignBits >= (BitWidth - ExBits + 1)) + return Op0; + break; + } + case ISD::ANY_EXTEND_VECTOR_INREG: + case ISD::SIGN_EXTEND_VECTOR_INREG: + case ISD::ZERO_EXTEND_VECTOR_INREG: { + // If we only want the lowest element and none of extended bits, then we can + // return the bitcasted source vector. + SDValue Src = Op.getOperand(0); + EVT SrcVT = Src.getValueType(); + EVT DstVT = Op.getValueType(); + if (DemandedElts == 1 && DstVT.getSizeInBits() == SrcVT.getSizeInBits() && + DAG.getDataLayout().isLittleEndian() && + DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) { + return DAG.getBitcast(DstVT, Src); + } break; } case ISD::INSERT_VECTOR_ELT: { @@ -757,6 +810,16 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( return Vec; break; } + case ISD::INSERT_SUBVECTOR: { + // If we don't demand the inserted subvector, return the base vector. + SDValue Vec = Op.getOperand(0); + SDValue Sub = Op.getOperand(1); + uint64_t Idx = Op.getConstantOperandVal(2); + unsigned NumSubElts = Sub.getValueType().getVectorNumElements(); + if (DemandedElts.extractBits(NumSubElts, Idx) == 0) + return Vec; + break; + } case ISD::VECTOR_SHUFFLE: { ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask(); @@ -790,6 +853,25 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits( return SDValue(); } +SDValue TargetLowering::SimplifyMultipleUseDemandedBits( + SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG, + unsigned Depth) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG, + Depth); +} + +SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts( + SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG, + unsigned Depth) const { + APInt DemandedBits = APInt::getAllOnesValue(Op.getScalarValueSizeInBits()); + return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG, + Depth); +} + /// Look at Op. At this point, we know that only the OriginalDemandedBits of the /// result of Op are ever used downstream. If we can use this information to /// simplify Op, create a new simplified DAG node and return true, returning the @@ -805,6 +887,15 @@ bool TargetLowering::SimplifyDemandedBits( assert(Op.getScalarValueSizeInBits() == BitWidth && "Mask size mismatches value type size!"); + // Don't know anything. + Known = KnownBits(BitWidth); + + // TODO: We can probably do more work on calculating the known bits and + // simplifying the operations for scalable vectors, but for now we just + // bail out. + if (Op.getValueType().isScalableVector()) + return false; + unsigned NumElts = OriginalDemandedElts.getBitWidth(); assert((!Op.getValueType().isVector() || NumElts == Op.getValueType().getVectorNumElements()) && @@ -815,9 +906,6 @@ bool TargetLowering::SimplifyDemandedBits( SDLoc dl(Op); auto &DL = TLO.DAG.getDataLayout(); - // Don't know anything. - Known = KnownBits(BitWidth); - // Undef operand. if (Op.isUndef()) return false; @@ -850,7 +938,7 @@ bool TargetLowering::SimplifyDemandedBits( return false; } - KnownBits Known2, KnownOut; + KnownBits Known2; switch (Op.getOpcode()) { case ISD::TargetConstant: llvm_unreachable("Can't simplify this node"); @@ -864,7 +952,11 @@ bool TargetLowering::SimplifyDemandedBits( APInt SrcDemandedBits = DemandedBits.zextOrSelf(SrcBitWidth); if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcKnown, TLO, Depth + 1)) return true; - Known = SrcKnown.zextOrTrunc(BitWidth, false); + + // Upper elements are undef, so only get the knownbits if we just demand + // the bottom element. + if (DemandedElts == 1) + Known = SrcKnown.anyextOrTrunc(BitWidth); break; } case ISD::BUILD_VECTOR: @@ -877,6 +969,12 @@ bool TargetLowering::SimplifyDemandedBits( if (getTargetConstantFromLoad(LD)) { Known = TLO.DAG.computeKnownBits(Op, DemandedElts, Depth); return false; // Don't fall through, will infinitely loop. + } else if (ISD::isZEXTLoad(Op.getNode()) && Op.getResNo() == 0) { + // If this is a ZEXTLoad and we are looking at the loaded value. + EVT MemVT = LD->getMemoryVT(); + unsigned MemBits = MemVT.getScalarSizeInBits(); + Known.Zero.setBitsFrom(MemBits); + return false; // Don't fall through, will infinitely loop. } break; } @@ -904,7 +1002,7 @@ bool TargetLowering::SimplifyDemandedBits( if (SimplifyDemandedBits(Scl, DemandedSclBits, KnownScl, TLO, Depth + 1)) return true; - Known = KnownScl.zextOrTrunc(BitWidth, false); + Known = KnownScl.anyextOrTrunc(BitWidth); KnownBits KnownVec; if (SimplifyDemandedBits(Vec, DemandedBits, DemandedVecElts, KnownVec, TLO, @@ -919,57 +1017,75 @@ bool TargetLowering::SimplifyDemandedBits( return false; } case ISD::INSERT_SUBVECTOR: { - SDValue Base = Op.getOperand(0); + // Demand any elements from the subvector and the remainder from the src its + // inserted into. + SDValue Src = Op.getOperand(0); SDValue Sub = Op.getOperand(1); - EVT SubVT = Sub.getValueType(); - unsigned NumSubElts = SubVT.getVectorNumElements(); - - // If index isn't constant, assume we need the original demanded base - // elements and ALL the inserted subvector elements. - APInt BaseElts = DemandedElts; - APInt SubElts = APInt::getAllOnesValue(NumSubElts); - if (isa<ConstantSDNode>(Op.getOperand(2))) { - const APInt &Idx = Op.getConstantOperandAPInt(2); - if (Idx.ule(NumElts - NumSubElts)) { - unsigned SubIdx = Idx.getZExtValue(); - SubElts = DemandedElts.extractBits(NumSubElts, SubIdx); - BaseElts.insertBits(APInt::getNullValue(NumSubElts), SubIdx); - } - } - - KnownBits KnownSub, KnownBase; - if (SimplifyDemandedBits(Sub, DemandedBits, SubElts, KnownSub, TLO, + uint64_t Idx = Op.getConstantOperandVal(2); + unsigned NumSubElts = Sub.getValueType().getVectorNumElements(); + APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx); + APInt DemandedSrcElts = DemandedElts; + DemandedSrcElts.insertBits(APInt::getNullValue(NumSubElts), Idx); + + KnownBits KnownSub, KnownSrc; + if (SimplifyDemandedBits(Sub, DemandedBits, DemandedSubElts, KnownSub, TLO, Depth + 1)) return true; - if (SimplifyDemandedBits(Base, DemandedBits, BaseElts, KnownBase, TLO, + if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, KnownSrc, TLO, Depth + 1)) return true; Known.Zero.setAllBits(); Known.One.setAllBits(); - if (!!SubElts) { - Known.One &= KnownSub.One; - Known.Zero &= KnownSub.Zero; + if (!!DemandedSubElts) { + Known.One &= KnownSub.One; + Known.Zero &= KnownSub.Zero; } - if (!!BaseElts) { - Known.One &= KnownBase.One; - Known.Zero &= KnownBase.Zero; + if (!!DemandedSrcElts) { + Known.One &= KnownSrc.One; + Known.Zero &= KnownSrc.Zero; + } + + // Attempt to avoid multi-use src if we don't need anything from it. + if (!DemandedBits.isAllOnesValue() || !DemandedSubElts.isAllOnesValue() || + !DemandedSrcElts.isAllOnesValue()) { + SDValue NewSub = SimplifyMultipleUseDemandedBits( + Sub, DemandedBits, DemandedSubElts, TLO.DAG, Depth + 1); + SDValue NewSrc = SimplifyMultipleUseDemandedBits( + Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1); + if (NewSub || NewSrc) { + NewSub = NewSub ? NewSub : Sub; + NewSrc = NewSrc ? NewSrc : Src; + SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc, NewSub, + Op.getOperand(2)); + return TLO.CombineTo(Op, NewOp); + } } break; } case ISD::EXTRACT_SUBVECTOR: { - // If index isn't constant, assume we need all the source vector elements. + // Offset the demanded elts by the subvector index. SDValue Src = Op.getOperand(0); - ConstantSDNode *SubIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1)); + if (Src.getValueType().isScalableVector()) + break; + uint64_t Idx = Op.getConstantOperandVal(1); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); - APInt SrcElts = APInt::getAllOnesValue(NumSrcElts); - if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { - // Offset the demanded elts by the subvector index. - uint64_t Idx = SubIdx->getZExtValue(); - SrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); - } - if (SimplifyDemandedBits(Src, DemandedBits, SrcElts, Known, TLO, Depth + 1)) + APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + + if (SimplifyDemandedBits(Src, DemandedBits, DemandedSrcElts, Known, TLO, + Depth + 1)) return true; + + // Attempt to avoid multi-use src if we don't need anything from it. + if (!DemandedBits.isAllOnesValue() || !DemandedSrcElts.isAllOnesValue()) { + SDValue DemandedSrc = SimplifyMultipleUseDemandedBits( + Src, DemandedBits, DemandedSrcElts, TLO.DAG, Depth + 1); + if (DemandedSrc) { + SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), dl, VT, DemandedSrc, + Op.getOperand(1)); + return TLO.CombineTo(Op, NewOp); + } + } break; } case ISD::CONCAT_VECTORS: { @@ -1069,7 +1185,8 @@ bool TargetLowering::SimplifyDemandedBits( // If any of the set bits in the RHS are known zero on the LHS, shrink // the constant. - if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, + DemandedElts, TLO)) return true; // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its @@ -1117,16 +1234,14 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT)); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts, + TLO)) return true; // If the operation can be done in a smaller type, do so. if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; - // Output known-1 bits are only known if set in both the LHS & RHS. - Known.One &= Known2.One; - // Output known-0 are known to be clear if zero in either the LHS | RHS. - Known.Zero |= Known2.Zero; + Known &= Known2; break; } case ISD::OR: { @@ -1163,16 +1278,13 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isSubsetOf(Known.One | Known2.Zero)) return TLO.CombineTo(Op, Op1); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // If the operation can be done in a smaller type, do so. if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) return true; - // Output known-0 bits are only known if clear in both the LHS & RHS. - Known.Zero &= Known2.Zero; - // Output known-1 are known to be set if set in either the LHS | RHS. - Known.One |= Known2.One; + Known |= Known2; break; } case ISD::XOR: { @@ -1218,12 +1330,8 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::OR, dl, VT, Op0, Op1)); - // Output known-0 bits are known if clear or set in both the LHS & RHS. - KnownOut.Zero = (Known.Zero & Known2.Zero) | (Known.One & Known2.One); - // Output known-1 are known to be set if set in only one of the LHS, RHS. - KnownOut.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero); - - if (ConstantSDNode *C = isConstOrConstSplat(Op1)) { + ConstantSDNode* C = isConstOrConstSplat(Op1, DemandedElts); + if (C) { // If one side is a constant, and all of the known set bits on the other // side are also set in the constant, turn this into an AND, as we know // the bits will be cleared. @@ -1238,19 +1346,20 @@ bool TargetLowering::SimplifyDemandedBits( // If the RHS is a constant, see if we can change it. Don't alter a -1 // constant because that's a 'not' op, and that is better for combining // and codegen. - if (!C->isAllOnesValue()) { - if (DemandedBits.isSubsetOf(C->getAPIntValue())) { - // We're flipping all demanded bits. Flip the undemanded bits too. - SDValue New = TLO.DAG.getNOT(dl, Op0, VT); - return TLO.CombineTo(Op, New); - } - // If we can't turn this into a 'not', try to shrink the constant. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) - return true; + if (!C->isAllOnesValue() && + DemandedBits.isSubsetOf(C->getAPIntValue())) { + // We're flipping all demanded bits. Flip the undemanded bits too. + SDValue New = TLO.DAG.getNOT(dl, Op0, VT); + return TLO.CombineTo(Op, New); } } - Known = std::move(KnownOut); + // If we can't turn this into a 'not', try to shrink the constant. + if (!C || !C->isAllOnesValue()) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) + return true; + + Known ^= Known2; break; } case ISD::SELECT: @@ -1264,7 +1373,7 @@ bool TargetLowering::SimplifyDemandedBits( assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // Only known if known in both the LHS and RHS. @@ -1282,7 +1391,7 @@ bool TargetLowering::SimplifyDemandedBits( assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // Only known if known in both the LHS and RHS. @@ -1320,12 +1429,10 @@ bool TargetLowering::SimplifyDemandedBits( case ISD::SHL: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + EVT ShiftVT = Op1.getValueType(); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { - // If the shift count is an invalid immediate, don't do anything. - if (SA->getAPIntValue().uge(BitWidth)) - break; - + if (const APInt *SA = + TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) { unsigned ShAmt = SA->getZExtValue(); if (ShAmt == 0) return TLO.CombineTo(Op, Op0); @@ -1336,37 +1443,25 @@ bool TargetLowering::SimplifyDemandedBits( // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SRL) { if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) { - if (ConstantSDNode *SA2 = - isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) { - if (SA2->getAPIntValue().ult(BitWidth)) { - unsigned C1 = SA2->getZExtValue(); - unsigned Opc = ISD::SHL; - int Diff = ShAmt - C1; - if (Diff < 0) { - Diff = -Diff; - Opc = ISD::SRL; - } - - SDValue NewSA = TLO.DAG.getConstant(Diff, dl, Op1.getValueType()); - return TLO.CombineTo( - Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); + if (const APInt *SA2 = + TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) { + unsigned C1 = SA2->getZExtValue(); + unsigned Opc = ISD::SHL; + int Diff = ShAmt - C1; + if (Diff < 0) { + Diff = -Diff; + Opc = ISD::SRL; } + SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT); + return TLO.CombineTo( + Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); } } } - if (SimplifyDemandedBits(Op0, DemandedBits.lshr(ShAmt), DemandedElts, - Known, TLO, Depth + 1)) - return true; - - // Try shrinking the operation as long as the shift amount will still be - // in range. - if ((ShAmt < DemandedBits.getActiveBits()) && - ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) - return true; - // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits // are not demanded. This will likely allow the anyext to be folded away. + // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::ANY_EXTEND) { SDValue InnerOp = Op0.getOperand(0); EVT InnerVT = InnerOp.getValueType(); @@ -1382,22 +1477,24 @@ bool TargetLowering::SimplifyDemandedBits( return TLO.CombineTo( Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl)); } + // Repeat the SHL optimization above in cases where an extension // intervenes: (shl (anyext (shr x, c1)), c2) to // (shl (anyext x), c2-c1). This requires that the bottom c1 bits // aren't demanded (as above) and that the shifted upper c1 bits of // x aren't demanded. + // TODO - support non-uniform vector amounts. if (Op0.hasOneUse() && InnerOp.getOpcode() == ISD::SRL && InnerOp.hasOneUse()) { - if (ConstantSDNode *SA2 = - isConstOrConstSplat(InnerOp.getOperand(1))) { - unsigned InnerShAmt = SA2->getLimitedValue(InnerBits); + if (const APInt *SA2 = + TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) { + unsigned InnerShAmt = SA2->getZExtValue(); if (InnerShAmt < ShAmt && InnerShAmt < InnerBits && DemandedBits.getActiveBits() <= (InnerBits - InnerShAmt + ShAmt) && DemandedBits.countTrailingZeros() >= ShAmt) { - SDValue NewSA = TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, - Op1.getValueType()); + SDValue NewSA = + TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT); SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, InnerOp.getOperand(0)); return TLO.CombineTo( @@ -1407,60 +1504,76 @@ bool TargetLowering::SimplifyDemandedBits( } } + APInt InDemandedMask = DemandedBits.lshr(ShAmt); + if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, + Depth + 1)) + return true; + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero <<= ShAmt; Known.One <<= ShAmt; // low bits known zero. Known.Zero.setLowBits(ShAmt); + + // Try shrinking the operation as long as the shift amount will still be + // in range. + if ((ShAmt < DemandedBits.getActiveBits()) && + ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) + return true; + } + + // If we are only demanding sign bits then we can use the shift source + // directly. + if (const APInt *MaxSA = + TLO.DAG.getValidMaximumShiftAmountConstant(Op, DemandedElts)) { + unsigned ShAmt = MaxSA->getZExtValue(); + unsigned NumSignBits = + TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1); + unsigned UpperDemandedBits = BitWidth - DemandedBits.countTrailingZeros(); + if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits)) + return TLO.CombineTo(Op, Op0); } break; } case ISD::SRL: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + EVT ShiftVT = Op1.getValueType(); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { - // If the shift count is an invalid immediate, don't do anything. - if (SA->getAPIntValue().uge(BitWidth)) - break; - + if (const APInt *SA = + TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) { unsigned ShAmt = SA->getZExtValue(); if (ShAmt == 0) return TLO.CombineTo(Op, Op0); - EVT ShiftVT = Op1.getValueType(); - APInt InDemandedMask = (DemandedBits << ShAmt); - - // If the shift is exact, then it does demand the low bits (and knows that - // they are zero). - if (Op->getFlags().hasExact()) - InDemandedMask.setLowBits(ShAmt); - // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a // single shift. We can do this if the top bits (which are shifted out) // are never demanded. // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SHL) { - if (ConstantSDNode *SA2 = - isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) { - if (!DemandedBits.intersects( - APInt::getHighBitsSet(BitWidth, ShAmt))) { - if (SA2->getAPIntValue().ult(BitWidth)) { - unsigned C1 = SA2->getZExtValue(); - unsigned Opc = ISD::SRL; - int Diff = ShAmt - C1; - if (Diff < 0) { - Diff = -Diff; - Opc = ISD::SHL; - } - - SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT); - return TLO.CombineTo( - Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); + if (!DemandedBits.intersects(APInt::getHighBitsSet(BitWidth, ShAmt))) { + if (const APInt *SA2 = + TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) { + unsigned C1 = SA2->getZExtValue(); + unsigned Opc = ISD::SRL; + int Diff = ShAmt - C1; + if (Diff < 0) { + Diff = -Diff; + Opc = ISD::SHL; } + SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT); + return TLO.CombineTo( + Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA)); } } } + APInt InDemandedMask = (DemandedBits << ShAmt); + + // If the shift is exact, then it does demand the low bits (and knows that + // they are zero). + if (Op->getFlags().hasExact()) + InDemandedMask.setLowBits(ShAmt); + // Compute the new bits that are at the top now. if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO, Depth + 1)) @@ -1468,14 +1581,22 @@ bool TargetLowering::SimplifyDemandedBits( assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShAmt); Known.One.lshrInPlace(ShAmt); - - Known.Zero.setHighBits(ShAmt); // High bits known zero. + // High bits known zero. + Known.Zero.setHighBits(ShAmt); } break; } case ISD::SRA: { SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); + EVT ShiftVT = Op1.getValueType(); + + // If we only want bits that already match the signbit then we don't need + // to shift. + unsigned NumHiDemandedBits = BitWidth - DemandedBits.countTrailingZeros(); + if (TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1) >= + NumHiDemandedBits) + return TLO.CombineTo(Op, Op0); // If this is an arithmetic shift right and only the low-bit is set, we can // always convert this into a logical shr, even if the shift amount is @@ -1484,11 +1605,8 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isOneValue()) return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1)); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { - // If the shift count is an invalid immediate, don't do anything. - if (SA->getAPIntValue().uge(BitWidth)) - break; - + if (const APInt *SA = + TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) { unsigned ShAmt = SA->getZExtValue(); if (ShAmt == 0) return TLO.CombineTo(Op, Op0); @@ -1525,14 +1643,23 @@ bool TargetLowering::SimplifyDemandedBits( int Log2 = DemandedBits.exactLogBase2(); if (Log2 >= 0) { // The bit must come from the sign. - SDValue NewSA = - TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, Op1.getValueType()); + SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT); return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA)); } if (Known.One[BitWidth - ShAmt - 1]) // New bits are known one. Known.One.setHighBits(ShAmt); + + // Attempt to avoid multi-use ops if we don't need anything from them. + if (!InDemandedMask.isAllOnesValue() || !DemandedElts.isAllOnesValue()) { + SDValue DemandedOp0 = SimplifyMultipleUseDemandedBits( + Op0, InDemandedMask, DemandedElts, TLO.DAG, Depth + 1); + if (DemandedOp0) { + SDValue NewOp = TLO.DAG.getNode(ISD::SRA, dl, VT, DemandedOp0, Op1); + return TLO.CombineTo(Op, NewOp); + } + } } break; } @@ -1573,6 +1700,32 @@ bool TargetLowering::SimplifyDemandedBits( Known.One |= Known2.One; Known.Zero |= Known2.Zero; } + + // For pow-2 bitwidths we only demand the bottom modulo amt bits. + if (isPowerOf2_32(BitWidth)) { + APInt DemandedAmtBits(Op2.getScalarValueSizeInBits(), BitWidth - 1); + if (SimplifyDemandedBits(Op2, DemandedAmtBits, DemandedElts, + Known2, TLO, Depth + 1)) + return true; + } + break; + } + case ISD::ROTL: + case ISD::ROTR: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + + // If we're rotating an 0/-1 value, then it stays an 0/-1 value. + if (BitWidth == TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1)) + return TLO.CombineTo(Op, Op0); + + // For pow-2 bitwidths we only demand the bottom modulo amt bits. + if (isPowerOf2_32(BitWidth)) { + APInt DemandedAmtBits(Op1.getScalarValueSizeInBits(), BitWidth - 1); + if (SimplifyDemandedBits(Op1, DemandedAmtBits, DemandedElts, Known2, TLO, + Depth + 1)) + return true; + } break; } case ISD::BITREVERSE: { @@ -1602,7 +1755,8 @@ bool TargetLowering::SimplifyDemandedBits( // If we only care about the highest bit, don't bother shifting right. if (DemandedBits.isSignMask()) { - unsigned NumSignBits = TLO.DAG.ComputeNumSignBits(Op0); + unsigned NumSignBits = + TLO.DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1); bool AlreadySignExtended = NumSignBits >= BitWidth - ExVTBits + 1; // However if the input is already sign extended we expect the sign // extension to be dropped altogether later and do not simplify. @@ -1639,8 +1793,7 @@ bool TargetLowering::SimplifyDemandedBits( // If the input sign bit is known zero, convert this into a zero extension. if (Known.Zero[ExVTBits - 1]) - return TLO.CombineTo( - Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT.getScalarType())); + return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT)); APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits); if (Known.One[ExVTBits - 1]) { // Input sign bit known set @@ -1704,7 +1857,7 @@ bool TargetLowering::SimplifyDemandedBits( return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); assert(Known.getBitWidth() == InBits && "Src width has changed?"); - Known = Known.zext(BitWidth, true /* ExtendedBitsAreKnownZero */); + Known = Known.zext(BitWidth); break; } case ISD::SIGN_EXTEND: @@ -1777,7 +1930,12 @@ bool TargetLowering::SimplifyDemandedBits( return true; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); assert(Known.getBitWidth() == InBits && "Src width has changed?"); - Known = Known.zext(BitWidth, false /* => any extend */); + Known = Known.anyext(BitWidth); + + // Attempt to avoid multi-use ops if we don't need anything from them. + if (SDValue NewSrc = SimplifyMultipleUseDemandedBits( + Src, InDemandedBits, InDemandedElts, TLO.DAG, Depth + 1)) + return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc)); break; } case ISD::TRUNCATE: { @@ -1886,7 +2044,7 @@ bool TargetLowering::SimplifyDemandedBits( Known = Known2; if (BitWidth > EltBitWidth) - Known = Known.zext(BitWidth, false /* => any extend */); + Known = Known.anyext(BitWidth); break; } case ISD::BITCAST: { @@ -2151,14 +2309,20 @@ bool TargetLowering::SimplifyDemandedVectorElts( APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const { EVT VT = Op.getValueType(); + unsigned Opcode = Op.getOpcode(); APInt DemandedElts = OriginalDemandedElts; unsigned NumElts = DemandedElts.getBitWidth(); assert(VT.isVector() && "Expected vector op"); - assert(VT.getVectorNumElements() == NumElts && - "Mask size mismatches value type element count!"); KnownUndef = KnownZero = APInt::getNullValue(NumElts); + // TODO: For now we assume we know nothing about scalable vectors. + if (VT.isScalableVector()) + return false; + + assert(VT.getVectorNumElements() == NumElts && + "Mask size mismatches value type element count!"); + // Undef operand. if (Op.isUndef()) { KnownUndef.setAllBits(); @@ -2182,7 +2346,22 @@ bool TargetLowering::SimplifyDemandedVectorElts( SDLoc DL(Op); unsigned EltSizeInBits = VT.getScalarSizeInBits(); - switch (Op.getOpcode()) { + // Helper for demanding the specified elements and all the bits of both binary + // operands. + auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) { + SDValue NewOp0 = SimplifyMultipleUseDemandedVectorElts(Op0, DemandedElts, + TLO.DAG, Depth + 1); + SDValue NewOp1 = SimplifyMultipleUseDemandedVectorElts(Op1, DemandedElts, + TLO.DAG, Depth + 1); + if (NewOp0 || NewOp1) { + SDValue NewOp = TLO.DAG.getNode( + Opcode, SDLoc(Op), VT, NewOp0 ? NewOp0 : Op0, NewOp1 ? NewOp1 : Op1); + return TLO.CombineTo(Op, NewOp); + } + return false; + }; + + switch (Opcode) { case ISD::SCALAR_TO_VECTOR: { if (!DemandedElts[0]) { KnownUndef.setAllBits(); @@ -2234,7 +2413,8 @@ bool TargetLowering::SimplifyDemandedVectorElts( } KnownBits Known; - if (SimplifyDemandedBits(Src, SrcDemandedBits, Known, TLO, Depth + 1)) + if (SimplifyDemandedBits(Src, SrcDemandedBits, SrcDemandedElts, Known, + TLO, Depth + 1)) return true; } @@ -2323,53 +2503,75 @@ bool TargetLowering::SimplifyDemandedVectorElts( break; } case ISD::INSERT_SUBVECTOR: { - if (!isa<ConstantSDNode>(Op.getOperand(2))) - break; - SDValue Base = Op.getOperand(0); + // Demand any elements from the subvector and the remainder from the src its + // inserted into. + SDValue Src = Op.getOperand(0); SDValue Sub = Op.getOperand(1); - EVT SubVT = Sub.getValueType(); - unsigned NumSubElts = SubVT.getVectorNumElements(); - const APInt &Idx = Op.getConstantOperandAPInt(2); - if (Idx.ugt(NumElts - NumSubElts)) - break; - unsigned SubIdx = Idx.getZExtValue(); - APInt SubElts = DemandedElts.extractBits(NumSubElts, SubIdx); + uint64_t Idx = Op.getConstantOperandVal(2); + unsigned NumSubElts = Sub.getValueType().getVectorNumElements(); + APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx); + APInt DemandedSrcElts = DemandedElts; + DemandedSrcElts.insertBits(APInt::getNullValue(NumSubElts), Idx); + APInt SubUndef, SubZero; - if (SimplifyDemandedVectorElts(Sub, SubElts, SubUndef, SubZero, TLO, + if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO, Depth + 1)) return true; - APInt BaseElts = DemandedElts; - BaseElts.insertBits(APInt::getNullValue(NumSubElts), SubIdx); - - // If none of the base operand elements are demanded, replace it with undef. - if (!BaseElts && !Base.isUndef()) - return TLO.CombineTo(Op, - TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, - TLO.DAG.getUNDEF(VT), - Op.getOperand(1), - Op.getOperand(2))); - - if (SimplifyDemandedVectorElts(Base, BaseElts, KnownUndef, KnownZero, TLO, - Depth + 1)) + + // If none of the src operand elements are demanded, replace it with undef. + if (!DemandedSrcElts && !Src.isUndef()) + return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, + TLO.DAG.getUNDEF(VT), Sub, + Op.getOperand(2))); + + if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownUndef, KnownZero, + TLO, Depth + 1)) return true; - KnownUndef.insertBits(SubUndef, SubIdx); - KnownZero.insertBits(SubZero, SubIdx); + KnownUndef.insertBits(SubUndef, Idx); + KnownZero.insertBits(SubZero, Idx); + + // Attempt to avoid multi-use ops if we don't need anything from them. + if (!DemandedSrcElts.isAllOnesValue() || + !DemandedSubElts.isAllOnesValue()) { + SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts( + Src, DemandedSrcElts, TLO.DAG, Depth + 1); + SDValue NewSub = SimplifyMultipleUseDemandedVectorElts( + Sub, DemandedSubElts, TLO.DAG, Depth + 1); + if (NewSrc || NewSub) { + NewSrc = NewSrc ? NewSrc : Src; + NewSub = NewSub ? NewSub : Sub; + SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc, + NewSub, Op.getOperand(2)); + return TLO.CombineTo(Op, NewOp); + } + } break; } case ISD::EXTRACT_SUBVECTOR: { + // Offset the demanded elts by the subvector index. SDValue Src = Op.getOperand(0); - ConstantSDNode *SubIdx = dyn_cast<ConstantSDNode>(Op.getOperand(1)); + if (Src.getValueType().isScalableVector()) + break; + uint64_t Idx = Op.getConstantOperandVal(1); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); - if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { - // Offset the demanded elts by the subvector index. - uint64_t Idx = SubIdx->getZExtValue(); - APInt SrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); - APInt SrcUndef, SrcZero; - if (SimplifyDemandedVectorElts(Src, SrcElts, SrcUndef, SrcZero, TLO, - Depth + 1)) - return true; - KnownUndef = SrcUndef.extractBits(NumElts, Idx); - KnownZero = SrcZero.extractBits(NumElts, Idx); + APInt DemandedSrcElts = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + + APInt SrcUndef, SrcZero; + if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, SrcUndef, SrcZero, TLO, + Depth + 1)) + return true; + KnownUndef = SrcUndef.extractBits(NumElts, Idx); + KnownZero = SrcZero.extractBits(NumElts, Idx); + + // Attempt to avoid multi-use ops if we don't need anything from them. + if (!DemandedElts.isAllOnesValue()) { + SDValue NewSrc = SimplifyMultipleUseDemandedVectorElts( + Src, DemandedSrcElts, TLO.DAG, Depth + 1); + if (NewSrc) { + SDValue NewOp = TLO.DAG.getNode(Op.getOpcode(), SDLoc(Op), VT, NewSrc, + Op.getOperand(1)); + return TLO.CombineTo(Op, NewOp); + } } break; } @@ -2538,7 +2740,7 @@ bool TargetLowering::SimplifyDemandedVectorElts( break; } - // TODO: There are more binop opcodes that could be handled here - MUL, MIN, + // TODO: There are more binop opcodes that could be handled here - MIN, // MAX, saturated math, etc. case ISD::OR: case ISD::XOR: @@ -2549,17 +2751,26 @@ bool TargetLowering::SimplifyDemandedVectorElts( case ISD::FMUL: case ISD::FDIV: case ISD::FREM: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + APInt UndefRHS, ZeroRHS; - if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, UndefRHS, - ZeroRHS, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO, + Depth + 1)) return true; APInt UndefLHS, ZeroLHS; - if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, UndefLHS, - ZeroLHS, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO, + Depth + 1)) return true; KnownZero = ZeroLHS & ZeroRHS; KnownUndef = getKnownUndefForVectorBinop(Op, TLO.DAG, UndefLHS, UndefRHS); + + // Attempt to avoid multi-use ops if we don't need anything from them. + // TODO - use KnownUndef to relax the demandedelts? + if (!DemandedElts.isAllOnesValue()) + if (SimplifyDemandedVectorEltsBinOp(Op0, Op1)) + return true; break; } case ISD::SHL: @@ -2567,27 +2778,39 @@ bool TargetLowering::SimplifyDemandedVectorElts( case ISD::SRA: case ISD::ROTL: case ISD::ROTR: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + APInt UndefRHS, ZeroRHS; - if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, UndefRHS, - ZeroRHS, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Op1, DemandedElts, UndefRHS, ZeroRHS, TLO, + Depth + 1)) return true; APInt UndefLHS, ZeroLHS; - if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, UndefLHS, - ZeroLHS, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Op0, DemandedElts, UndefLHS, ZeroLHS, TLO, + Depth + 1)) return true; KnownZero = ZeroLHS; KnownUndef = UndefLHS & UndefRHS; // TODO: use getKnownUndefForVectorBinop? + + // Attempt to avoid multi-use ops if we don't need anything from them. + // TODO - use KnownUndef to relax the demandedelts? + if (!DemandedElts.isAllOnesValue()) + if (SimplifyDemandedVectorEltsBinOp(Op0, Op1)) + return true; break; } case ISD::MUL: case ISD::AND: { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + APInt SrcUndef, SrcZero; - if (SimplifyDemandedVectorElts(Op.getOperand(1), DemandedElts, SrcUndef, - SrcZero, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Op1, DemandedElts, SrcUndef, SrcZero, TLO, + Depth + 1)) return true; - if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, KnownUndef, - KnownZero, TLO, Depth + 1)) + if (SimplifyDemandedVectorElts(Op0, DemandedElts, KnownUndef, KnownZero, + TLO, Depth + 1)) return true; // If either side has a zero element, then the result element is zero, even @@ -2597,6 +2820,12 @@ bool TargetLowering::SimplifyDemandedVectorElts( KnownZero |= SrcZero; KnownUndef &= SrcUndef; KnownUndef &= ~KnownZero; + + // Attempt to avoid multi-use ops if we don't need anything from them. + // TODO - use KnownUndef to relax the demandedelts? + if (!DemandedElts.isAllOnesValue()) + if (SimplifyDemandedVectorEltsBinOp(Op0, Op1)) + return true; break; } case ISD::TRUNCATE: @@ -2661,17 +2890,16 @@ void TargetLowering::computeKnownBitsForTargetInstr( Known.resetAll(); } -void TargetLowering::computeKnownBitsForFrameIndex(const SDValue Op, - KnownBits &Known, - const APInt &DemandedElts, - const SelectionDAG &DAG, - unsigned Depth) const { - assert(isa<FrameIndexSDNode>(Op) && "expected FrameIndex"); +void TargetLowering::computeKnownBitsForFrameIndex( + const int FrameIdx, KnownBits &Known, const MachineFunction &MF) const { + // The low bits are known zero if the pointer is aligned. + Known.Zero.setLowBits(Log2(MF.getFrameInfo().getObjectAlign(FrameIdx))); +} - if (unsigned Align = DAG.InferPtrAlignment(Op)) { - // The low bits are known zero if the pointer is aligned. - Known.Zero.setLowBits(Log2_32(Align)); - } +Align TargetLowering::computeKnownAlignForTargetInstr( + GISelKnownBits &Analysis, Register R, const MachineRegisterInfo &MRI, + unsigned Depth) const { + return Align(1); } /// This method can be implemented by targets that want to expose additional @@ -2689,6 +2917,12 @@ unsigned TargetLowering::ComputeNumSignBitsForTargetNode(SDValue Op, return 1; } +unsigned TargetLowering::computeNumSignBitsForTargetInstr( + GISelKnownBits &Analysis, Register R, const APInt &DemandedElts, + const MachineRegisterInfo &MRI, unsigned Depth) const { + return 1; +} + bool TargetLowering::SimplifyDemandedVectorEltsForTargetNode( SDValue Op, const APInt &DemandedElts, APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth) const { @@ -3788,33 +4022,18 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1, // the comparison operands is infinity or negative infinity, convert the // condition to a less-awkward <= or >=. if (CFP->getValueAPF().isInfinity()) { - if (CFP->getValueAPF().isNegative()) { - if (Cond == ISD::SETOEQ && - isCondCodeLegal(ISD::SETOLE, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETOLE); - if (Cond == ISD::SETUEQ && - isCondCodeLegal(ISD::SETOLE, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETULE); - if (Cond == ISD::SETUNE && - isCondCodeLegal(ISD::SETUGT, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETUGT); - if (Cond == ISD::SETONE && - isCondCodeLegal(ISD::SETUGT, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETOGT); - } else { - if (Cond == ISD::SETOEQ && - isCondCodeLegal(ISD::SETOGE, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETOGE); - if (Cond == ISD::SETUEQ && - isCondCodeLegal(ISD::SETOGE, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETUGE); - if (Cond == ISD::SETUNE && - isCondCodeLegal(ISD::SETULT, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETULT); - if (Cond == ISD::SETONE && - isCondCodeLegal(ISD::SETULT, N0.getSimpleValueType())) - return DAG.getSetCC(dl, VT, N0, N1, ISD::SETOLT); + bool IsNegInf = CFP->getValueAPF().isNegative(); + ISD::CondCode NewCond = ISD::SETCC_INVALID; + switch (Cond) { + case ISD::SETOEQ: NewCond = IsNegInf ? ISD::SETOLE : ISD::SETOGE; break; + case ISD::SETUEQ: NewCond = IsNegInf ? ISD::SETULE : ISD::SETUGE; break; + case ISD::SETUNE: NewCond = IsNegInf ? ISD::SETUGT : ISD::SETULT; break; + case ISD::SETONE: NewCond = IsNegInf ? ISD::SETOGT : ISD::SETOLT; break; + default: break; } + if (NewCond != ISD::SETCC_INVALID && + isCondCodeLegal(NewCond, N0.getSimpleValueType())) + return DAG.getSetCC(dl, VT, N0, N1, NewCond); } } } @@ -4245,10 +4464,10 @@ unsigned TargetLowering::AsmOperandInfo::getMatchedOperand() const { TargetLowering::AsmOperandInfoVector TargetLowering::ParseConstraints(const DataLayout &DL, const TargetRegisterInfo *TRI, - ImmutableCallSite CS) const { + const CallBase &Call) const { /// Information about all of the constraints. AsmOperandInfoVector ConstraintOperands; - const InlineAsm *IA = cast<InlineAsm>(CS.getCalledValue()); + const InlineAsm *IA = cast<InlineAsm>(Call.getCalledOperand()); unsigned maCount = 0; // Largest number of multiple alternative constraints. // Do a prepass over the constraints, canonicalizing them, and building up the @@ -4271,25 +4490,24 @@ TargetLowering::ParseConstraints(const DataLayout &DL, case InlineAsm::isOutput: // Indirect outputs just consume an argument. if (OpInfo.isIndirect) { - OpInfo.CallOperandVal = const_cast<Value *>(CS.getArgument(ArgNo++)); + OpInfo.CallOperandVal = Call.getArgOperand(ArgNo++); break; } // The return value of the call is this value. As such, there is no // corresponding argument. - assert(!CS.getType()->isVoidTy() && - "Bad inline asm!"); - if (StructType *STy = dyn_cast<StructType>(CS.getType())) { + assert(!Call.getType()->isVoidTy() && "Bad inline asm!"); + if (StructType *STy = dyn_cast<StructType>(Call.getType())) { OpInfo.ConstraintVT = getSimpleValueType(DL, STy->getElementType(ResNo)); } else { assert(ResNo == 0 && "Asm only has one result!"); - OpInfo.ConstraintVT = getSimpleValueType(DL, CS.getType()); + OpInfo.ConstraintVT = getSimpleValueType(DL, Call.getType()); } ++ResNo; break; case InlineAsm::isInput: - OpInfo.CallOperandVal = const_cast<Value *>(CS.getArgument(ArgNo++)); + OpInfo.CallOperandVal = Call.getArgOperand(ArgNo++); break; case InlineAsm::isClobber: // Nothing to do. @@ -5479,251 +5697,221 @@ verifyReturnAddressArgumentIsConstant(SDValue Op, SelectionDAG &DAG) const { return false; } -char TargetLowering::isNegatibleForFree(SDValue Op, SelectionDAG &DAG, - bool LegalOperations, bool ForCodeSize, - unsigned Depth) const { +SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, + bool LegalOps, bool OptForSize, + NegatibleCost &Cost, + unsigned Depth) const { // fneg is removable even if it has multiple uses. - if (Op.getOpcode() == ISD::FNEG) - return 2; + if (Op.getOpcode() == ISD::FNEG) { + Cost = NegatibleCost::Cheaper; + return Op.getOperand(0); + } - // Don't allow anything with multiple uses unless we know it is free. - EVT VT = Op.getValueType(); + // Don't recurse exponentially. + if (Depth > SelectionDAG::MaxRecursionDepth) + return SDValue(); + + // Pre-increment recursion depth for use in recursive calls. + ++Depth; const SDNodeFlags Flags = Op->getFlags(); const TargetOptions &Options = DAG.getTarget().Options; - if (!Op.hasOneUse() && !(Op.getOpcode() == ISD::FP_EXTEND && - isFPExtFree(VT, Op.getOperand(0).getValueType()))) - return 0; + EVT VT = Op.getValueType(); + unsigned Opcode = Op.getOpcode(); - // Don't recurse exponentially. - if (Depth > SelectionDAG::MaxRecursionDepth) - return 0; + // Don't allow anything with multiple uses unless we know it is free. + if (!Op.hasOneUse() && Opcode != ISD::ConstantFP) { + bool IsFreeExtend = Opcode == ISD::FP_EXTEND && + isFPExtFree(VT, Op.getOperand(0).getValueType()); + if (!IsFreeExtend) + return SDValue(); + } - switch (Op.getOpcode()) { - case ISD::ConstantFP: { - if (!LegalOperations) - return 1; + SDLoc DL(Op); + switch (Opcode) { + case ISD::ConstantFP: { // Don't invert constant FP values after legalization unless the target says // the negated constant is legal. - return isOperationLegal(ISD::ConstantFP, VT) || - isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT, - ForCodeSize); + bool IsOpLegal = + isOperationLegal(ISD::ConstantFP, VT) || + isFPImmLegal(neg(cast<ConstantFPSDNode>(Op)->getValueAPF()), VT, + OptForSize); + + if (LegalOps && !IsOpLegal) + break; + + APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF(); + V.changeSign(); + SDValue CFP = DAG.getConstantFP(V, DL, VT); + + // If we already have the use of the negated floating constant, it is free + // to negate it even it has multiple uses. + if (!Op.hasOneUse() && CFP.use_empty()) + break; + Cost = NegatibleCost::Neutral; + return CFP; } case ISD::BUILD_VECTOR: { // Only permit BUILD_VECTOR of constants. if (llvm::any_of(Op->op_values(), [&](SDValue N) { return !N.isUndef() && !isa<ConstantFPSDNode>(N); })) - return 0; - if (!LegalOperations) - return 1; - if (isOperationLegal(ISD::ConstantFP, VT) && - isOperationLegal(ISD::BUILD_VECTOR, VT)) - return 1; - return llvm::all_of(Op->op_values(), [&](SDValue N) { - return N.isUndef() || - isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT, - ForCodeSize); - }); + break; + + bool IsOpLegal = + (isOperationLegal(ISD::ConstantFP, VT) && + isOperationLegal(ISD::BUILD_VECTOR, VT)) || + llvm::all_of(Op->op_values(), [&](SDValue N) { + return N.isUndef() || + isFPImmLegal(neg(cast<ConstantFPSDNode>(N)->getValueAPF()), VT, + OptForSize); + }); + + if (LegalOps && !IsOpLegal) + break; + + SmallVector<SDValue, 4> Ops; + for (SDValue C : Op->op_values()) { + if (C.isUndef()) { + Ops.push_back(C); + continue; + } + APFloat V = cast<ConstantFPSDNode>(C)->getValueAPF(); + V.changeSign(); + Ops.push_back(DAG.getConstantFP(V, DL, C.getValueType())); + } + Cost = NegatibleCost::Neutral; + return DAG.getBuildVector(VT, DL, Ops); } - case ISD::FADD: + case ISD::FADD: { if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros()) - return 0; + break; // After operation legalization, it might not be legal to create new FSUBs. - if (LegalOperations && !isOperationLegalOrCustom(ISD::FSUB, VT)) - return 0; + if (LegalOps && !isOperationLegalOrCustom(ISD::FSUB, VT)) + break; + SDValue X = Op.getOperand(0), Y = Op.getOperand(1); - // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) - if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1)) - return V; - // fold (fneg (fadd A, B)) -> (fsub (fneg B), A) - return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); - case ISD::FSUB: + // fold (fneg (fadd X, Y)) -> (fsub (fneg X), Y) + NegatibleCost CostX = NegatibleCost::Expensive; + SDValue NegX = + getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth); + // fold (fneg (fadd X, Y)) -> (fsub (fneg Y), X) + NegatibleCost CostY = NegatibleCost::Expensive; + SDValue NegY = + getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth); + + // Negate the X if its cost is less or equal than Y. + if (NegX && (CostX <= CostY)) { + Cost = CostX; + return DAG.getNode(ISD::FSUB, DL, VT, NegX, Y, Flags); + } + + // Negate the Y if it is not expensive. + if (NegY) { + Cost = CostY; + return DAG.getNode(ISD::FSUB, DL, VT, NegY, X, Flags); + } + break; + } + case ISD::FSUB: { // We can't turn -(A-B) into B-A when we honor signed zeros. if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros()) - return 0; + break; - // fold (fneg (fsub A, B)) -> (fsub B, A) - return 1; + SDValue X = Op.getOperand(0), Y = Op.getOperand(1); + // fold (fneg (fsub 0, Y)) -> Y + if (ConstantFPSDNode *C = isConstOrConstSplatFP(X, /*AllowUndefs*/ true)) + if (C->isZero()) { + Cost = NegatibleCost::Cheaper; + return Y; + } + // fold (fneg (fsub X, Y)) -> (fsub Y, X) + Cost = NegatibleCost::Neutral; + return DAG.getNode(ISD::FSUB, DL, VT, Y, X, Flags); + } case ISD::FMUL: - case ISD::FDIV: - // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) or (fmul X, (fneg Y)) - if (char V = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1)) - return V; + case ISD::FDIV: { + SDValue X = Op.getOperand(0), Y = Op.getOperand(1); + + // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) + NegatibleCost CostX = NegatibleCost::Expensive; + SDValue NegX = + getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth); + // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y)) + NegatibleCost CostY = NegatibleCost::Expensive; + SDValue NegY = + getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth); + + // Negate the X if its cost is less or equal than Y. + if (NegX && (CostX <= CostY)) { + Cost = CostX; + return DAG.getNode(Opcode, DL, VT, NegX, Y, Flags); + } // Ignore X * 2.0 because that is expected to be canonicalized to X + X. if (auto *C = isConstOrConstSplatFP(Op.getOperand(1))) if (C->isExactlyValue(2.0) && Op.getOpcode() == ISD::FMUL) - return 0; - - return isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); + break; + // Negate the Y if it is not expensive. + if (NegY) { + Cost = CostY; + return DAG.getNode(Opcode, DL, VT, X, NegY, Flags); + } + break; + } case ISD::FMA: case ISD::FMAD: { if (!Options.NoSignedZerosFPMath && !Flags.hasNoSignedZeros()) - return 0; + break; + + SDValue X = Op.getOperand(0), Y = Op.getOperand(1), Z = Op.getOperand(2); + NegatibleCost CostZ = NegatibleCost::Expensive; + SDValue NegZ = + getNegatedExpression(Z, DAG, LegalOps, OptForSize, CostZ, Depth); + // Give up if fail to negate the Z. + if (!NegZ) + break; // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z)) + NegatibleCost CostX = NegatibleCost::Expensive; + SDValue NegX = + getNegatedExpression(X, DAG, LegalOps, OptForSize, CostX, Depth); // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z)) - char V2 = isNegatibleForFree(Op.getOperand(2), DAG, LegalOperations, - ForCodeSize, Depth + 1); - if (!V2) - return 0; - - // One of Op0/Op1 must be cheaply negatible, then select the cheapest. - char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1); - char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); - char V01 = std::max(V0, V1); - return V01 ? std::max(V01, V2) : 0; - } - - case ISD::FP_EXTEND: - case ISD::FP_ROUND: - case ISD::FSIN: - return isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1); - } - - return 0; -} - -SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, - bool LegalOperations, - bool ForCodeSize, - unsigned Depth) const { - // fneg is removable even if it has multiple uses. - if (Op.getOpcode() == ISD::FNEG) - return Op.getOperand(0); + NegatibleCost CostY = NegatibleCost::Expensive; + SDValue NegY = + getNegatedExpression(Y, DAG, LegalOps, OptForSize, CostY, Depth); - assert(Depth <= SelectionDAG::MaxRecursionDepth && - "getNegatedExpression doesn't match isNegatibleForFree"); - const SDNodeFlags Flags = Op->getFlags(); - - switch (Op.getOpcode()) { - case ISD::ConstantFP: { - APFloat V = cast<ConstantFPSDNode>(Op)->getValueAPF(); - V.changeSign(); - return DAG.getConstantFP(V, SDLoc(Op), Op.getValueType()); - } - case ISD::BUILD_VECTOR: { - SmallVector<SDValue, 4> Ops; - for (SDValue C : Op->op_values()) { - if (C.isUndef()) { - Ops.push_back(C); - continue; - } - APFloat V = cast<ConstantFPSDNode>(C)->getValueAPF(); - V.changeSign(); - Ops.push_back(DAG.getConstantFP(V, SDLoc(Op), C.getValueType())); + // Negate the X if its cost is less or equal than Y. + if (NegX && (CostX <= CostY)) { + Cost = std::min(CostX, CostZ); + return DAG.getNode(Opcode, DL, VT, NegX, Y, NegZ, Flags); } - return DAG.getBuildVector(Op.getValueType(), SDLoc(Op), Ops); - } - case ISD::FADD: - assert((DAG.getTarget().Options.NoSignedZerosFPMath || - Flags.hasNoSignedZeros()) && - "Expected NSZ fp-flag"); - - // fold (fneg (fadd A, B)) -> (fsub (fneg A), B) - if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize, - Depth + 1)) - return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), - getNegatedExpression(Op.getOperand(0), DAG, - LegalOperations, ForCodeSize, - Depth + 1), - Op.getOperand(1), Flags); - // fold (fneg (fadd A, B)) -> (fsub (fneg B), A) - return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), - getNegatedExpression(Op.getOperand(1), DAG, - LegalOperations, ForCodeSize, - Depth + 1), - Op.getOperand(0), Flags); - case ISD::FSUB: - // fold (fneg (fsub 0, B)) -> B - if (ConstantFPSDNode *N0CFP = - isConstOrConstSplatFP(Op.getOperand(0), /*AllowUndefs*/ true)) - if (N0CFP->isZero()) - return Op.getOperand(1); - - // fold (fneg (fsub A, B)) -> (fsub B, A) - return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), - Op.getOperand(1), Op.getOperand(0), Flags); - - case ISD::FMUL: - case ISD::FDIV: - // fold (fneg (fmul X, Y)) -> (fmul (fneg X), Y) - if (isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, ForCodeSize, - Depth + 1)) - return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), - getNegatedExpression(Op.getOperand(0), DAG, - LegalOperations, ForCodeSize, - Depth + 1), - Op.getOperand(1), Flags); - - // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y)) - return DAG.getNode( - Op.getOpcode(), SDLoc(Op), Op.getValueType(), Op.getOperand(0), - getNegatedExpression(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1), - Flags); - case ISD::FMA: - case ISD::FMAD: { - assert((DAG.getTarget().Options.NoSignedZerosFPMath || - Flags.hasNoSignedZeros()) && - "Expected NSZ fp-flag"); - - SDValue Neg2 = getNegatedExpression(Op.getOperand(2), DAG, LegalOperations, - ForCodeSize, Depth + 1); - - char V0 = isNegatibleForFree(Op.getOperand(0), DAG, LegalOperations, - ForCodeSize, Depth + 1); - char V1 = isNegatibleForFree(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); - // TODO: This is a hack. It is possible that costs have changed between now - // and the initial calls to isNegatibleForFree(). That is because we - // are rewriting the expression, and that may change the number of - // uses (and therefore the cost) of values. If the negation costs are - // equal, only negate this value if it is a constant. Otherwise, try - // operand 1. A better fix would eliminate uses as a cost factor or - // track the change in uses as we rewrite the expression. - if (V0 > V1 || (V0 == V1 && isa<ConstantFPSDNode>(Op.getOperand(0)))) { - // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z)) - SDValue Neg0 = getNegatedExpression( - Op.getOperand(0), DAG, LegalOperations, ForCodeSize, Depth + 1); - return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), Neg0, - Op.getOperand(1), Neg2, Flags); + // Negate the Y if it is not expensive. + if (NegY) { + Cost = std::min(CostY, CostZ); + return DAG.getNode(Opcode, DL, VT, X, NegY, NegZ, Flags); } - - // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z)) - SDValue Neg1 = getNegatedExpression(Op.getOperand(1), DAG, LegalOperations, - ForCodeSize, Depth + 1); - return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), - Op.getOperand(0), Neg1, Neg2, Flags); + break; } case ISD::FP_EXTEND: case ISD::FSIN: - return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), - getNegatedExpression(Op.getOperand(0), DAG, - LegalOperations, ForCodeSize, - Depth + 1)); + if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps, + OptForSize, Cost, Depth)) + return DAG.getNode(Opcode, DL, VT, NegV); + break; case ISD::FP_ROUND: - return DAG.getNode(ISD::FP_ROUND, SDLoc(Op), Op.getValueType(), - getNegatedExpression(Op.getOperand(0), DAG, - LegalOperations, ForCodeSize, - Depth + 1), - Op.getOperand(1)); + if (SDValue NegV = getNegatedExpression(Op.getOperand(0), DAG, LegalOps, + OptForSize, Cost, Depth)) + return DAG.getNode(ISD::FP_ROUND, DL, VT, NegV, Op.getOperand(1)); + break; } - llvm_unreachable("Unknown code"); + return SDValue(); } //===----------------------------------------------------------------------===// @@ -5929,6 +6117,14 @@ bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT, return Ok; } +// Check that (every element of) Z is undef or not an exact multiple of BW. +static bool isNonZeroModBitWidth(SDValue Z, unsigned BW) { + return ISD::matchUnaryPredicate( + Z, + [=](ConstantSDNode *C) { return !C || C->getAPIntValue().urem(BW) != 0; }, + true); +} + bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const { EVT VT = Node->getValueType(0); @@ -5939,41 +6135,54 @@ bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result, !isOperationLegalOrCustomOrPromote(ISD::OR, VT))) return false; - // fshl: (X << (Z % BW)) | (Y >> (BW - (Z % BW))) - // fshr: (X << (BW - (Z % BW))) | (Y >> (Z % BW)) SDValue X = Node->getOperand(0); SDValue Y = Node->getOperand(1); SDValue Z = Node->getOperand(2); - unsigned EltSizeInBits = VT.getScalarSizeInBits(); + unsigned BW = VT.getScalarSizeInBits(); bool IsFSHL = Node->getOpcode() == ISD::FSHL; SDLoc DL(SDValue(Node, 0)); EVT ShVT = Z.getValueType(); - SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT); - SDValue Zero = DAG.getConstant(0, DL, ShVT); - SDValue ShAmt; - if (isPowerOf2_32(EltSizeInBits)) { - SDValue Mask = DAG.getConstant(EltSizeInBits - 1, DL, ShVT); - ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask); - } else { + SDValue ShX, ShY; + SDValue ShAmt, InvShAmt; + if (isNonZeroModBitWidth(Z, BW)) { + // fshl: X << C | Y >> (BW - C) + // fshr: X << (BW - C) | Y >> C + // where C = Z % BW is not zero + SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT); ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); - } - - SDValue InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt); - SDValue ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt); - SDValue ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt); - SDValue Or = DAG.getNode(ISD::OR, DL, VT, ShX, ShY); - - // If (Z % BW == 0), then the opposite direction shift is shift-by-bitwidth, - // and that is undefined. We must compare and select to avoid UB. - EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), ShVT); + InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt); + ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt); + ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt); + } else { + // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW)) + // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW) + SDValue Mask = DAG.getConstant(BW - 1, DL, ShVT); + if (isPowerOf2_32(BW)) { + // Z % BW -> Z & (BW - 1) + ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask); + // (BW - 1) - (Z % BW) -> ~Z & (BW - 1) + InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask); + } else { + SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT); + ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); + InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt); + } - // For fshl, 0-shift returns the 1st arg (X). - // For fshr, 0-shift returns the 2nd arg (Y). - SDValue IsZeroShift = DAG.getSetCC(DL, CCVT, ShAmt, Zero, ISD::SETEQ); - Result = DAG.getSelect(DL, VT, IsZeroShift, IsFSHL ? X : Y, Or); + SDValue One = DAG.getConstant(1, DL, ShVT); + if (IsFSHL) { + ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt); + SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One); + ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt); + } else { + SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One); + ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt); + ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt); + } + } + Result = DAG.getNode(ISD::OR, DL, VT, ShX, ShY); return true; } @@ -5988,12 +6197,15 @@ bool TargetLowering::expandROT(SDNode *Node, SDValue &Result, SDLoc DL(SDValue(Node, 0)); EVT ShVT = Op1.getValueType(); - SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT); + SDValue Zero = DAG.getConstant(0, DL, ShVT); - // If a rotate in the other direction is legal, use it. + assert(isPowerOf2_32(EltSizeInBits) && EltSizeInBits > 1 && + "Expecting the type bitwidth to be a power of 2"); + + // If a rotate in the other direction is supported, use it. unsigned RevRot = IsLeft ? ISD::ROTR : ISD::ROTL; - if (isOperationLegal(RevRot, VT)) { - SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, Op1); + if (isOperationLegalOrCustom(RevRot, VT)) { + SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1); Result = DAG.getNode(RevRot, DL, VT, Op0, Sub); return true; } @@ -6006,15 +6218,13 @@ bool TargetLowering::expandROT(SDNode *Node, SDValue &Result, return false; // Otherwise, - // (rotl x, c) -> (or (shl x, (and c, w-1)), (srl x, (and w-c, w-1))) - // (rotr x, c) -> (or (srl x, (and c, w-1)), (shl x, (and w-c, w-1))) + // (rotl x, c) -> (or (shl x, (and c, w-1)), (srl x, (and -c, w-1))) + // (rotr x, c) -> (or (srl x, (and c, w-1)), (shl x, (and -c, w-1))) // - assert(isPowerOf2_32(EltSizeInBits) && EltSizeInBits > 1 && - "Expecting the type bitwidth to be a power of 2"); unsigned ShOpc = IsLeft ? ISD::SHL : ISD::SRL; unsigned HsOpc = IsLeft ? ISD::SRL : ISD::SHL; SDValue BitWidthMinusOneC = DAG.getConstant(EltSizeInBits - 1, DL, ShVT); - SDValue NegOp1 = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, Op1); + SDValue NegOp1 = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Op1); SDValue And0 = DAG.getNode(ISD::AND, DL, ShVT, Op1, BitWidthMinusOneC); SDValue And1 = DAG.getNode(ISD::AND, DL, ShVT, NegOp1, BitWidthMinusOneC); Result = DAG.getNode(ISD::OR, DL, VT, DAG.getNode(ShOpc, DL, VT, Op0, And0), @@ -6198,114 +6408,50 @@ bool TargetLowering::expandUINT_TO_FP(SDNode *Node, SDValue &Result, EVT SrcVT = Src.getValueType(); EVT DstVT = Node->getValueType(0); - if (SrcVT.getScalarType() != MVT::i64) + if (SrcVT.getScalarType() != MVT::i64 || DstVT.getScalarType() != MVT::f64) + return false; + + // Only expand vector types if we have the appropriate vector bit operations. + if (SrcVT.isVector() && (!isOperationLegalOrCustom(ISD::SRL, SrcVT) || + !isOperationLegalOrCustom(ISD::FADD, DstVT) || + !isOperationLegalOrCustom(ISD::FSUB, DstVT) || + !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) || + !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT))) return false; SDLoc dl(SDValue(Node, 0)); EVT ShiftVT = getShiftAmountTy(SrcVT, DAG.getDataLayout()); - if (DstVT.getScalarType() == MVT::f32) { - // Only expand vector types if we have the appropriate vector bit - // operations. - if (SrcVT.isVector() && - (!isOperationLegalOrCustom(ISD::SRL, SrcVT) || - !isOperationLegalOrCustom(ISD::FADD, DstVT) || - !isOperationLegalOrCustom(ISD::SINT_TO_FP, SrcVT) || - !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) || - !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT))) - return false; - - // For unsigned conversions, convert them to signed conversions using the - // algorithm from the x86_64 __floatundisf in compiler_rt. - - // TODO: This really should be implemented using a branch rather than a - // select. We happen to get lucky and machinesink does the right - // thing most of the time. This would be a good candidate for a - // pseudo-op, or, even better, for whole-function isel. - EVT SetCCVT = - getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT); - - SDValue SignBitTest = DAG.getSetCC( - dl, SetCCVT, Src, DAG.getConstant(0, dl, SrcVT), ISD::SETLT); - - SDValue ShiftConst = DAG.getConstant(1, dl, ShiftVT); - SDValue Shr = DAG.getNode(ISD::SRL, dl, SrcVT, Src, ShiftConst); - SDValue AndConst = DAG.getConstant(1, dl, SrcVT); - SDValue And = DAG.getNode(ISD::AND, dl, SrcVT, Src, AndConst); - SDValue Or = DAG.getNode(ISD::OR, dl, SrcVT, And, Shr); - - SDValue Slow, Fast; - if (Node->isStrictFPOpcode()) { - // In strict mode, we must avoid spurious exceptions, and therefore - // must make sure to only emit a single STRICT_SINT_TO_FP. - SDValue InCvt = DAG.getSelect(dl, SrcVT, SignBitTest, Or, Src); - Fast = DAG.getNode(ISD::STRICT_SINT_TO_FP, dl, { DstVT, MVT::Other }, - { Node->getOperand(0), InCvt }); - Slow = DAG.getNode(ISD::STRICT_FADD, dl, { DstVT, MVT::Other }, - { Fast.getValue(1), Fast, Fast }); - Chain = Slow.getValue(1); - // The STRICT_SINT_TO_FP inherits the exception mode from the - // incoming STRICT_UINT_TO_FP node; the STRICT_FADD node can - // never raise any exception. - SDNodeFlags Flags; - Flags.setNoFPExcept(Node->getFlags().hasNoFPExcept()); - Fast->setFlags(Flags); - Flags.setNoFPExcept(true); - Slow->setFlags(Flags); - } else { - SDValue SignCvt = DAG.getNode(ISD::SINT_TO_FP, dl, DstVT, Or); - Slow = DAG.getNode(ISD::FADD, dl, DstVT, SignCvt, SignCvt); - Fast = DAG.getNode(ISD::SINT_TO_FP, dl, DstVT, Src); - } - - Result = DAG.getSelect(dl, DstVT, SignBitTest, Slow, Fast); - return true; - } - - if (DstVT.getScalarType() == MVT::f64) { - // Only expand vector types if we have the appropriate vector bit - // operations. - if (SrcVT.isVector() && - (!isOperationLegalOrCustom(ISD::SRL, SrcVT) || - !isOperationLegalOrCustom(ISD::FADD, DstVT) || - !isOperationLegalOrCustom(ISD::FSUB, DstVT) || - !isOperationLegalOrCustomOrPromote(ISD::OR, SrcVT) || - !isOperationLegalOrCustomOrPromote(ISD::AND, SrcVT))) - return false; - - // Implementation of unsigned i64 to f64 following the algorithm in - // __floatundidf in compiler_rt. This implementation has the advantage - // of performing rounding correctly, both in the default rounding mode - // and in all alternate rounding modes. - SDValue TwoP52 = DAG.getConstant(UINT64_C(0x4330000000000000), dl, SrcVT); - SDValue TwoP84PlusTwoP52 = DAG.getConstantFP( - BitsToDouble(UINT64_C(0x4530000000100000)), dl, DstVT); - SDValue TwoP84 = DAG.getConstant(UINT64_C(0x4530000000000000), dl, SrcVT); - SDValue LoMask = DAG.getConstant(UINT64_C(0x00000000FFFFFFFF), dl, SrcVT); - SDValue HiShift = DAG.getConstant(32, dl, ShiftVT); - - SDValue Lo = DAG.getNode(ISD::AND, dl, SrcVT, Src, LoMask); - SDValue Hi = DAG.getNode(ISD::SRL, dl, SrcVT, Src, HiShift); - SDValue LoOr = DAG.getNode(ISD::OR, dl, SrcVT, Lo, TwoP52); - SDValue HiOr = DAG.getNode(ISD::OR, dl, SrcVT, Hi, TwoP84); - SDValue LoFlt = DAG.getBitcast(DstVT, LoOr); - SDValue HiFlt = DAG.getBitcast(DstVT, HiOr); - if (Node->isStrictFPOpcode()) { - SDValue HiSub = - DAG.getNode(ISD::STRICT_FSUB, dl, {DstVT, MVT::Other}, - {Node->getOperand(0), HiFlt, TwoP84PlusTwoP52}); - Result = DAG.getNode(ISD::STRICT_FADD, dl, {DstVT, MVT::Other}, - {HiSub.getValue(1), LoFlt, HiSub}); - Chain = Result.getValue(1); - } else { - SDValue HiSub = - DAG.getNode(ISD::FSUB, dl, DstVT, HiFlt, TwoP84PlusTwoP52); - Result = DAG.getNode(ISD::FADD, dl, DstVT, LoFlt, HiSub); - } - return true; + // Implementation of unsigned i64 to f64 following the algorithm in + // __floatundidf in compiler_rt. This implementation has the advantage + // of performing rounding correctly, both in the default rounding mode + // and in all alternate rounding modes. + SDValue TwoP52 = DAG.getConstant(UINT64_C(0x4330000000000000), dl, SrcVT); + SDValue TwoP84PlusTwoP52 = DAG.getConstantFP( + BitsToDouble(UINT64_C(0x4530000000100000)), dl, DstVT); + SDValue TwoP84 = DAG.getConstant(UINT64_C(0x4530000000000000), dl, SrcVT); + SDValue LoMask = DAG.getConstant(UINT64_C(0x00000000FFFFFFFF), dl, SrcVT); + SDValue HiShift = DAG.getConstant(32, dl, ShiftVT); + + SDValue Lo = DAG.getNode(ISD::AND, dl, SrcVT, Src, LoMask); + SDValue Hi = DAG.getNode(ISD::SRL, dl, SrcVT, Src, HiShift); + SDValue LoOr = DAG.getNode(ISD::OR, dl, SrcVT, Lo, TwoP52); + SDValue HiOr = DAG.getNode(ISD::OR, dl, SrcVT, Hi, TwoP84); + SDValue LoFlt = DAG.getBitcast(DstVT, LoOr); + SDValue HiFlt = DAG.getBitcast(DstVT, HiOr); + if (Node->isStrictFPOpcode()) { + SDValue HiSub = + DAG.getNode(ISD::STRICT_FSUB, dl, {DstVT, MVT::Other}, + {Node->getOperand(0), HiFlt, TwoP84PlusTwoP52}); + Result = DAG.getNode(ISD::STRICT_FADD, dl, {DstVT, MVT::Other}, + {HiSub.getValue(1), LoFlt, HiSub}); + Chain = Result.getValue(1); + } else { + SDValue HiSub = + DAG.getNode(ISD::FSUB, dl, DstVT, HiFlt, TwoP84PlusTwoP52); + Result = DAG.getNode(ISD::FADD, dl, DstVT, LoFlt, HiSub); } - - return false; + return true; } SDValue TargetLowering::expandFMINNUM_FMAXNUM(SDNode *Node, @@ -6564,12 +6710,61 @@ TargetLowering::scalarizeVectorLoad(LoadSDNode *LD, SDValue Chain = LD->getChain(); SDValue BasePTR = LD->getBasePtr(); EVT SrcVT = LD->getMemoryVT(); + EVT DstVT = LD->getValueType(0); ISD::LoadExtType ExtType = LD->getExtensionType(); unsigned NumElem = SrcVT.getVectorNumElements(); EVT SrcEltVT = SrcVT.getScalarType(); - EVT DstEltVT = LD->getValueType(0).getScalarType(); + EVT DstEltVT = DstVT.getScalarType(); + + // A vector must always be stored in memory as-is, i.e. without any padding + // between the elements, since various code depend on it, e.g. in the + // handling of a bitcast of a vector type to int, which may be done with a + // vector store followed by an integer load. A vector that does not have + // elements that are byte-sized must therefore be stored as an integer + // built out of the extracted vector elements. + if (!SrcEltVT.isByteSized()) { + unsigned NumLoadBits = SrcVT.getStoreSizeInBits(); + EVT LoadVT = EVT::getIntegerVT(*DAG.getContext(), NumLoadBits); + + unsigned NumSrcBits = SrcVT.getSizeInBits(); + EVT SrcIntVT = EVT::getIntegerVT(*DAG.getContext(), NumSrcBits); + + unsigned SrcEltBits = SrcEltVT.getSizeInBits(); + SDValue SrcEltBitMask = DAG.getConstant( + APInt::getLowBitsSet(NumLoadBits, SrcEltBits), SL, LoadVT); + + // Load the whole vector and avoid masking off the top bits as it makes + // the codegen worse. + SDValue Load = + DAG.getExtLoad(ISD::EXTLOAD, SL, LoadVT, Chain, BasePTR, + LD->getPointerInfo(), SrcIntVT, LD->getAlignment(), + LD->getMemOperand()->getFlags(), LD->getAAInfo()); + + SmallVector<SDValue, 8> Vals; + for (unsigned Idx = 0; Idx < NumElem; ++Idx) { + unsigned ShiftIntoIdx = + (DAG.getDataLayout().isBigEndian() ? (NumElem - 1) - Idx : Idx); + SDValue ShiftAmount = + DAG.getShiftAmountConstant(ShiftIntoIdx * SrcEltVT.getSizeInBits(), + LoadVT, SL, /*LegalTypes=*/false); + SDValue ShiftedElt = DAG.getNode(ISD::SRL, SL, LoadVT, Load, ShiftAmount); + SDValue Elt = + DAG.getNode(ISD::AND, SL, LoadVT, ShiftedElt, SrcEltBitMask); + SDValue Scalar = DAG.getNode(ISD::TRUNCATE, SL, SrcEltVT, Elt); + + if (ExtType != ISD::NON_EXTLOAD) { + unsigned ExtendOp = ISD::getExtForLoadExtType(false, ExtType); + Scalar = DAG.getNode(ExtendOp, SL, DstEltVT, Scalar); + } + + Vals.push_back(Scalar); + } + + SDValue Value = DAG.getBuildVector(DstVT, SL, Vals); + return std::make_pair(Value, Load.getValue(1)); + } unsigned Stride = SrcEltVT.getSizeInBits() / 8; assert(SrcEltVT.isByteSized()); @@ -6591,7 +6786,7 @@ TargetLowering::scalarizeVectorLoad(LoadSDNode *LD, } SDValue NewChain = DAG.getNode(ISD::TokenFactor, SL, MVT::Other, LoadChains); - SDValue Value = DAG.getBuildVector(LD->getValueType(0), SL, Vals); + SDValue Value = DAG.getBuildVector(DstVT, SL, Vals); return std::make_pair(Value, NewChain); } @@ -6612,7 +6807,6 @@ SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST, // The type of data as saved in memory. EVT MemSclVT = StVT.getScalarType(); - EVT IdxVT = getVectorIdxTy(DAG.getDataLayout()); unsigned NumElem = StVT.getVectorNumElements(); // A vector must always be stored in memory as-is, i.e. without any padding @@ -6629,7 +6823,7 @@ SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST, for (unsigned Idx = 0; Idx < NumElem; ++Idx) { SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, RegSclVT, Value, - DAG.getConstant(Idx, SL, IdxVT)); + DAG.getVectorIdxConstant(Idx, SL)); SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, MemSclVT, Elt); SDValue ExtElt = DAG.getNode(ISD::ZERO_EXTEND, SL, IntVT, Trunc); unsigned ShiftIntoIdx = @@ -6654,7 +6848,7 @@ SDValue TargetLowering::scalarizeVectorStore(StoreSDNode *ST, SmallVector<SDValue, 8> Stores; for (unsigned Idx = 0; Idx < NumElem; ++Idx) { SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, RegSclVT, Value, - DAG.getConstant(Idx, SL, IdxVT)); + DAG.getVectorIdxConstant(Idx, SL)); SDValue Ptr = DAG.getObjectPtrOffset(SL, BasePtr, Idx * Stride); @@ -7313,12 +7507,13 @@ SDValue TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl, SDValue LHS, SDValue RHS, unsigned Scale, SelectionDAG &DAG) const { - assert((Opcode == ISD::SDIVFIX || - Opcode == ISD::UDIVFIX) && + assert((Opcode == ISD::SDIVFIX || Opcode == ISD::SDIVFIXSAT || + Opcode == ISD::UDIVFIX || Opcode == ISD::UDIVFIXSAT) && "Expected a fixed point division opcode"); EVT VT = LHS.getValueType(); - bool Signed = Opcode == ISD::SDIVFIX; + bool Signed = Opcode == ISD::SDIVFIX || Opcode == ISD::SDIVFIXSAT; + bool Saturating = Opcode == ISD::SDIVFIXSAT || Opcode == ISD::UDIVFIXSAT; EVT BoolVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); // If there is enough room in the type to upscale the LHS or downscale the @@ -7330,7 +7525,15 @@ TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl, : DAG.computeKnownBits(LHS).countMinLeadingZeros(); unsigned RHSTrail = DAG.computeKnownBits(RHS).countMinTrailingZeros(); - if (LHSLead + RHSTrail < Scale) + // For signed saturating operations, we need to be able to detect true integer + // division overflow; that is, when you have MIN / -EPS. However, this + // is undefined behavior and if we emit divisions that could take such + // values it may cause undesired behavior (arithmetic exceptions on x86, for + // example). + // Avoid this by requiring an extra bit so that we never get this case. + // FIXME: This is a bit unfortunate as it means that for an 8-bit 7-scale + // signed saturating division, we need to emit a whopping 32-bit division. + if (LHSLead + RHSTrail < Scale + (unsigned)(Saturating && Signed)) return SDValue(); unsigned LHSShift = std::min(LHSLead, Scale); @@ -7384,8 +7587,6 @@ TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl, Quot = DAG.getNode(ISD::UDIV, dl, VT, LHS, RHS); - // TODO: Saturation. - return Quot; } @@ -7659,3 +7860,26 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res); return Res; } + +bool TargetLowering::expandREM(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + EVT VT = Node->getValueType(0); + SDLoc dl(Node); + bool isSigned = Node->getOpcode() == ISD::SREM; + unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV; + unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; + SDValue Dividend = Node->getOperand(0); + SDValue Divisor = Node->getOperand(1); + if (isOperationLegalOrCustom(DivRemOpc, VT)) { + SDVTList VTs = DAG.getVTList(VT, VT); + Result = DAG.getNode(DivRemOpc, dl, VTs, Dividend, Divisor).getValue(1); + return true; + } else if (isOperationLegalOrCustom(DivOpc, VT)) { + // X % Y -> X-X/Y*Y + SDValue Divide = DAG.getNode(DivOpc, dl, VT, Dividend, Divisor); + SDValue Mul = DAG.getNode(ISD::MUL, dl, VT, Divide, Divisor); + Result = DAG.getNode(ISD::SUB, dl, VT, Dividend, Mul); + return true; + } + return false; +} |