diff options
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64FrameLowering.cpp | 47 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64FrameLowering.h | 6 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp | 131 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 23 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrFormats.td | 5 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 29 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp | 28 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 22 | ||||
-rw-r--r-- | llvm/lib/Target/AArch64/SVEInstrFormats.td | 10 | ||||
-rw-r--r-- | llvm/lib/Target/PowerPC/PPCISelLowering.cpp | 27 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 324 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h | 9 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 27 | ||||
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVInstrInfoB.td | 429 | ||||
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 131 |
15 files changed, 1101 insertions, 147 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp index efa3fd5ca9ce..4789a9f02937 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1192,7 +1192,7 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF, // Process the SVE callee-saves to determine what space needs to be // allocated. - if (AFI->getSVECalleeSavedStackSize()) { + if (int64_t CalleeSavedSize = AFI->getSVECalleeSavedStackSize()) { // Find callee save instructions in frame. CalleeSavesBegin = MBBI; assert(IsSVECalleeSave(CalleeSavesBegin) && "Unexpected instruction"); @@ -1200,11 +1200,7 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF, ++MBBI; CalleeSavesEnd = MBBI; - int64_t OffsetToFirstCalleeSaveFromSP = - MFI.getObjectOffset(AFI->getMaxSVECSFrameIndex()); - StackOffset OffsetToCalleeSavesFromSP = - StackOffset(OffsetToFirstCalleeSaveFromSP, MVT::nxv1i8) + SVEStackSize; - AllocateBefore -= OffsetToCalleeSavesFromSP; + AllocateBefore = {CalleeSavedSize, MVT::nxv1i8}; AllocateAfter = SVEStackSize - AllocateBefore; } @@ -1582,7 +1578,7 @@ void AArch64FrameLowering::emitEpilogue(MachineFunction &MF, // deallocated. StackOffset DeallocateBefore = {}, DeallocateAfter = SVEStackSize; MachineBasicBlock::iterator RestoreBegin = LastPopI, RestoreEnd = LastPopI; - if (AFI->getSVECalleeSavedStackSize()) { + if (int64_t CalleeSavedSize = AFI->getSVECalleeSavedStackSize()) { RestoreBegin = std::prev(RestoreEnd);; while (IsSVECalleeSave(RestoreBegin) && RestoreBegin != MBB.begin()) @@ -1592,23 +1588,21 @@ void AArch64FrameLowering::emitEpilogue(MachineFunction &MF, assert(IsSVECalleeSave(RestoreBegin) && IsSVECalleeSave(std::prev(RestoreEnd)) && "Unexpected instruction"); - int64_t OffsetToFirstCalleeSaveFromSP = - MFI.getObjectOffset(AFI->getMaxSVECSFrameIndex()); - StackOffset OffsetToCalleeSavesFromSP = - StackOffset(OffsetToFirstCalleeSaveFromSP, MVT::nxv1i8) + SVEStackSize; - DeallocateBefore = OffsetToCalleeSavesFromSP; - DeallocateAfter = SVEStackSize - DeallocateBefore; + StackOffset CalleeSavedSizeAsOffset = {CalleeSavedSize, MVT::nxv1i8}; + DeallocateBefore = SVEStackSize - CalleeSavedSizeAsOffset; + DeallocateAfter = CalleeSavedSizeAsOffset; } // Deallocate the SVE area. if (SVEStackSize) { if (AFI->isStackRealigned()) { - if (AFI->getSVECalleeSavedStackSize()) - // Set SP to start of SVE area, from which the callee-save reloads - // can be done. The code below will deallocate the stack space + if (int64_t CalleeSavedSize = AFI->getSVECalleeSavedStackSize()) + // Set SP to start of SVE callee-save area from which they can + // be reloaded. The code below will deallocate the stack space // space by moving FP -> SP. emitFrameOffset(MBB, RestoreBegin, DL, AArch64::SP, AArch64::FP, - -SVEStackSize, TII, MachineInstr::FrameDestroy); + {-CalleeSavedSize, MVT::nxv1i8}, TII, + MachineInstr::FrameDestroy); } else { if (AFI->getSVECalleeSavedStackSize()) { // Deallocate the non-SVE locals first before we can deallocate (and @@ -2595,25 +2589,23 @@ static int64_t determineSVEStackObjectOffsets(MachineFrameInfo &MFI, int &MinCSFrameIndex, int &MaxCSFrameIndex, bool AssignOffsets) { +#ifndef NDEBUG // First process all fixed stack objects. - int64_t Offset = 0; for (int I = MFI.getObjectIndexBegin(); I != 0; ++I) - if (MFI.getStackID(I) == TargetStackID::SVEVector) { - int64_t FixedOffset = -MFI.getObjectOffset(I); - if (FixedOffset > Offset) - Offset = FixedOffset; - } + assert(MFI.getStackID(I) != TargetStackID::SVEVector && + "SVE vectors should never be passed on the stack by value, only by " + "reference."); +#endif auto Assign = [&MFI](int FI, int64_t Offset) { LLVM_DEBUG(dbgs() << "alloc FI(" << FI << ") at SP[" << Offset << "]\n"); MFI.setObjectOffset(FI, Offset); }; + int64_t Offset = 0; + // Then process all callee saved slots. if (getSVECalleeSaveSlotRange(MFI, MinCSFrameIndex, MaxCSFrameIndex)) { - // Make sure to align the last callee save slot. - MFI.setObjectAlignment(MaxCSFrameIndex, Align(16)); - // Assign offsets to the callee save slots. for (int I = MinCSFrameIndex; I <= MaxCSFrameIndex; ++I) { Offset += MFI.getObjectSize(I); @@ -2623,6 +2615,9 @@ static int64_t determineSVEStackObjectOffsets(MachineFrameInfo &MFI, } } + // Ensure that the Callee-save area is aligned to 16bytes. + Offset = alignTo(Offset, Align(16U)); + // Create a buffer of SVE objects to allocate and sort it. SmallVector<int, 8> ObjectsToAllocate; for (int I = 0, E = MFI.getObjectIndexEnd(); I != E; ++I) { diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.h b/llvm/lib/Target/AArch64/AArch64FrameLowering.h index 9d0a6d9eaf25..444740cb50ab 100644 --- a/llvm/lib/Target/AArch64/AArch64FrameLowering.h +++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.h @@ -105,6 +105,12 @@ public: } } + bool isStackIdSafeForLocalArea(unsigned StackId) const override { + // We don't support putting SVE objects into the pre-allocated local + // frame block at the moment. + return StackId != TargetStackID::SVEVector; + } + private: bool shouldCombineCSRLocalStackBump(MachineFunction &MF, uint64_t StackBumpBytes) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index 10c477853353..7799ebfbd68e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -245,7 +245,8 @@ public: unsigned SubRegIdx); void SelectLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectPostLoadLane(SDNode *N, unsigned NumVecs, unsigned Opc); - void SelectPredicatedLoad(SDNode *N, unsigned NumVecs, const unsigned Opc); + void SelectPredicatedLoad(SDNode *N, unsigned NumVecs, unsigned Scale, + unsigned Opc_rr, unsigned Opc_ri); bool SelectAddrModeFrameIndexSVE(SDValue N, SDValue &Base, SDValue &OffImm); /// SVE Reg+Imm addressing mode. @@ -262,14 +263,12 @@ public: void SelectPostStore(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); void SelectPostStoreLane(SDNode *N, unsigned NumVecs, unsigned Opc); - template <unsigned Scale> - void SelectPredicatedStore(SDNode *N, unsigned NumVecs, const unsigned Opc_rr, - const unsigned Opc_ri); - template <unsigned Scale> + void SelectPredicatedStore(SDNode *N, unsigned NumVecs, unsigned Scale, + unsigned Opc_rr, unsigned Opc_ri); std::tuple<unsigned, SDValue, SDValue> - findAddrModeSVELoadStore(SDNode *N, const unsigned Opc_rr, - const unsigned Opc_ri, const SDValue &OldBase, - const SDValue &OldOffset); + findAddrModeSVELoadStore(SDNode *N, unsigned Opc_rr, unsigned Opc_ri, + const SDValue &OldBase, const SDValue &OldOffset, + unsigned Scale); bool tryBitfieldExtractOp(SDNode *N); bool tryBitfieldExtractOpFromSExt(SDNode *N); @@ -1414,12 +1413,12 @@ void AArch64DAGToDAGISel::SelectPostLoad(SDNode *N, unsigned NumVecs, /// Optimize \param OldBase and \param OldOffset selecting the best addressing /// mode. Returns a tuple consisting of an Opcode, an SDValue representing the /// new Base and an SDValue representing the new offset. -template <unsigned Scale> std::tuple<unsigned, SDValue, SDValue> -AArch64DAGToDAGISel::findAddrModeSVELoadStore(SDNode *N, const unsigned Opc_rr, - const unsigned Opc_ri, +AArch64DAGToDAGISel::findAddrModeSVELoadStore(SDNode *N, unsigned Opc_rr, + unsigned Opc_ri, const SDValue &OldBase, - const SDValue &OldOffset) { + const SDValue &OldOffset, + unsigned Scale) { SDValue NewBase = OldBase; SDValue NewOffset = OldOffset; // Detect a possible Reg+Imm addressing mode. @@ -1429,21 +1428,30 @@ AArch64DAGToDAGISel::findAddrModeSVELoadStore(SDNode *N, const unsigned Opc_rr, // Detect a possible reg+reg addressing mode, but only if we haven't already // detected a Reg+Imm one. const bool IsRegReg = - !IsRegImm && SelectSVERegRegAddrMode<Scale>(OldBase, NewBase, NewOffset); + !IsRegImm && SelectSVERegRegAddrMode(OldBase, Scale, NewBase, NewOffset); // Select the instruction. return std::make_tuple(IsRegReg ? Opc_rr : Opc_ri, NewBase, NewOffset); } void AArch64DAGToDAGISel::SelectPredicatedLoad(SDNode *N, unsigned NumVecs, - const unsigned Opc) { + unsigned Scale, unsigned Opc_ri, + unsigned Opc_rr) { + assert(Scale < 4 && "Invalid scaling value."); SDLoc DL(N); EVT VT = N->getValueType(0); SDValue Chain = N->getOperand(0); + // Optimize addressing mode. + SDValue Base, Offset; + unsigned Opc; + std::tie(Opc, Base, Offset) = findAddrModeSVELoadStore( + N, Opc_rr, Opc_ri, N->getOperand(2), + CurDAG->getTargetConstant(0, DL, MVT::i64), Scale); + SDValue Ops[] = {N->getOperand(1), // Predicate - N->getOperand(2), // Memory operand - CurDAG->getTargetConstant(0, DL, MVT::i64), Chain}; + Base, // Memory operand + Offset, Chain}; const EVT ResTys[] = {MVT::Untyped, MVT::Other}; @@ -1479,10 +1487,9 @@ void AArch64DAGToDAGISel::SelectStore(SDNode *N, unsigned NumVecs, ReplaceNode(N, St); } -template <unsigned Scale> void AArch64DAGToDAGISel::SelectPredicatedStore(SDNode *N, unsigned NumVecs, - const unsigned Opc_rr, - const unsigned Opc_ri) { + unsigned Scale, unsigned Opc_rr, + unsigned Opc_ri) { SDLoc dl(N); // Form a REG_SEQUENCE to force register allocation. @@ -1492,9 +1499,9 @@ void AArch64DAGToDAGISel::SelectPredicatedStore(SDNode *N, unsigned NumVecs, // Optimize addressing mode. unsigned Opc; SDValue Offset, Base; - std::tie(Opc, Base, Offset) = findAddrModeSVELoadStore<Scale>( + std::tie(Opc, Base, Offset) = findAddrModeSVELoadStore( N, Opc_rr, Opc_ri, N->getOperand(NumVecs + 3), - CurDAG->getTargetConstant(0, dl, MVT::i64)); + CurDAG->getTargetConstant(0, dl, MVT::i64), Scale); SDValue Ops[] = {RegSeq, N->getOperand(NumVecs + 2), // predicate Base, // address @@ -4085,63 +4092,51 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { } case Intrinsic::aarch64_sve_st2: { if (VT == MVT::nxv16i8) { - SelectPredicatedStore</*Scale=*/0>(Node, 2, AArch64::ST2B, - AArch64::ST2B_IMM); + SelectPredicatedStore(Node, 2, 0, AArch64::ST2B, AArch64::ST2B_IMM); return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) { - SelectPredicatedStore</*Scale=*/1>(Node, 2, AArch64::ST2H, - AArch64::ST2H_IMM); + SelectPredicatedStore(Node, 2, 1, AArch64::ST2H, AArch64::ST2H_IMM); return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - SelectPredicatedStore</*Scale=*/2>(Node, 2, AArch64::ST2W, - AArch64::ST2W_IMM); + SelectPredicatedStore(Node, 2, 2, AArch64::ST2W, AArch64::ST2W_IMM); return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - SelectPredicatedStore</*Scale=*/3>(Node, 2, AArch64::ST2D, - AArch64::ST2D_IMM); + SelectPredicatedStore(Node, 2, 3, AArch64::ST2D, AArch64::ST2D_IMM); return; } break; } case Intrinsic::aarch64_sve_st3: { if (VT == MVT::nxv16i8) { - SelectPredicatedStore</*Scale=*/0>(Node, 3, AArch64::ST3B, - AArch64::ST3B_IMM); + SelectPredicatedStore(Node, 3, 0, AArch64::ST3B, AArch64::ST3B_IMM); return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) { - SelectPredicatedStore</*Scale=*/1>(Node, 3, AArch64::ST3H, - AArch64::ST3H_IMM); + SelectPredicatedStore(Node, 3, 1, AArch64::ST3H, AArch64::ST3H_IMM); return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - SelectPredicatedStore</*Scale=*/2>(Node, 3, AArch64::ST3W, - AArch64::ST3W_IMM); + SelectPredicatedStore(Node, 3, 2, AArch64::ST3W, AArch64::ST3W_IMM); return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - SelectPredicatedStore</*Scale=*/3>(Node, 3, AArch64::ST3D, - AArch64::ST3D_IMM); + SelectPredicatedStore(Node, 3, 3, AArch64::ST3D, AArch64::ST3D_IMM); return; } break; } case Intrinsic::aarch64_sve_st4: { if (VT == MVT::nxv16i8) { - SelectPredicatedStore</*Scale=*/0>(Node, 4, AArch64::ST4B, - AArch64::ST4B_IMM); + SelectPredicatedStore(Node, 4, 0, AArch64::ST4B, AArch64::ST4B_IMM); return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) { - SelectPredicatedStore</*Scale=*/1>(Node, 4, AArch64::ST4H, - AArch64::ST4H_IMM); + SelectPredicatedStore(Node, 4, 1, AArch64::ST4H, AArch64::ST4H_IMM); return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - SelectPredicatedStore</*Scale=*/2>(Node, 4, AArch64::ST4W, - AArch64::ST4W_IMM); + SelectPredicatedStore(Node, 4, 2, AArch64::ST4W, AArch64::ST4W_IMM); return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - SelectPredicatedStore</*Scale=*/3>(Node, 4, AArch64::ST4D, - AArch64::ST4D_IMM); + SelectPredicatedStore(Node, 4, 3, AArch64::ST4D, AArch64::ST4D_IMM); return; } break; @@ -4741,51 +4736,51 @@ void AArch64DAGToDAGISel::Select(SDNode *Node) { } case AArch64ISD::SVE_LD2_MERGE_ZERO: { if (VT == MVT::nxv16i8) { - SelectPredicatedLoad(Node, 2, AArch64::LD2B_IMM); + SelectPredicatedLoad(Node, 2, 0, AArch64::LD2B_IMM, AArch64::LD2B); return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) { - SelectPredicatedLoad(Node, 2, AArch64::LD2H_IMM); + SelectPredicatedLoad(Node, 2, 1, AArch64::LD2H_IMM, AArch64::LD2H); return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - SelectPredicatedLoad(Node, 2, AArch64::LD2W_IMM); + SelectPredicatedLoad(Node, 2, 2, AArch64::LD2W_IMM, AArch64::LD2W); return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - SelectPredicatedLoad(Node, 2, AArch64::LD2D_IMM); + SelectPredicatedLoad(Node, 2, 3, AArch64::LD2D_IMM, AArch64::LD2D); return; } break; } case AArch64ISD::SVE_LD3_MERGE_ZERO: { if (VT == MVT::nxv16i8) { - SelectPredicatedLoad(Node, 3, AArch64::LD3B_IMM); + SelectPredicatedLoad(Node, 3, 0, AArch64::LD3B_IMM, AArch64::LD3B); return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) { - SelectPredicatedLoad(Node, 3, AArch64::LD3H_IMM); + SelectPredicatedLoad(Node, 3, 1, AArch64::LD3H_IMM, AArch64::LD3H); return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - SelectPredicatedLoad(Node, 3, AArch64::LD3W_IMM); + SelectPredicatedLoad(Node, 3, 2, AArch64::LD3W_IMM, AArch64::LD3W); return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - SelectPredicatedLoad(Node, 3, AArch64::LD3D_IMM); + SelectPredicatedLoad(Node, 3, 3, AArch64::LD3D_IMM, AArch64::LD3D); return; } break; } case AArch64ISD::SVE_LD4_MERGE_ZERO: { if (VT == MVT::nxv16i8) { - SelectPredicatedLoad(Node, 4, AArch64::LD4B_IMM); + SelectPredicatedLoad(Node, 4, 0, AArch64::LD4B_IMM, AArch64::LD4B); return; } else if (VT == MVT::nxv8i16 || VT == MVT::nxv8f16 || (VT == MVT::nxv8bf16 && Subtarget->hasBF16())) { - SelectPredicatedLoad(Node, 4, AArch64::LD4H_IMM); + SelectPredicatedLoad(Node, 4, 1, AArch64::LD4H_IMM, AArch64::LD4H); return; } else if (VT == MVT::nxv4i32 || VT == MVT::nxv4f32) { - SelectPredicatedLoad(Node, 4, AArch64::LD4W_IMM); + SelectPredicatedLoad(Node, 4, 2, AArch64::LD4W_IMM, AArch64::LD4W); return; } else if (VT == MVT::nxv2i64 || VT == MVT::nxv2f64) { - SelectPredicatedLoad(Node, 4, AArch64::LD4D_IMM); + SelectPredicatedLoad(Node, 4, 3, AArch64::LD4D_IMM, AArch64::LD4D); return; } break; @@ -4805,10 +4800,14 @@ FunctionPass *llvm::createAArch64ISelDag(AArch64TargetMachine &TM, /// When \p PredVT is a scalable vector predicate in the form /// MVT::nx<M>xi1, it builds the correspondent scalable vector of -/// integers MVT::nx<M>xi<bits> s.t. M x bits = 128. If the input +/// integers MVT::nx<M>xi<bits> s.t. M x bits = 128. When targeting +/// structured vectors (NumVec >1), the output data type is +/// MVT::nx<M*NumVec>xi<bits> s.t. M x bits = 128. If the input /// PredVT is not in the form MVT::nx<M>xi1, it returns an invalid /// EVT. -static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT) { +static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT, + unsigned NumVec) { + assert(NumVec > 0 && NumVec < 5 && "Invalid number of vectors."); if (!PredVT.isScalableVector() || PredVT.getVectorElementType() != MVT::i1) return EVT(); @@ -4818,7 +4817,8 @@ static EVT getPackedVectorTypeFromPredicateType(LLVMContext &Ctx, EVT PredVT) { ElementCount EC = PredVT.getVectorElementCount(); EVT ScalarVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.Min); - EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, EC); + EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, EC * NumVec); + return MemVT; } @@ -4842,6 +4842,15 @@ static EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) { return cast<VTSDNode>(Root->getOperand(3))->getVT(); case AArch64ISD::ST1_PRED: return cast<VTSDNode>(Root->getOperand(4))->getVT(); + case AArch64ISD::SVE_LD2_MERGE_ZERO: + return getPackedVectorTypeFromPredicateType( + Ctx, Root->getOperand(1)->getValueType(0), /*NumVec=*/2); + case AArch64ISD::SVE_LD3_MERGE_ZERO: + return getPackedVectorTypeFromPredicateType( + Ctx, Root->getOperand(1)->getValueType(0), /*NumVec=*/3); + case AArch64ISD::SVE_LD4_MERGE_ZERO: + return getPackedVectorTypeFromPredicateType( + Ctx, Root->getOperand(1)->getValueType(0), /*NumVec=*/4); default: break; } @@ -4857,7 +4866,7 @@ static EVT getMemVTFromNode(LLVMContext &Ctx, SDNode *Root) { // We are using an SVE prefetch intrinsic. Type must be inferred // from the width of the predicate. return getPackedVectorTypeFromPredicateType( - Ctx, Root->getOperand(2)->getValueType(0)); + Ctx, Root->getOperand(2)->getValueType(0), /*NumVec=*/1); } /// SelectAddrModeIndexedSVE - Attempt selection of the addressing mode: diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 85db14ab66fe..1500da2fdfc7 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -932,8 +932,11 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::SRL, VT, Custom); setOperationAction(ISD::SRA, VT, Custom); - if (VT.getScalarType() == MVT::i1) + if (VT.getScalarType() == MVT::i1) { setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::CONCAT_VECTORS, VT, Legal); + } } } @@ -8858,6 +8861,16 @@ SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); + if (VT.getScalarType() == MVT::i1) { + // Lower i1 truncate to `(x & 1) != 0`. + SDLoc dl(Op); + EVT OpVT = Op.getOperand(0).getValueType(); + SDValue Zero = DAG.getConstant(0, dl, OpVT); + SDValue One = DAG.getConstant(1, dl, OpVT); + SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One); + return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE); + } + if (!VT.isVector() || VT.isScalableVector()) return Op; @@ -12288,6 +12301,9 @@ static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) { "Unsupported opcode."); SDLoc DL(N); EVT VT = N->getValueType(0); + if (VT == MVT::nxv8bf16 && + !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16()) + return SDValue(); EVT LoadVT = VT; if (VT.isFloatingPoint()) @@ -14909,6 +14925,11 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const { if (isa<ScalableVectorType>(Inst.getOperand(i)->getType())) return true; + if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) { + if (isa<ScalableVectorType>(AI->getAllocatedType())) + return true; + } + return false; } diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 6df7970f4d82..4f4ba692c2db 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -495,6 +495,9 @@ def SImmS4XForm : SDNodeXForm<imm, [{ def SImmS16XForm : SDNodeXForm<imm, [{ return CurDAG->getTargetConstant(N->getSExtValue() / 16, SDLoc(N), MVT::i64); }]>; +def SImmS32XForm : SDNodeXForm<imm, [{ + return CurDAG->getTargetConstant(N->getSExtValue() / 32, SDLoc(N), MVT::i64); +}]>; // simm6sN predicate - True if the immediate is a multiple of N in the range // [-32 * N, 31 * N]. @@ -546,7 +549,7 @@ def simm4s16 : Operand<i64>, ImmLeaf<i64, let DecoderMethod = "DecodeSImm<4>"; } def simm4s32 : Operand<i64>, ImmLeaf<i64, -[{ return Imm >=-256 && Imm <= 224 && (Imm % 32) == 0x0; }]> { +[{ return Imm >=-256 && Imm <= 224 && (Imm % 32) == 0x0; }], SImmS32XForm> { let PrintMethod = "printImmScale<32>"; let ParserMatchClass = SImm4s32Operand; let DecoderMethod = "DecodeSImm<4>"; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 5139ae5ccaf1..08f80c9aa361 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2744,6 +2744,35 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, return; } + // Copy a Z register pair by copying the individual sub-registers. + if (AArch64::ZPR2RegClass.contains(DestReg) && + AArch64::ZPR2RegClass.contains(SrcReg)) { + static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1}; + copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ, + Indices); + return; + } + + // Copy a Z register triple by copying the individual sub-registers. + if (AArch64::ZPR3RegClass.contains(DestReg) && + AArch64::ZPR3RegClass.contains(SrcReg)) { + static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1, + AArch64::zsub2}; + copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ, + Indices); + return; + } + + // Copy a Z register quad by copying the individual sub-registers. + if (AArch64::ZPR4RegClass.contains(DestReg) && + AArch64::ZPR4RegClass.contains(SrcReg)) { + static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1, + AArch64::zsub2, AArch64::zsub3}; + copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ, + Indices); + return; + } + if (AArch64::GPR64spRegClass.contains(DestReg) && (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) { if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp index 886158ca4490..83a488afc797 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -40,6 +40,14 @@ AArch64RegisterInfo::AArch64RegisterInfo(const Triple &TT) AArch64_MC::initLLVMToCVRegMapping(this); } +static bool hasSVEArgsOrReturn(const MachineFunction *MF) { + const Function &F = MF->getFunction(); + return isa<ScalableVectorType>(F.getReturnType()) || + any_of(F.args(), [](const Argument &Arg) { + return isa<ScalableVectorType>(Arg.getType()); + }); +} + const MCPhysReg * AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { assert(MF && "Invalid MachineFunction pointer."); @@ -75,6 +83,8 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { // This is for OSes other than Windows; Windows is a separate case further // above. return CSR_AArch64_AAPCS_X18_SaveList; + if (hasSVEArgsOrReturn(MF)) + return CSR_AArch64_SVE_AAPCS_SaveList; return CSR_AArch64_AAPCS_SaveList; } @@ -343,6 +353,15 @@ bool AArch64RegisterInfo::hasBasePointer(const MachineFunction &MF) const { if (MFI.hasVarSizedObjects() || MF.hasEHFunclets()) { if (needsStackRealignment(MF)) return true; + + if (MF.getSubtarget<AArch64Subtarget>().hasSVE()) { + const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); + // Frames that have variable sized objects and scalable SVE objects, + // should always use a basepointer. + if (!AFI->hasCalculatedStackSizeSVE() || AFI->getStackSizeSVE()) + return true; + } + // Conservatively estimate whether the negative offset from the frame // pointer will be sufficient to reach. If a function has a smallish // frame, it's less likely to have lots of spills and callee saved @@ -379,8 +398,15 @@ AArch64RegisterInfo::useFPForScavengingIndex(const MachineFunction &MF) const { // (closer to SP). // // The beginning works most reliably if we have a frame pointer. + // In the presence of any non-constant space between FP and locals, + // (e.g. in case of stack realignment or a scalable SVE area), it is + // better to use SP or BP. const AArch64FrameLowering &TFI = *getFrameLowering(MF); - return TFI.hasFP(MF); + const AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>(); + assert((!MF.getSubtarget<AArch64Subtarget>().hasSVE() || + AFI->hasCalculatedStackSizeSVE()) && + "Expected SVE area to be calculated by this point"); + return TFI.hasFP(MF) && !needsStackRealignment(MF) && !AFI->getStackSizeSVE(); } bool AArch64RegisterInfo::requiresFrameIndexScavenging( diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 28a54e6f7d79..3449a8bd16d2 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1109,6 +1109,28 @@ multiclass sve_prefetch<SDPatternOperator prefetch, ValueType PredTy, Instructio defm TRN1_PPP : sve_int_perm_bin_perm_pp<0b100, "trn1", AArch64trn1>; defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2", AArch64trn2>; + // Extract lo/hi halves of legal predicate types. + def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 0))), + (ZIP1_PPP_S PPR:$Ps, (PFALSE))>; + def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 2))), + (ZIP2_PPP_S PPR:$Ps, (PFALSE))>; + def : Pat<(nxv4i1 (extract_subvector (nxv8i1 PPR:$Ps), (i64 0))), + (ZIP1_PPP_H PPR:$Ps, (PFALSE))>; + def : Pat<(nxv4i1 (extract_subvector (nxv8i1 PPR:$Ps), (i64 4))), + (ZIP2_PPP_H PPR:$Ps, (PFALSE))>; + def : Pat<(nxv8i1 (extract_subvector (nxv16i1 PPR:$Ps), (i64 0))), + (ZIP1_PPP_B PPR:$Ps, (PFALSE))>; + def : Pat<(nxv8i1 (extract_subvector (nxv16i1 PPR:$Ps), (i64 8))), + (ZIP2_PPP_B PPR:$Ps, (PFALSE))>; + + // Concatenate two predicates. + def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)), + (UZP1_PPP_S $p1, $p2)>; + def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)), + (UZP1_PPP_H $p1, $p2)>; + def : Pat<(nxv16i1 (concat_vectors nxv8i1:$p1, nxv8i1:$p2)), + (UZP1_PPP_B $p1, $p2)>; + defm CMPHS_PPzZZ : sve_int_cmp_0<0b000, "cmphs", SETUGE, SETULE>; defm CMPHI_PPzZZ : sve_int_cmp_0<0b001, "cmphi", SETUGT, SETULT>; defm CMPGE_PPzZZ : sve_int_cmp_0<0b100, "cmpge", SETGE, SETLE>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index a005d1e65abe..c56a65b9e212 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -7718,9 +7718,13 @@ multiclass sve_mem_ldor_si<bits<2> sz, string asm, RegisterOperand listty, (!cast<Instruction>(NAME) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s32:$imm4), 0>; // Base addressing mode - def : Pat<(Ty (Ld1ro (PredTy PPR3bAny:$gp), GPR64sp:$base)), - (!cast<Instruction>(NAME) PPR3bAny:$gp, GPR64sp:$base, (i64 0))>; - + def : Pat<(Ty (Ld1ro (PredTy PPR3bAny:$Pg), GPR64sp:$base)), + (!cast<Instruction>(NAME) PPR3bAny:$Pg, GPR64sp:$base, (i64 0))>; + let AddedComplexity = 2 in { + // Reg + Imm addressing mode + def : Pat<(Ty (Ld1ro (PredTy PPR3bAny:$Pg), (add GPR64:$base, (i64 simm4s32:$imm)))), + (!cast<Instruction>(NAME) $Pg, $base, simm4s32:$imm)>; + } } class sve_mem_ldor_ss<bits<2> sz, string asm, RegisterOperand VecList, diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp index 11454841cab7..5c1a4cb16568 100644 --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -9111,13 +9111,15 @@ SDValue PPCTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { Op0.getOperand(1)); } -static const SDValue *getNormalLoadInput(const SDValue &Op) { +static const SDValue *getNormalLoadInput(const SDValue &Op, bool &IsPermuted) { const SDValue *InputLoad = &Op; if (InputLoad->getOpcode() == ISD::BITCAST) InputLoad = &InputLoad->getOperand(0); if (InputLoad->getOpcode() == ISD::SCALAR_TO_VECTOR || - InputLoad->getOpcode() == PPCISD::SCALAR_TO_VECTOR_PERMUTED) + InputLoad->getOpcode() == PPCISD::SCALAR_TO_VECTOR_PERMUTED) { + IsPermuted = InputLoad->getOpcode() == PPCISD::SCALAR_TO_VECTOR_PERMUTED; InputLoad = &InputLoad->getOperand(0); + } if (InputLoad->getOpcode() != ISD::LOAD) return nullptr; LoadSDNode *LD = cast<LoadSDNode>(*InputLoad); @@ -9289,7 +9291,9 @@ SDValue PPCTargetLowering::LowerBUILD_VECTOR(SDValue Op, if (!BVNIsConstantSplat || SplatBitSize > 32) { - const SDValue *InputLoad = getNormalLoadInput(Op.getOperand(0)); + bool IsPermutedLoad = false; + const SDValue *InputLoad = + getNormalLoadInput(Op.getOperand(0), IsPermutedLoad); // Handle load-and-splat patterns as we have instructions that will do this // in one go. if (InputLoad && DAG.isSplatValue(Op, true)) { @@ -9912,7 +9916,8 @@ SDValue PPCTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, // If this is a load-and-splat, we can do that with a single instruction // in some cases. However if the load has multiple uses, we don't want to // combine it because that will just produce multiple loads. - const SDValue *InputLoad = getNormalLoadInput(V1); + bool IsPermutedLoad = false; + const SDValue *InputLoad = getNormalLoadInput(V1, IsPermutedLoad); if (InputLoad && Subtarget.hasVSX() && V2.isUndef() && (PPC::isSplatShuffleMask(SVOp, 4) || PPC::isSplatShuffleMask(SVOp, 8)) && InputLoad->hasOneUse()) { @@ -9920,6 +9925,16 @@ SDValue PPCTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, int SplatIdx = PPC::getSplatIdxForPPCMnemonics(SVOp, IsFourByte ? 4 : 8, DAG); + // The splat index for permuted loads will be in the left half of the vector + // which is strictly wider than the loaded value by 8 bytes. So we need to + // adjust the splat index to point to the correct address in memory. + if (IsPermutedLoad) { + assert(isLittleEndian && "Unexpected permuted load on big endian target"); + SplatIdx += IsFourByte ? 2 : 1; + assert((SplatIdx < (IsFourByte ? 4 : 2)) && + "Splat of a value outside of the loaded memory"); + } + LoadSDNode *LD = cast<LoadSDNode>(*InputLoad); // For 4-byte load-and-splat, we need Power9. if ((IsFourByte && Subtarget.hasP9Vector()) || !IsFourByte) { @@ -9929,10 +9944,6 @@ SDValue PPCTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, else Offset = isLittleEndian ? (1 - SplatIdx) * 8 : SplatIdx * 8; - // If we are loading a partial vector, it does not make sense to adjust - // the base pointer. This happens with (splat (s_to_v_permuted (ld))). - if (LD->getMemoryVT().getSizeInBits() == (IsFourByte ? 32 : 64)) - Offset = 0; SDValue BasePtr = LD->getBasePtr(); if (Offset != 0) BasePtr = DAG.getNode(ISD::ADD, dl, getPointerTy(DAG.getDataLayout()), diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index a0ae05081adc..7570385e38e3 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -184,6 +184,330 @@ bool RISCVDAGToDAGISel::SelectAddrFI(SDValue Addr, SDValue &Base) { return false; } +// Check that it is a SLOI (Shift Left Ones Immediate). We first check that +// it is the right node tree: +// +// (OR (SHL RS1, VC2), VC1) +// +// and then we check that VC1, the mask used to fill with ones, is compatible +// with VC2, the shamt: +// +// VC1 == maskTrailingOnes<uint64_t>(VC2) + +bool RISCVDAGToDAGISel::SelectSLOI(SDValue N, SDValue &RS1, SDValue &Shamt) { + MVT XLenVT = Subtarget->getXLenVT(); + if (N.getOpcode() == ISD::OR) { + SDValue Or = N; + if (Or.getOperand(0).getOpcode() == ISD::SHL) { + SDValue Shl = Or.getOperand(0); + if (isa<ConstantSDNode>(Shl.getOperand(1)) && + isa<ConstantSDNode>(Or.getOperand(1))) { + if (XLenVT == MVT::i64) { + uint64_t VC1 = Or.getConstantOperandVal(1); + uint64_t VC2 = Shl.getConstantOperandVal(1); + if (VC1 == maskTrailingOnes<uint64_t>(VC2)) { + RS1 = Shl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Shl.getOperand(1).getValueType()); + return true; + } + } + if (XLenVT == MVT::i32) { + uint32_t VC1 = Or.getConstantOperandVal(1); + uint32_t VC2 = Shl.getConstantOperandVal(1); + if (VC1 == maskTrailingOnes<uint32_t>(VC2)) { + RS1 = Shl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Shl.getOperand(1).getValueType()); + return true; + } + } + } + } + } + return false; +} + +// Check that it is a SROI (Shift Right Ones Immediate). We first check that +// it is the right node tree: +// +// (OR (SRL RS1, VC2), VC1) +// +// and then we check that VC1, the mask used to fill with ones, is compatible +// with VC2, the shamt: +// +// VC1 == maskLeadingOnes<uint64_t>(VC2) + +bool RISCVDAGToDAGISel::SelectSROI(SDValue N, SDValue &RS1, SDValue &Shamt) { + MVT XLenVT = Subtarget->getXLenVT(); + if (N.getOpcode() == ISD::OR) { + SDValue Or = N; + if (Or.getOperand(0).getOpcode() == ISD::SRL) { + SDValue Srl = Or.getOperand(0); + if (isa<ConstantSDNode>(Srl.getOperand(1)) && + isa<ConstantSDNode>(Or.getOperand(1))) { + if (XLenVT == MVT::i64) { + uint64_t VC1 = Or.getConstantOperandVal(1); + uint64_t VC2 = Srl.getConstantOperandVal(1); + if (VC1 == maskLeadingOnes<uint64_t>(VC2)) { + RS1 = Srl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Srl.getOperand(1).getValueType()); + return true; + } + } + if (XLenVT == MVT::i32) { + uint32_t VC1 = Or.getConstantOperandVal(1); + uint32_t VC2 = Srl.getConstantOperandVal(1); + if (VC1 == maskLeadingOnes<uint32_t>(VC2)) { + RS1 = Srl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Srl.getOperand(1).getValueType()); + return true; + } + } + } + } + } + return false; +} + +// Check that it is a RORI (Rotate Right Immediate). We first check that +// it is the right node tree: +// +// (ROTL RS1, VC) +// +// The compiler translates immediate rotations to the right given by the call +// to the rotateright32/rotateright64 intrinsics as rotations to the left. +// Since the rotation to the left can be easily emulated as a rotation to the +// right by negating the constant, there is no encoding for ROLI. +// We then select the immediate left rotations as RORI by the complementary +// constant: +// +// Shamt == XLen - VC + +bool RISCVDAGToDAGISel::SelectRORI(SDValue N, SDValue &RS1, SDValue &Shamt) { + MVT XLenVT = Subtarget->getXLenVT(); + if (N.getOpcode() == ISD::ROTL) { + if (isa<ConstantSDNode>(N.getOperand(1))) { + if (XLenVT == MVT::i64) { + uint64_t VC = N.getConstantOperandVal(1); + Shamt = CurDAG->getTargetConstant((64 - VC), SDLoc(N), + N.getOperand(1).getValueType()); + RS1 = N.getOperand(0); + return true; + } + if (XLenVT == MVT::i32) { + uint32_t VC = N.getConstantOperandVal(1); + Shamt = CurDAG->getTargetConstant((32 - VC), SDLoc(N), + N.getOperand(1).getValueType()); + RS1 = N.getOperand(0); + return true; + } + } + } + return false; +} + + +// Check that it is a SLLIUW (Shift Logical Left Immediate Unsigned i32 +// on RV64). +// SLLIUW is the same as SLLI except for the fact that it clears the bits +// XLEN-1:32 of the input RS1 before shifting. +// We first check that it is the right node tree: +// +// (AND (SHL RS1, VC2), VC1) +// +// We check that VC2, the shamt is less than 32, otherwise the pattern is +// exactly the same as SLLI and we give priority to that. +// Eventually we check that that VC1, the mask used to clear the upper 32 bits +// of RS1, is correct: +// +// VC1 == (0xFFFFFFFF << VC2) + +bool RISCVDAGToDAGISel::SelectSLLIUW(SDValue N, SDValue &RS1, SDValue &Shamt) { + if (N.getOpcode() == ISD::AND && Subtarget->getXLenVT() == MVT::i64) { + SDValue And = N; + if (And.getOperand(0).getOpcode() == ISD::SHL) { + SDValue Shl = And.getOperand(0); + if (isa<ConstantSDNode>(Shl.getOperand(1)) && + isa<ConstantSDNode>(And.getOperand(1))) { + uint64_t VC1 = And.getConstantOperandVal(1); + uint64_t VC2 = Shl.getConstantOperandVal(1); + if (VC2 < 32 && VC1 == ((uint64_t)0xFFFFFFFF << VC2)) { + RS1 = Shl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Shl.getOperand(1).getValueType()); + return true; + } + } + } + } + return false; +} + +// Check that it is a SLOIW (Shift Left Ones Immediate i32 on RV64). +// We first check that it is the right node tree: +// +// (SIGN_EXTEND_INREG (OR (SHL RS1, VC2), VC1)) +// +// and then we check that VC1, the mask used to fill with ones, is compatible +// with VC2, the shamt: +// +// VC1 == maskTrailingOnes<uint32_t>(VC2) + +bool RISCVDAGToDAGISel::SelectSLOIW(SDValue N, SDValue &RS1, SDValue &Shamt) { + if (Subtarget->getXLenVT() == MVT::i64 && + N.getOpcode() == ISD::SIGN_EXTEND_INREG && + cast<VTSDNode>(N.getOperand(1))->getVT() == MVT::i32) { + if (N.getOperand(0).getOpcode() == ISD::OR) { + SDValue Or = N.getOperand(0); + if (Or.getOperand(0).getOpcode() == ISD::SHL) { + SDValue Shl = Or.getOperand(0); + if (isa<ConstantSDNode>(Shl.getOperand(1)) && + isa<ConstantSDNode>(Or.getOperand(1))) { + uint32_t VC1 = Or.getConstantOperandVal(1); + uint32_t VC2 = Shl.getConstantOperandVal(1); + if (VC1 == maskTrailingOnes<uint32_t>(VC2)) { + RS1 = Shl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Shl.getOperand(1).getValueType()); + return true; + } + } + } + } + } + return false; +} + +// Check that it is a SROIW (Shift Right Ones Immediate i32 on RV64). +// We first check that it is the right node tree: +// +// (OR (SHL RS1, VC2), VC1) +// +// and then we check that VC1, the mask used to fill with ones, is compatible +// with VC2, the shamt: +// +// VC1 == maskLeadingOnes<uint32_t>(VC2) + +bool RISCVDAGToDAGISel::SelectSROIW(SDValue N, SDValue &RS1, SDValue &Shamt) { + if (N.getOpcode() == ISD::OR && Subtarget->getXLenVT() == MVT::i64) { + SDValue Or = N; + if (Or.getOperand(0).getOpcode() == ISD::SRL) { + SDValue Srl = Or.getOperand(0); + if (isa<ConstantSDNode>(Srl.getOperand(1)) && + isa<ConstantSDNode>(Or.getOperand(1))) { + uint32_t VC1 = Or.getConstantOperandVal(1); + uint32_t VC2 = Srl.getConstantOperandVal(1); + if (VC1 == maskLeadingOnes<uint32_t>(VC2)) { + RS1 = Srl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Srl.getOperand(1).getValueType()); + return true; + } + } + } + } + return false; +} + +// Check that it is a RORIW (i32 Right Rotate Immediate on RV64). +// We first check that it is the right node tree: +// +// (SIGN_EXTEND_INREG (OR (SHL (AsserSext RS1, i32), VC2), +// (SRL (AND (AssertSext RS2, i32), VC3), VC1))) +// +// Then we check that the constant operands respect these constraints: +// +// VC2 == 32 - VC1 +// VC3 == maskLeadingOnes<uint32_t>(VC2) +// +// being VC1 the Shamt we need, VC2 the complementary of Shamt over 32 +// and VC3 a 32 bit mask of (32 - VC1) leading ones. + +bool RISCVDAGToDAGISel::SelectRORIW(SDValue N, SDValue &RS1, SDValue &Shamt) { + if (N.getOpcode() == ISD::SIGN_EXTEND_INREG && + Subtarget->getXLenVT() == MVT::i64 && + cast<VTSDNode>(N.getOperand(1))->getVT() == MVT::i32) { + if (N.getOperand(0).getOpcode() == ISD::OR) { + SDValue Or = N.getOperand(0); + if (Or.getOperand(0).getOpcode() == ISD::SHL && + Or.getOperand(1).getOpcode() == ISD::SRL) { + SDValue Shl = Or.getOperand(0); + SDValue Srl = Or.getOperand(1); + if (Srl.getOperand(0).getOpcode() == ISD::AND) { + SDValue And = Srl.getOperand(0); + if (isa<ConstantSDNode>(Srl.getOperand(1)) && + isa<ConstantSDNode>(Shl.getOperand(1)) && + isa<ConstantSDNode>(And.getOperand(1))) { + uint32_t VC1 = Srl.getConstantOperandVal(1); + uint32_t VC2 = Shl.getConstantOperandVal(1); + uint32_t VC3 = And.getConstantOperandVal(1); + if (VC2 == (32 - VC1) && + VC3 == maskLeadingOnes<uint32_t>(VC2)) { + RS1 = Shl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC1, SDLoc(N), + Srl.getOperand(1).getValueType()); + return true; + } + } + } + } + } + } + return false; +} + +// Check that it is a FSRIW (i32 Funnel Shift Right Immediate on RV64). +// We first check that it is the right node tree: +// +// (SIGN_EXTEND_INREG (OR (SHL (AsserSext RS1, i32), VC2), +// (SRL (AND (AssertSext RS2, i32), VC3), VC1))) +// +// Then we check that the constant operands respect these constraints: +// +// VC2 == 32 - VC1 +// VC3 == maskLeadingOnes<uint32_t>(VC2) +// +// being VC1 the Shamt we need, VC2 the complementary of Shamt over 32 +// and VC3 a 32 bit mask of (32 - VC1) leading ones. + +bool RISCVDAGToDAGISel::SelectFSRIW(SDValue N, SDValue &RS1, SDValue &RS2, + SDValue &Shamt) { + if (N.getOpcode() == ISD::SIGN_EXTEND_INREG && + Subtarget->getXLenVT() == MVT::i64 && + cast<VTSDNode>(N.getOperand(1))->getVT() == MVT::i32) { + if (N.getOperand(0).getOpcode() == ISD::OR) { + SDValue Or = N.getOperand(0); + if (Or.getOperand(0).getOpcode() == ISD::SHL && + Or.getOperand(1).getOpcode() == ISD::SRL) { + SDValue Shl = Or.getOperand(0); + SDValue Srl = Or.getOperand(1); + if (Srl.getOperand(0).getOpcode() == ISD::AND) { + SDValue And = Srl.getOperand(0); + if (isa<ConstantSDNode>(Srl.getOperand(1)) && + isa<ConstantSDNode>(Shl.getOperand(1)) && + isa<ConstantSDNode>(And.getOperand(1))) { + uint32_t VC1 = Srl.getConstantOperandVal(1); + uint32_t VC2 = Shl.getConstantOperandVal(1); + uint32_t VC3 = And.getConstantOperandVal(1); + if (VC2 == (32 - VC1) && + VC3 == maskLeadingOnes<uint32_t>(VC2)) { + RS1 = Shl.getOperand(0); + RS2 = And.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC1, SDLoc(N), + Srl.getOperand(1).getValueType()); + return true; + } + } + } + } + } + } + return false; +} + // Merge an ADDI into the offset of a load/store instruction where possible. // (load (addi base, off1), off2) -> (load base, off1+off2) // (store val, (addi base, off1), off2) -> (store val, base, off1+off2) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h index dcf733ec3675..0ca12510a230 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -45,6 +45,15 @@ public: bool SelectAddrFI(SDValue Addr, SDValue &Base); + bool SelectSLOI(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectSROI(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectRORI(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectSLLIUW(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectSLOIW(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectSROIW(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectRORIW(SDValue N, SDValue &RS1, SDValue &Shamt); + bool SelectFSRIW(SDValue N, SDValue &RS1, SDValue &RS2, SDValue &Shamt); + // Include the pieces autogenerated from the target description. #include "RISCVGenDAGISel.inc" diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 91fc69b5bc10..03d9eefd59d0 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -149,12 +149,27 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SRL_PARTS, XLenVT, Custom); setOperationAction(ISD::SRA_PARTS, XLenVT, Custom); - setOperationAction(ISD::ROTL, XLenVT, Expand); - setOperationAction(ISD::ROTR, XLenVT, Expand); - setOperationAction(ISD::BSWAP, XLenVT, Expand); - setOperationAction(ISD::CTTZ, XLenVT, Expand); - setOperationAction(ISD::CTLZ, XLenVT, Expand); - setOperationAction(ISD::CTPOP, XLenVT, Expand); + if (!(Subtarget.hasStdExtZbb() || Subtarget.hasStdExtZbp())) { + setOperationAction(ISD::ROTL, XLenVT, Expand); + setOperationAction(ISD::ROTR, XLenVT, Expand); + } + + if (!Subtarget.hasStdExtZbp()) + setOperationAction(ISD::BSWAP, XLenVT, Expand); + + if (!Subtarget.hasStdExtZbb()) { + setOperationAction(ISD::CTTZ, XLenVT, Expand); + setOperationAction(ISD::CTLZ, XLenVT, Expand); + setOperationAction(ISD::CTPOP, XLenVT, Expand); + } + + if (Subtarget.hasStdExtZbp()) + setOperationAction(ISD::BITREVERSE, XLenVT, Legal); + + if (Subtarget.hasStdExtZbt()) { + setOperationAction(ISD::FSHL, XLenVT, Legal); + setOperationAction(ISD::FSHR, XLenVT, Legal); + } ISD::CondCode FPCCToExtend[] = { ISD::SETOGT, ISD::SETOGE, ISD::SETONE, ISD::SETUEQ, ISD::SETUGT, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoB.td index 34a463626e29..afac509f743d 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoB.td @@ -632,3 +632,432 @@ let Predicates = [HasStdExtZbproposedc, HasStdExtZbbOrZbp, HasStdExtC, IsRV64] i def : CompressPat<(PACK GPRC:$rs1, GPRC:$rs1, X0), (C_ZEXTW GPRC:$rs1)>; } // Predicates = [HasStdExtZbproposedc, HasStdExtC, IsRV64] + +//===----------------------------------------------------------------------===// +// Codegen patterns +//===----------------------------------------------------------------------===// +def SLOIPat : ComplexPattern<XLenVT, 2, "SelectSLOI", [or]>; +def SROIPat : ComplexPattern<XLenVT, 2, "SelectSROI", [or]>; +def RORIPat : ComplexPattern<XLenVT, 2, "SelectRORI", [rotl]>; +def SLLIUWPat : ComplexPattern<i64, 2, "SelectSLLIUW", [and]>; +def SLOIWPat : ComplexPattern<i64, 2, "SelectSLOIW", [sext_inreg]>; +def SROIWPat : ComplexPattern<i64, 2, "SelectSROIW", [or]>; +def RORIWPat : ComplexPattern<i64, 2, "SelectRORIW", [sext_inreg]>; +def FSRIWPat : ComplexPattern<i64, 3, "SelectFSRIW", [sext_inreg]>; + +let Predicates = [HasStdExtZbbOrZbp] in { +def : Pat<(and GPR:$rs1, (not GPR:$rs2)), (ANDN GPR:$rs1, GPR:$rs2)>; +def : Pat<(or GPR:$rs1, (not GPR:$rs2)), (ORN GPR:$rs1, GPR:$rs2)>; +def : Pat<(xor GPR:$rs1, (not GPR:$rs2)), (XNOR GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbbOrZbp] + +let Predicates = [HasStdExtZbb] in { +def : Pat<(xor (shl (xor GPR:$rs1, -1), GPR:$rs2), -1), + (SLO GPR:$rs1, GPR:$rs2)>; +def : Pat<(xor (srl (xor GPR:$rs1, -1), GPR:$rs2), -1), + (SRO GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbb] + +let Predicates = [HasStdExtZbbOrZbp] in { +def : Pat<(rotl GPR:$rs1, GPR:$rs2), (ROL GPR:$rs1, GPR:$rs2)>; +def : Pat<(fshl GPR:$rs1, GPR:$rs1, GPR:$rs2), (ROL GPR:$rs1, GPR:$rs2)>; +def : Pat<(rotr GPR:$rs1, GPR:$rs2), (ROR GPR:$rs1, GPR:$rs2)>; +def : Pat<(fshr GPR:$rs1, GPR:$rs1, GPR:$rs2), (ROR GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbbOrZbp] + +let Predicates = [HasStdExtZbs, IsRV32] in +def : Pat<(and (xor (shl 1, (and GPR:$rs2, 31)), -1), GPR:$rs1), + (SBCLR GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbs, IsRV64] in +def : Pat<(and (xor (shl 1, (and GPR:$rs2, 63)), -1), GPR:$rs1), + (SBCLR GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbs] in +def : Pat<(and (rotl -2, GPR:$rs2), GPR:$rs1), (SBCLR GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbs, IsRV32] in +def : Pat<(or (shl 1, (and GPR:$rs2, 31)), GPR:$rs1), + (SBSET GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbs, IsRV64] in +def : Pat<(or (shl 1, (and GPR:$rs2, 63)), GPR:$rs1), + (SBSET GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbs, IsRV32] in +def : Pat<(xor (shl 1, (and GPR:$rs2, 31)), GPR:$rs1), + (SBINV GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbs, IsRV64] in +def : Pat<(xor (shl 1, (and GPR:$rs2, 63)), GPR:$rs1), + (SBINV GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbs, IsRV32] in +def : Pat<(and (srl GPR:$rs1, (and GPR:$rs2, 31)), 1), + (SBEXT GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbs, IsRV64] in +def : Pat<(and (srl GPR:$rs1, (and GPR:$rs2, 63)), 1), + (SBEXT GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbb] in { +def : Pat<(SLOIPat GPR:$rs1, uimmlog2xlen:$shamt), + (SLOI GPR:$rs1, uimmlog2xlen:$shamt)>; +def : Pat<(SROIPat GPR:$rs1, uimmlog2xlen:$shamt), + (SROI GPR:$rs1, uimmlog2xlen:$shamt)>; +} // Predicates = [HasStdExtZbb] + +// There's no encoding for roli in the current version of the 'B' extension +// (v0.92) as it can be implemented with rori by negating the immediate. +// For this reason we pattern-match only against rori[w]. +let Predicates = [HasStdExtZbbOrZbp] in +def : Pat<(RORIPat GPR:$rs1, uimmlog2xlen:$shamt), + (RORI GPR:$rs1, uimmlog2xlen:$shamt)>; + +// We don't pattern-match sbclri[w], sbseti[w], sbinvi[w] because they are +// pattern-matched by simple andi, ori, and xori. +let Predicates = [HasStdExtZbs] in +def : Pat<(and (srl GPR:$rs1, uimmlog2xlen:$shamt), (XLenVT 1)), + (SBEXTI GPR:$rs1, uimmlog2xlen:$shamt)>; + +let Predicates = [HasStdExtZbp, IsRV32] in { +def : Pat<(or (or (and (srl GPR:$rs1, (i32 1)), (i32 0x55555555)), GPR:$rs1), + (and (shl GPR:$rs1, (i32 1)), (i32 0xAAAAAAAA))), + (GORCI GPR:$rs1, (i32 1))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i32 2)), (i32 0x33333333)), GPR:$rs1), + (and (shl GPR:$rs1, (i32 2)), (i32 0xCCCCCCCC))), + (GORCI GPR:$rs1, (i32 2))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i32 4)), (i32 0x0F0F0F0F)), GPR:$rs1), + (and (shl GPR:$rs1, (i32 4)), (i32 0xF0F0F0F0))), + (GORCI GPR:$rs1, (i32 4))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i32 8)), (i32 0x00FF00FF)), GPR:$rs1), + (and (shl GPR:$rs1, (i32 8)), (i32 0xFF00FF00))), + (GORCI GPR:$rs1, (i32 8))>; +def : Pat<(or (or (srl GPR:$rs1, (i32 16)), GPR:$rs1), + (shl GPR:$rs1, (i32 16))), + (GORCI GPR:$rs1, (i32 16))>; +} // Predicates = [HasStdExtZbp, IsRV32] + +let Predicates = [HasStdExtZbp, IsRV64] in { +def : Pat<(or (or (and (srl GPR:$rs1, (i64 1)), (i64 0x5555555555555555)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAAAAAAAAAA))), + (GORCI GPR:$rs1, (i64 1))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i64 2)), (i64 0x3333333333333333)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCCCCCCCCCC))), + (GORCI GPR:$rs1, (i64 2))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F0F0F0F0F)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0F0F0F0F0))), + (GORCI GPR:$rs1, (i64 4))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF00FF00FF)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00FF00FF00))), + (GORCI GPR:$rs1, (i64 8))>; +def : Pat<(or (or (and (srl GPR:$rs1, (i64 16)), (i64 0x0000FFFF0000FFFF)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 16)), (i64 0xFFFF0000FFFF0000))), + (GORCI GPR:$rs1, (i64 16))>; +def : Pat<(or (or (srl GPR:$rs1, (i64 32)), GPR:$rs1), + (shl GPR:$rs1, (i64 32))), + (GORCI GPR:$rs1, (i64 32))>; +} // Predicates = [HasStdExtZbp, IsRV64] + +let Predicates = [HasStdExtZbp, IsRV32] in { +def : Pat<(or (and (shl GPR:$rs1, (i32 1)), (i32 0xAAAAAAAA)), + (and (srl GPR:$rs1, (i32 1)), (i32 0x55555555))), + (GREVI GPR:$rs1, (i32 1))>; +def : Pat<(or (and (shl GPR:$rs1, (i32 2)), (i32 0xCCCCCCCC)), + (and (srl GPR:$rs1, (i32 2)), (i32 0x33333333))), + (GREVI GPR:$rs1, (i32 2))>; +def : Pat<(or (and (shl GPR:$rs1, (i32 4)), (i32 0xF0F0F0F0)), + (and (srl GPR:$rs1, (i32 4)), (i32 0x0F0F0F0F))), + (GREVI GPR:$rs1, (i32 4))>; +def : Pat<(or (and (shl GPR:$rs1, (i32 8)), (i32 0xFF00FF00)), + (and (srl GPR:$rs1, (i32 8)), (i32 0x00FF00FF))), + (GREVI GPR:$rs1, (i32 8))>; +def : Pat<(rotr (bswap GPR:$rs1), (i32 16)), (GREVI GPR:$rs1, (i32 8))>; +def : Pat<(or (shl GPR:$rs1, (i32 16)), (srl GPR:$rs1, (i32 16))), + (GREVI GPR:$rs1, (i32 16))>; +def : Pat<(rotl GPR:$rs1, (i32 16)), (GREVI GPR:$rs1, (i32 16))>; +def : Pat<(bswap GPR:$rs1), (GREVI GPR:$rs1, (i32 24))>; +def : Pat<(bitreverse GPR:$rs1), (GREVI GPR:$rs1, (i32 31))>; +} // Predicates = [HasStdExtZbp, IsRV32] + +let Predicates = [HasStdExtZbp, IsRV64] in { +def : Pat<(or (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAAAAAAAAAA)), + (and (srl GPR:$rs1, (i64 1)), (i64 0x5555555555555555))), + (GREVI GPR:$rs1, (i64 1))>; +def : Pat<(or (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCCCCCCCCCC)), + (and (srl GPR:$rs1, (i64 2)), (i64 0x3333333333333333))), + (GREVI GPR:$rs1, (i64 2))>; +def : Pat<(or (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0F0F0F0F0)), + (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F0F0F0F0F))), + (GREVI GPR:$rs1, (i64 4))>; +def : Pat<(or (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00FF00FF00)), + (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF00FF00FF))), + (GREVI GPR:$rs1, (i64 8))>; +def : Pat<(or (and (shl GPR:$rs1, (i64 16)), (i64 0xFFFF0000FFFF0000)), + (and (srl GPR:$rs1, (i64 16)), (i64 0x0000FFFF0000FFFF))), + (GREVI GPR:$rs1, (i64 16))>; +def : Pat<(or (shl GPR:$rs1, (i64 32)), (srl GPR:$rs1, (i64 32))), + (GREVI GPR:$rs1, (i64 32))>; +def : Pat<(rotl GPR:$rs1, (i64 32)), (GREVI GPR:$rs1, (i64 32))>; +def : Pat<(bswap GPR:$rs1), (GREVI GPR:$rs1, (i64 56))>; +def : Pat<(bitreverse GPR:$rs1), (GREVI GPR:$rs1, (i64 63))>; +} // Predicates = [HasStdExtZbp, IsRV64] + +let Predicates = [HasStdExtZbt] in { +def : Pat<(or (and (xor GPR:$rs2, -1), GPR:$rs3), (and GPR:$rs2, GPR:$rs1)), + (CMIX GPR:$rs1, GPR:$rs2, GPR:$rs3)>; +def : Pat<(riscv_selectcc GPR:$rs2, (XLenVT 0), (XLenVT 17), GPR:$rs3, GPR:$rs1), + (CMOV GPR:$rs1, GPR:$rs2, GPR:$rs3)>; +def : Pat<(fshl GPR:$rs1, GPR:$rs2, GPR:$rs3), + (FSL GPR:$rs1, GPR:$rs2, GPR:$rs3)>; +def : Pat<(fshr GPR:$rs1, GPR:$rs2, GPR:$rs3), + (FSR GPR:$rs1, GPR:$rs2, GPR:$rs3)>; +def : Pat<(fshr GPR:$rs1, GPR:$rs2, uimmlog2xlen:$shamt), + (FSRI GPR:$rs1, GPR:$rs2, uimmlog2xlen:$shamt)>; +} // Predicates = [HasStdExtZbt] + +let Predicates = [HasStdExtZbb] in { +def : Pat<(ctlz GPR:$rs1), (CLZ GPR:$rs1)>; +def : Pat<(cttz GPR:$rs1), (CTZ GPR:$rs1)>; +def : Pat<(ctpop GPR:$rs1), (PCNT GPR:$rs1)>; +} // Predicates = [HasStdExtZbb] + +let Predicates = [HasStdExtZbb, IsRV32] in +def : Pat<(sra (shl GPR:$rs1, (i32 24)), (i32 24)), (SEXTB GPR:$rs1)>; +let Predicates = [HasStdExtZbb, IsRV64] in +def : Pat<(sra (shl GPR:$rs1, (i64 56)), (i64 56)), (SEXTB GPR:$rs1)>; + +let Predicates = [HasStdExtZbb, IsRV32] in +def : Pat<(sra (shl GPR:$rs1, (i32 16)), (i32 16)), (SEXTH GPR:$rs1)>; +let Predicates = [HasStdExtZbb, IsRV64] in +def : Pat<(sra (shl GPR:$rs1, (i64 48)), (i64 48)), (SEXTH GPR:$rs1)>; + +let Predicates = [HasStdExtZbb] in { +def : Pat<(smin GPR:$rs1, GPR:$rs2), (MIN GPR:$rs1, GPR:$rs2)>; +def : Pat<(riscv_selectcc GPR:$rs1, GPR:$rs2, (XLenVT 20), GPR:$rs1, GPR:$rs2), + (MIN GPR:$rs1, GPR:$rs2)>; +def : Pat<(smax GPR:$rs1, GPR:$rs2), (MAX GPR:$rs1, GPR:$rs2)>; +def : Pat<(riscv_selectcc GPR:$rs2, GPR:$rs1, (XLenVT 20), GPR:$rs1, GPR:$rs2), + (MAX GPR:$rs1, GPR:$rs2)>; +def : Pat<(umin GPR:$rs1, GPR:$rs2), (MINU GPR:$rs1, GPR:$rs2)>; +def : Pat<(riscv_selectcc GPR:$rs1, GPR:$rs2, (XLenVT 12), GPR:$rs1, GPR:$rs2), + (MINU GPR:$rs1, GPR:$rs2)>; +def : Pat<(umax GPR:$rs1, GPR:$rs2), (MAXU GPR:$rs1, GPR:$rs2)>; +def : Pat<(riscv_selectcc GPR:$rs2, GPR:$rs1, (XLenVT 12), GPR:$rs1, GPR:$rs2), + (MAXU GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbb] + +let Predicates = [HasStdExtZbbOrZbp, IsRV32] in +def : Pat<(or (and GPR:$rs1, 0x0000FFFF), (shl GPR:$rs2, (i32 16))), + (PACK GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbbOrZbp, IsRV64] in +def : Pat<(or (and GPR:$rs1, 0x00000000FFFFFFFF), (shl GPR:$rs2, (i64 32))), + (PACK GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbbOrZbp, IsRV32] in +def : Pat<(or (and GPR:$rs2, 0xFFFF0000), (srl GPR:$rs1, (i32 16))), + (PACKU GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbbOrZbp, IsRV64] in +def : Pat<(or (and GPR:$rs2, 0xFFFFFFFF00000000), (srl GPR:$rs1, (i64 32))), + (PACKU GPR:$rs1, GPR:$rs2)>; +let Predicates = [HasStdExtZbbOrZbp] in +def : Pat<(or (and (shl GPR:$rs2, (XLenVT 8)), 0xFF00), + (and GPR:$rs1, 0x00FF)), + (PACKH GPR:$rs1, GPR:$rs2)>; + +let Predicates = [HasStdExtZbp, IsRV32] in { +def : Pat<(or (or (and (shl GPR:$rs1, (i32 8)), (i32 0x00FF0000)), + (and GPR:$rs1, (i32 0xFF0000FF))), + (and (srl GPR:$rs1, (i32 8)), (i32 0x0000FF00))), + (SHFLI GPR:$rs1, (i32 8))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i32 4)), (i32 0x0F000F00)), + (and GPR:$rs1, (i32 0xF00FF00F))), + (and (srl GPR:$rs1, (i32 4)), (i32 0x00F000F0))), + (SHFLI GPR:$rs1, (i32 4))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i32 2)), (i32 0x30303030)), + (and GPR:$rs1, (i32 0xC3C3C3C3))), + (and (srl GPR:$rs1, (i32 2)), (i32 0x0C0C0C0C))), + (SHFLI GPR:$rs1, (i32 2))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i32 1)), (i32 0x44444444)), + (and GPR:$rs1, (i32 0x99999999))), + (and (srl GPR:$rs1, (i32 1)), (i32 0x22222222))), + (SHFLI GPR:$rs1, (i32 1))>; +} // Predicates = [HasStdExtZbp, IsRV32] + +let Predicates = [HasStdExtZbp, IsRV64] in { +def : Pat<(or (or (and (shl GPR:$rs1, (i64 16)), (i64 0x0000FFFF00000000)), + (and GPR:$rs1, (i64 0xFFFF00000000FFFF))), + (and (srl GPR:$rs1, (i64 16)), (i64 0x00000000FFFF0000))), + (SHFLI GPR:$rs1, (i64 16))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i64 8)), (i64 0x00FF000000FF0000)), + (and GPR:$rs1, (i64 0xFF0000FFFF0000FF))), + (and (srl GPR:$rs1, (i64 8)), (i64 0x0000FF000000FF00))), + (SHFLI GPR:$rs1, (i64 8))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i64 4)), (i64 0x0F000F000F000F00)), + (and GPR:$rs1, (i64 0xF00FF00FF00FF00F))), + (and (srl GPR:$rs1, (i64 4)), (i64 0x00F000F000F000F0))), + (SHFLI GPR:$rs1, (i64 4))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i64 2)), (i64 0x3030303030303030)), + (and GPR:$rs1, (i64 0xC3C3C3C3C3C3C3C3))), + (and (srl GPR:$rs1, (i64 2)), (i64 0x0C0C0C0C0C0C0C0C))), + (SHFLI GPR:$rs1, (i64 2))>; +def : Pat<(or (or (and (shl GPR:$rs1, (i64 1)), (i64 0x4444444444444444)), + (and GPR:$rs1, (i64 0x9999999999999999))), + (and (srl GPR:$rs1, (i64 1)), (i64 0x2222222222222222))), + (SHFLI GPR:$rs1, (i64 1))>; +} // Predicates = [HasStdExtZbp, IsRV64] + +let Predicates = [HasStdExtZbb, IsRV64] in { +def : Pat<(and (add GPR:$rs, simm12:$simm12), (i64 0xFFFFFFFF)), + (ADDIWU GPR:$rs, simm12:$simm12)>; +def : Pat<(SLLIUWPat GPR:$rs1, uimmlog2xlen:$shamt), + (SLLIUW GPR:$rs1, uimmlog2xlen:$shamt)>; +def : Pat<(and (add GPR:$rs1, GPR:$rs2), (i64 0xFFFFFFFF)), + (ADDWU GPR:$rs1, GPR:$rs2)>; +def : Pat<(and (sub GPR:$rs1, GPR:$rs2), (i64 0xFFFFFFFF)), + (SUBWU GPR:$rs1, GPR:$rs2)>; +def : Pat<(add GPR:$rs1, (and GPR:$rs2, (i64 0xFFFFFFFF))), + (ADDUW GPR:$rs1, GPR:$rs2)>; +def : Pat<(sub GPR:$rs1, (and GPR:$rs2, (i64 0xFFFFFFFF))), + (SUBUW GPR:$rs1, GPR:$rs2)>; +def : Pat<(xor (riscv_sllw (xor GPR:$rs1, -1), GPR:$rs2), -1), + (SLOW GPR:$rs1, GPR:$rs2)>; +def : Pat<(xor (riscv_srlw (xor GPR:$rs1, -1), GPR:$rs2), -1), + (SROW GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbb, IsRV64] + +let Predicates = [HasStdExtZbbOrZbp, IsRV64] in { +def : Pat<(or (riscv_sllw (assertsexti32 GPR:$rs1), (assertsexti32 GPR:$rs2)), + (riscv_srlw (assertsexti32 GPR:$rs1), + (sub (i64 0), (assertsexti32 GPR:$rs2)))), + (ROLW GPR:$rs1, GPR:$rs2)>; +def : Pat<(or (riscv_sllw (assertsexti32 GPR:$rs1), + (sub (i64 0), (assertsexti32 GPR:$rs2))), + (riscv_srlw (assertsexti32 GPR:$rs1), (assertsexti32 GPR:$rs2))), + (RORW GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbbOrZbp, IsRV64] + +let Predicates = [HasStdExtZbs, IsRV64] in { +def : Pat<(and (xor (riscv_sllw 1, (assertsexti32 GPR:$rs2)), -1), + (assertsexti32 GPR:$rs1)), + (SBCLRW GPR:$rs1, GPR:$rs2)>; +def : Pat<(or (riscv_sllw 1, (assertsexti32 GPR:$rs2)), + (assertsexti32 GPR:$rs1)), + (SBSETW GPR:$rs1, GPR:$rs2)>; +def : Pat<(xor (riscv_sllw 1, (assertsexti32 GPR:$rs2)), + (assertsexti32 GPR:$rs1)), + (SBINVW GPR:$rs1, GPR:$rs2)>; +def : Pat<(and (riscv_srlw (assertsexti32 GPR:$rs1), (assertsexti32 GPR:$rs2)), + 1), + (SBEXTW GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbs, IsRV64] + +let Predicates = [HasStdExtZbb, IsRV64] in { +def : Pat<(SLOIWPat GPR:$rs1, uimmlog2xlen:$shamt), + (SLOIW GPR:$rs1, uimmlog2xlen:$shamt)>; +def : Pat<(SROIWPat GPR:$rs1, uimmlog2xlen:$shamt), + (SROIW GPR:$rs1, uimmlog2xlen:$shamt)>; +} // Predicates = [HasStdExtZbb, IsRV64] + +let Predicates = [HasStdExtZbbOrZbp, IsRV64] in +def : Pat<(RORIWPat GPR:$rs1, uimmlog2xlen:$shamt), + (RORIW GPR:$rs1, uimmlog2xlen:$shamt)>; + +let Predicates = [HasStdExtZbp, IsRV64] in { +def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 1)), (i64 0x55555555)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAA))), + i32), + (GORCIW GPR:$rs1, (i64 1))>; +def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 2)), (i64 0x33333333)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCC))), + i32), + (GORCIW GPR:$rs1, (i64 2))>; +def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0))), + i32), + (GORCIW GPR:$rs1, (i64 4))>; +def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00))), + i32), + (GORCIW GPR:$rs1, (i64 8))>; +def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 16)), (i64 0x0000FFFF)), + GPR:$rs1), + (and (shl GPR:$rs1, (i64 16)), (i64 0xFFFF0000))), + i32), + (GORCIW GPR:$rs1, (i64 16))>; +def : Pat<(sext_inreg (or (or (srl (and GPR:$rs1, (i64 0xFFFF0000)), (i64 16)), + GPR:$rs1), + (shl GPR:$rs1, (i64 16))), i32), + (GORCIW GPR:$rs1, (i64 16))>; + +def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAA)), + (and (srl GPR:$rs1, (i64 1)), (i64 0x55555555))), + i32), + (GREVIW GPR:$rs1, (i64 1))>; +def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCC)), + (and (srl GPR:$rs1, (i64 2)), (i64 0x33333333))), + i32), + (GREVIW GPR:$rs1, (i64 2))>; +def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0)), + (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F))), + i32), + (GREVIW GPR:$rs1, (i64 4))>; +def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00)), + (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF))), + i32), + (GREVIW GPR:$rs1, (i64 8))>; +def : Pat<(sext_inreg (or (shl GPR:$rs1, (i64 16)), + (srl (and GPR:$rs1, 0xFFFF0000), (i64 16))), i32), + (GREVIW GPR:$rs1, (i64 16))>; +def : Pat<(sra (bswap GPR:$rs1), (i64 32)), (GREVIW GPR:$rs1, (i64 24))>; +def : Pat<(sra (bitreverse GPR:$rs1), (i64 32)), (GREVIW GPR:$rs1, (i64 31))>; +} // Predicates = [HasStdExtZbp, IsRV64] + +let Predicates = [HasStdExtZbt, IsRV64] in { +def : Pat<(riscv_selectcc (and (assertsexti32 GPR:$rs3), 31), + (i64 0), + (i64 17), + (assertsexti32 GPR:$rs1), + (or (riscv_sllw (assertsexti32 GPR:$rs1), + (and (assertsexti32 GPR:$rs3), 31)), + (riscv_srlw (assertsexti32 GPR:$rs2), + (sub (i64 32), + (assertsexti32 GPR:$rs3))))), + (FSLW GPR:$rs1, GPR:$rs2, GPR:$rs3)>; +def : Pat<(riscv_selectcc (and (assertsexti32 GPR:$rs3), 31), + (i64 0), + (i64 17), + (assertsexti32 GPR:$rs2), + (or (riscv_sllw (assertsexti32 GPR:$rs1), + (sub (i64 32), + (assertsexti32 GPR:$rs3))), + (riscv_srlw (assertsexti32 GPR:$rs2), + (and (assertsexti32 GPR:$rs3), 31)))), + (FSRW GPR:$rs1, GPR:$rs2, GPR:$rs3)>; +def : Pat<(FSRIWPat GPR:$rs1, GPR:$rs2, uimmlog2xlen:$shamt), + (FSRIW GPR:$rs1, GPR:$rs2, uimmlog2xlen:$shamt)>; +} // Predicates = [HasStdExtZbt, IsRV64] + +let Predicates = [HasStdExtZbb, IsRV64] in { +def : Pat<(add (ctlz (and GPR:$rs1, (i64 0xFFFFFFFF))), (i64 -32)), + (CLZW GPR:$rs1)>; +// We don't pattern-match CTZW here as it has the same pattern and result as +// RV64 CTZ +def : Pat<(ctpop (and GPR:$rs1, (i64 0xFFFFFFFF))), (PCNTW GPR:$rs1)>; +} // Predicates = [HasStdExtZbb, IsRV64] + +let Predicates = [HasStdExtZbbOrZbp, IsRV64] in { +def : Pat<(sext_inreg (or (shl (assertsexti32 GPR:$rs2), (i64 16)), + (and (assertsexti32 GPR:$rs1), 0x000000000000FFFF)), + i32), + (PACKW GPR:$rs1, GPR:$rs2)>; +def : Pat<(or (and (assertsexti32 GPR:$rs2), 0xFFFFFFFFFFFF0000), + (srl (and (assertsexti32 GPR:$rs1), 0x00000000FFFF0000), + (i64 16))), + (PACKUW GPR:$rs1, GPR:$rs2)>; +} // Predicates = [HasStdExtZbbOrZbp, IsRV64] diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index f8b6b7eb3aff..86aa85e965f6 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -30953,6 +30953,34 @@ bool X86TargetLowering::areJTsAllowed(const Function *Fn) const { // X86 Scheduler Hooks //===----------------------------------------------------------------------===// +// Returns true if EFLAG is consumed after this iterator in the rest of the +// basic block or any successors of the basic block. +static bool isEFLAGSLiveAfter(MachineBasicBlock::iterator Itr, + MachineBasicBlock *BB) { + // Scan forward through BB for a use/def of EFLAGS. + for (MachineBasicBlock::iterator miI = std::next(Itr), miE = BB->end(); + miI != miE; ++miI) { + const MachineInstr& mi = *miI; + if (mi.readsRegister(X86::EFLAGS)) + return true; + // If we found a def, we can stop searching. + if (mi.definesRegister(X86::EFLAGS)) + return false; + } + + // If we hit the end of the block, check whether EFLAGS is live into a + // successor. + for (MachineBasicBlock::succ_iterator sItr = BB->succ_begin(), + sEnd = BB->succ_end(); + sItr != sEnd; ++sItr) { + MachineBasicBlock* succ = *sItr; + if (succ->isLiveIn(X86::EFLAGS)) + return true; + } + + return false; +} + /// Utility function to emit xbegin specifying the start of an RTM region. static MachineBasicBlock *emitXBegin(MachineInstr &MI, MachineBasicBlock *MBB, const TargetInstrInfo *TII) { @@ -30985,6 +31013,12 @@ static MachineBasicBlock *emitXBegin(MachineInstr &MI, MachineBasicBlock *MBB, MF->insert(I, fallMBB); MF->insert(I, sinkMBB); + if (isEFLAGSLiveAfter(MI, MBB)) { + mainMBB->addLiveIn(X86::EFLAGS); + fallMBB->addLiveIn(X86::EFLAGS); + sinkMBB->addLiveIn(X86::EFLAGS); + } + // Transfer the remainder of BB and its successor edges to sinkMBB. sinkMBB->splice(sinkMBB->begin(), MBB, std::next(MachineBasicBlock::iterator(MI)), MBB->end()); @@ -31373,27 +31407,8 @@ MachineBasicBlock *X86TargetLowering::EmitVAStartSaveXMMRegsWithCustomInserter( static bool checkAndUpdateEFLAGSKill(MachineBasicBlock::iterator SelectItr, MachineBasicBlock* BB, const TargetRegisterInfo* TRI) { - // Scan forward through BB for a use/def of EFLAGS. - MachineBasicBlock::iterator miI(std::next(SelectItr)); - for (MachineBasicBlock::iterator miE = BB->end(); miI != miE; ++miI) { - const MachineInstr& mi = *miI; - if (mi.readsRegister(X86::EFLAGS)) - return false; - if (mi.definesRegister(X86::EFLAGS)) - break; // Should have kill-flag - update below. - } - - // If we hit the end of the block, check whether EFLAGS is live into a - // successor. - if (miI == BB->end()) { - for (MachineBasicBlock::succ_iterator sItr = BB->succ_begin(), - sEnd = BB->succ_end(); - sItr != sEnd; ++sItr) { - MachineBasicBlock* succ = *sItr; - if (succ->isLiveIn(X86::EFLAGS)) - return false; - } - } + if (isEFLAGSLiveAfter(SelectItr, BB)) + return false; // We found a def, or hit the end of the basic block and EFLAGS wasn't live // out. SelectMI should have a kill flag on EFLAGS. @@ -44349,8 +44364,8 @@ static SDValue combineVEXTRACT_STORE(SDNode *N, SelectionDAG &DAG, /// A horizontal-op B, for some already available A and B, and if so then LHS is /// set to A, RHS to B, and the routine returns 'true'. static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, SelectionDAG &DAG, - const X86Subtarget &Subtarget, - bool IsCommutative) { + const X86Subtarget &Subtarget, bool IsCommutative, + SmallVectorImpl<int> &PostShuffleMask) { // If either operand is undef, bail out. The binop should be simplified. if (LHS.isUndef() || RHS.isUndef()) return false; @@ -44443,6 +44458,12 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, SelectionDAG &DAG, RMask.push_back(i); } + // Avoid 128-bit lane crossing if pre-AVX2 and FP (integer will split). + if (!Subtarget.hasAVX2() && VT.isFloatingPoint() && + (isLaneCrossingShuffleMask(128, VT.getScalarSizeInBits(), LMask) || + isLaneCrossingShuffleMask(128, VT.getScalarSizeInBits(), RMask))) + return false; + // If A and B occur in reverse order in RHS, then canonicalize by commuting // RHS operands and shuffle mask. if (A != C) { @@ -44453,6 +44474,9 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, SelectionDAG &DAG, if (!(A == C && B == D)) return false; + PostShuffleMask.clear(); + PostShuffleMask.append(NumElts, SM_SentinelUndef); + // LHS and RHS are now: // LHS = shuffle A, B, LMask // RHS = shuffle A, B, RMask @@ -44461,6 +44485,7 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, SelectionDAG &DAG, // so we just repeat the inner loop if this is a 256-bit op. unsigned Num128BitChunks = VT.getSizeInBits() / 128; unsigned NumEltsPer128BitChunk = NumElts / Num128BitChunks; + unsigned NumEltsPer64BitChunk = NumEltsPer128BitChunk / 2; assert((NumEltsPer128BitChunk % 2 == 0) && "Vector type should have an even number of elements in each lane"); for (unsigned j = 0; j != NumElts; j += NumEltsPer128BitChunk) { @@ -44472,25 +44497,40 @@ static bool isHorizontalBinOp(SDValue &LHS, SDValue &RHS, SelectionDAG &DAG, (!B.getNode() && (LIdx >= (int)NumElts || RIdx >= (int)NumElts))) continue; + // Check that successive odd/even elements are being operated on. If not, + // this is not a horizontal operation. + if (!((RIdx & 1) == 1 && (LIdx + 1) == RIdx) && + !((LIdx & 1) == 1 && (RIdx + 1) == LIdx && IsCommutative)) + return false; + + // Compute the post-shuffle mask index based on where the element + // is stored in the HOP result, and where it needs to be moved to. + int Base = LIdx & ~1u; + int Index = ((Base % NumEltsPer128BitChunk) / 2) + + ((Base % NumElts) & ~(NumEltsPer128BitChunk - 1)); + // The low half of the 128-bit result must choose from A. // The high half of the 128-bit result must choose from B, // unless B is undef. In that case, we are always choosing from A. - unsigned NumEltsPer64BitChunk = NumEltsPer128BitChunk / 2; - unsigned Src = B.getNode() ? i >= NumEltsPer64BitChunk : 0; - - // Check that successive elements are being operated on. If not, this is - // not a horizontal operation. - int Index = 2 * (i % NumEltsPer64BitChunk) + NumElts * Src + j; - if (!(LIdx == Index && RIdx == Index + 1) && - !(IsCommutative && LIdx == Index + 1 && RIdx == Index)) - return false; + if ((B && Base >= (int)NumElts) || (!B && i >= NumEltsPer64BitChunk)) + Index += NumEltsPer64BitChunk; + PostShuffleMask[i + j] = Index; } } LHS = A.getNode() ? A : B; // If A is 'UNDEF', use B for it. RHS = B.getNode() ? B : A; // If B is 'UNDEF', use A for it. - if (!shouldUseHorizontalOp(LHS == RHS && NumShuffles < 2, DAG, Subtarget)) + bool IsIdentityPostShuffle = + isSequentialOrUndefInRange(PostShuffleMask, 0, NumElts, 0); + if (IsIdentityPostShuffle) + PostShuffleMask.clear(); + + // Assume a SingleSource HOP if we only shuffle one input and don't need to + // shuffle the result. + if (!shouldUseHorizontalOp(LHS == RHS && + (NumShuffles < 2 || !IsIdentityPostShuffle), + DAG, Subtarget)) return false; LHS = DAG.getBitcast(VT, LHS); @@ -44509,10 +44549,16 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG, assert((IsFadd || N->getOpcode() == ISD::FSUB) && "Wrong opcode"); // Try to synthesize horizontal add/sub from adds/subs of shuffles. + SmallVector<int, 8> PostShuffleMask; if (((Subtarget.hasSSE3() && (VT == MVT::v4f32 || VT == MVT::v2f64)) || (Subtarget.hasAVX() && (VT == MVT::v8f32 || VT == MVT::v4f64))) && - isHorizontalBinOp(LHS, RHS, DAG, Subtarget, IsFadd)) - return DAG.getNode(HorizOpcode, SDLoc(N), VT, LHS, RHS); + isHorizontalBinOp(LHS, RHS, DAG, Subtarget, IsFadd, PostShuffleMask)) { + SDValue HorizBinOp = DAG.getNode(HorizOpcode, SDLoc(N), VT, LHS, RHS); + if (!PostShuffleMask.empty()) + HorizBinOp = DAG.getVectorShuffle(VT, SDLoc(HorizBinOp), HorizBinOp, + DAG.getUNDEF(VT), PostShuffleMask); + return HorizBinOp; + } // NOTE: isHorizontalBinOp may have changed LHS/RHS variables. @@ -47605,17 +47651,22 @@ static SDValue combineAddOrSubToHADDorHSUB(SDNode *N, SelectionDAG &DAG, bool IsAdd = N->getOpcode() == ISD::ADD; assert((IsAdd || N->getOpcode() == ISD::SUB) && "Wrong opcode"); + SmallVector<int, 8> PostShuffleMask; if ((VT == MVT::v8i16 || VT == MVT::v4i32 || VT == MVT::v16i16 || VT == MVT::v8i32) && Subtarget.hasSSSE3() && - isHorizontalBinOp(Op0, Op1, DAG, Subtarget, IsAdd)) { + isHorizontalBinOp(Op0, Op1, DAG, Subtarget, IsAdd, PostShuffleMask)) { auto HOpBuilder = [IsAdd](SelectionDAG &DAG, const SDLoc &DL, ArrayRef<SDValue> Ops) { - return DAG.getNode(IsAdd ? X86ISD::HADD : X86ISD::HSUB, - DL, Ops[0].getValueType(), Ops); + return DAG.getNode(IsAdd ? X86ISD::HADD : X86ISD::HSUB, DL, + Ops[0].getValueType(), Ops); }; - return SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, {Op0, Op1}, - HOpBuilder); + SDValue HorizBinOp = + SplitOpsAndApply(DAG, Subtarget, SDLoc(N), VT, {Op0, Op1}, HOpBuilder); + if (!PostShuffleMask.empty()) + HorizBinOp = DAG.getVectorShuffle(VT, SDLoc(HorizBinOp), HorizBinOp, + DAG.getUNDEF(VT), PostShuffleMask); + return HorizBinOp; } return SDValue(); |