diff options
Diffstat (limited to 'llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp')
-rw-r--r-- | llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 71 |
1 files changed, 70 insertions, 1 deletions
diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 66a34d73dd37..b24eb5f7bbf4 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -718,6 +718,71 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { break; } + case ISD::MUL: { + // Special case for calculating (mul (and X, C2), C1) where the full product + // fits in XLen bits. We can shift X left by the number of leading zeros in + // C2 and shift C1 left by XLen-lzcnt(C2). This will ensure the final + // product has XLen trailing zeros, putting it in the output of MULHU. This + // can avoid materializing a constant in a register for C2. + + // RHS should be a constant. + auto *N1C = dyn_cast<ConstantSDNode>(Node->getOperand(1)); + if (!N1C || !N1C->hasOneUse()) + break; + + // LHS should be an AND with constant. + SDValue N0 = Node->getOperand(0); + if (N0.getOpcode() != ISD::AND || !isa<ConstantSDNode>(N0.getOperand(1))) + break; + + uint64_t C2 = cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue(); + + // Constant should be a mask. + if (!isMask_64(C2)) + break; + + // This should be the only use of the AND unless we will use + // (SRLI (SLLI X, 32), 32). We don't use a shift pair for other AND + // constants. + if (!N0.hasOneUse() && C2 != UINT64_C(0xFFFFFFFF)) + break; + + // If this can be an ANDI, ZEXT.H or ZEXT.W we don't need to do this + // optimization. + if (isInt<12>(C2) || + (C2 == UINT64_C(0xFFFF) && + (Subtarget->hasStdExtZbb() || Subtarget->hasStdExtZbp())) || + (C2 == UINT64_C(0xFFFFFFFF) && Subtarget->hasStdExtZba())) + break; + + // We need to shift left the AND input and C1 by a total of XLen bits. + + // How far left do we need to shift the AND input? + unsigned XLen = Subtarget->getXLen(); + unsigned LeadingZeros = XLen - (64 - countLeadingZeros(C2)); + + // The constant gets shifted by the remaining amount unless that would + // shift bits out. + uint64_t C1 = N1C->getZExtValue(); + unsigned ConstantShift = XLen - LeadingZeros; + if (ConstantShift > (XLen - (64 - countLeadingZeros(C1)))) + break; + + uint64_t ShiftedC1 = C1 << ConstantShift; + // If this RV32, we need to sign extend the constant. + if (XLen == 32) + ShiftedC1 = SignExtend64(ShiftedC1, 32); + + // Create (mulhu (slli X, lzcnt(C2)), C1 << (XLen - lzcnt(C2))). + SDNode *Imm = selectImm(CurDAG, DL, ShiftedC1, *Subtarget); + SDNode *SLLI = + CurDAG->getMachineNode(RISCV::SLLI, DL, VT, N0.getOperand(0), + CurDAG->getTargetConstant(LeadingZeros, DL, VT)); + SDNode *MULHU = CurDAG->getMachineNode(RISCV::MULHU, DL, VT, + SDValue(SLLI, 0), SDValue(Imm, 0)); + ReplaceNode(Node, MULHU); + return; + } case ISD::INTRINSIC_WO_CHAIN: { unsigned IntNo = Node->getConstantOperandVal(0); switch (IntNo) { @@ -1450,6 +1515,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { ReplaceNode(Node, Extract.getNode()); return; } + case ISD::SPLAT_VECTOR: case RISCVISD::VMV_V_X_VL: case RISCVISD::VFMV_V_F_VL: { // Try to match splat of a scalar load to a strided load with stride of x0. @@ -1466,7 +1532,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { break; SDValue VL; - selectVLOp(Node->getOperand(1), VL); + if (Node->getOpcode() == ISD::SPLAT_VECTOR) + VL = CurDAG->getTargetConstant(RISCV::VLMaxSentinel, DL, XLenVT); + else + selectVLOp(Node->getOperand(1), VL); unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits()); SDValue SEW = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT); |