aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/VE/VVPISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/VE/VVPISelLowering.cpp')
-rw-r--r--llvm/lib/Target/VE/VVPISelLowering.cpp443
1 files changed, 443 insertions, 0 deletions
diff --git a/llvm/lib/Target/VE/VVPISelLowering.cpp b/llvm/lib/Target/VE/VVPISelLowering.cpp
new file mode 100644
index 000000000000..330eef4c7c2b
--- /dev/null
+++ b/llvm/lib/Target/VE/VVPISelLowering.cpp
@@ -0,0 +1,443 @@
+//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering and legalization of vector instructions to
+// VVP_*layer SDNodes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "VECustomDAG.h"
+#include "VEISelLowering.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "ve-lower"
+
+SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
+ SelectionDAG &DAG) const {
+ VECustomDAG CDAG(DAG, Op);
+ SDValue AVL =
+ CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
+ SDValue A = Op->getOperand(0);
+ SDValue B = Op->getOperand(1);
+ SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
+ SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
+ SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
+ SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
+ unsigned Opc = Op.getOpcode();
+ auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
+ auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
+ return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
+}
+
+SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
+ // Can we represent this as a VVP node.
+ const unsigned Opcode = Op->getOpcode();
+ auto VVPOpcodeOpt = getVVPOpcode(Opcode);
+ if (!VVPOpcodeOpt)
+ return SDValue();
+ unsigned VVPOpcode = VVPOpcodeOpt.getValue();
+ const bool FromVP = ISD::isVPOpcode(Opcode);
+
+ // The representative and legalized vector type of this operation.
+ VECustomDAG CDAG(DAG, Op);
+ // Dispatch to complex lowering functions.
+ switch (VVPOpcode) {
+ case VEISD::VVP_LOAD:
+ case VEISD::VVP_STORE:
+ return lowerVVP_LOAD_STORE(Op, CDAG);
+ case VEISD::VVP_GATHER:
+ case VEISD::VVP_SCATTER:
+ return lowerVVP_GATHER_SCATTER(Op, CDAG);
+ }
+
+ EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
+ EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
+ auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
+
+ SDValue AVL;
+ SDValue Mask;
+
+ if (FromVP) {
+ // All upstream VP SDNodes always have a mask and avl.
+ auto MaskIdx = ISD::getVPMaskIdx(Opcode);
+ auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
+ if (MaskIdx)
+ Mask = Op->getOperand(*MaskIdx);
+ if (AVLIdx)
+ AVL = Op->getOperand(*AVLIdx);
+ }
+
+ // Materialize default mask and avl.
+ if (!AVL)
+ AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
+ if (!Mask)
+ Mask = CDAG.getConstantMask(Packing, true);
+
+ assert(LegalVecVT.isSimple());
+ if (isVVPUnaryOp(VVPOpcode))
+ return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
+ if (isVVPBinaryOp(VVPOpcode))
+ return CDAG.getNode(VVPOpcode, LegalVecVT,
+ {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
+ if (isVVPReductionOp(VVPOpcode)) {
+ auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
+ SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
+ SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
+ return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
+ VectorV, Mask, AVL, Op->getFlags());
+ }
+
+ switch (VVPOpcode) {
+ default:
+ llvm_unreachable("lowerToVVP called for unexpected SDNode.");
+ case VEISD::VVP_FFMA: {
+ // VE has a swizzled operand order in FMA (compared to LLVM IR and
+ // SDNodes).
+ auto X = Op->getOperand(2);
+ auto Y = Op->getOperand(0);
+ auto Z = Op->getOperand(1);
+ return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
+ }
+ case VEISD::VVP_SELECT: {
+ auto Mask = Op->getOperand(0);
+ auto OnTrue = Op->getOperand(1);
+ auto OnFalse = Op->getOperand(2);
+ return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
+ }
+ case VEISD::VVP_SETCC: {
+ EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
+ auto LHS = Op->getOperand(0);
+ auto RHS = Op->getOperand(1);
+ auto Pred = Op->getOperand(2);
+ return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
+ }
+ }
+}
+
+SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
+ VECustomDAG &CDAG) const {
+ auto VVPOpc = *getVVPOpcode(Op->getOpcode());
+ const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
+
+ // Shares.
+ SDValue BasePtr = getMemoryPtr(Op);
+ SDValue Mask = getNodeMask(Op);
+ SDValue Chain = getNodeChain(Op);
+ SDValue AVL = getNodeAVL(Op);
+ // Store specific.
+ SDValue Data = getStoredValue(Op);
+ // Load specific.
+ SDValue PassThru = getNodePassthru(Op);
+
+ SDValue StrideV = getLoadStoreStride(Op, CDAG);
+
+ auto DataVT = *getIdiomaticVectorType(Op.getNode());
+ auto Packing = getTypePacking(DataVT);
+
+ // TODO: Infer lower AVL from mask.
+ if (!AVL)
+ AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
+
+ // Default to the all-true mask.
+ if (!Mask)
+ Mask = CDAG.getConstantMask(Packing, true);
+
+ if (IsLoad) {
+ MVT LegalDataVT = getLegalVectorType(
+ Packing, DataVT.getVectorElementType().getSimpleVT());
+
+ auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
+ {Chain, BasePtr, StrideV, Mask, AVL});
+
+ if (!PassThru || PassThru->isUndef())
+ return NewLoadV;
+
+ // Convert passthru to an explicit select node.
+ SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
+ {NewLoadV, PassThru, Mask, AVL});
+ SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
+
+ // Merge them back into one node.
+ return CDAG.getMergeValues({DataV, NewLoadChainV});
+ }
+
+ // VVP_STORE
+ assert(VVPOpc == VEISD::VVP_STORE);
+ return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
+ {Chain, Data, BasePtr, StrideV, Mask, AVL});
+}
+
+SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
+ VECustomDAG &CDAG) const {
+ auto VVPOC = *getVVPOpcode(Op.getOpcode());
+ assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
+
+ MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
+ assert(getTypePacking(DataVT) == Packing::Dense &&
+ "Can only split packed load/store");
+ MVT SplitDataVT = splitVectorType(DataVT);
+
+ assert(!getNodePassthru(Op) &&
+ "Should have been folded in lowering to VVP layer");
+
+ // Analyze the operation
+ SDValue PackedMask = getNodeMask(Op);
+ SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
+ SDValue PackPtr = getMemoryPtr(Op);
+ SDValue PackData = getStoredValue(Op);
+ SDValue PackStride = getLoadStoreStride(Op, CDAG);
+
+ unsigned ChainResIdx = PackData ? 0 : 1;
+
+ SDValue PartOps[2];
+
+ SDValue UpperPartAVL; // we will use this for packing things back together
+ for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
+ // VP ops already have an explicit mask and AVL. When expanding from non-VP
+ // attach those additional inputs here.
+ auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
+
+ // Keep track of the (higher) lvl.
+ if (Part == PackElem::Hi)
+ UpperPartAVL = SplitTM.AVL;
+
+ // Attach non-predicating value operands
+ SmallVector<SDValue, 4> OpVec;
+
+ // Chain
+ OpVec.push_back(getNodeChain(Op));
+
+ // Data
+ if (PackData) {
+ SDValue PartData =
+ CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
+ OpVec.push_back(PartData);
+ }
+
+ // Ptr & Stride
+ // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
+ // Stride info
+ // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
+ OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
+ OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
+
+ // Add predicating args and generate part node
+ OpVec.push_back(SplitTM.Mask);
+ OpVec.push_back(SplitTM.AVL);
+
+ if (PackData) {
+ // Store
+ PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
+ } else {
+ // Load
+ PartOps[(int)Part] =
+ CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
+ }
+ }
+
+ // Merge the chains
+ SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
+ SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
+ SDValue FusedChains =
+ CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
+
+ // Chain only [store]
+ if (PackData)
+ return FusedChains;
+
+ // Re-pack into full packed vector result
+ MVT PackedVT =
+ getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
+ SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
+ PartOps[(int)PackElem::Hi], UpperPartAVL);
+
+ return CDAG.getMergeValues({PackedVals, FusedChains});
+}
+
+SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
+ VECustomDAG &CDAG) const {
+ EVT DataVT = *getIdiomaticVectorType(Op.getNode());
+ auto Packing = getTypePacking(DataVT);
+ MVT LegalDataVT =
+ getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
+
+ SDValue AVL = getAnnotatedNodeAVL(Op).first;
+ SDValue Index = getGatherScatterIndex(Op);
+ SDValue BasePtr = getMemoryPtr(Op);
+ SDValue Mask = getNodeMask(Op);
+ SDValue Chain = getNodeChain(Op);
+ SDValue Scale = getGatherScatterScale(Op);
+ SDValue PassThru = getNodePassthru(Op);
+ SDValue StoredValue = getStoredValue(Op);
+ if (PassThru && PassThru->isUndef())
+ PassThru = SDValue();
+
+ bool IsScatter = (bool)StoredValue;
+
+ // TODO: Infer lower AVL from mask.
+ if (!AVL)
+ AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
+
+ // Default to the all-true mask.
+ if (!Mask)
+ Mask = CDAG.getConstantMask(Packing, true);
+
+ SDValue AddressVec =
+ CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
+ if (IsScatter)
+ return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
+ {Chain, StoredValue, AddressVec, Mask, AVL});
+
+ // Gather.
+ SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
+ {Chain, AddressVec, Mask, AVL});
+
+ if (!PassThru)
+ return NewLoadV;
+
+ // TODO: Use vvp_select
+ SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
+ {NewLoadV, PassThru, Mask, AVL});
+ SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
+ return CDAG.getMergeValues({DataV, NewLoadChainV});
+}
+
+SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
+ VECustomDAG &CDAG) const {
+ LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
+ MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
+
+ // TODO: Recognize packable load,store.
+ if (isPackedVectorType(DataVT))
+ return splitPackedLoadStore(Op, CDAG);
+
+ return legalizePackedAVL(Op, CDAG);
+}
+
+SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
+ SelectionDAG &DAG) const {
+ LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
+ VECustomDAG CDAG(DAG, Op);
+
+ // Dispatch to specialized legalization functions.
+ switch (Op->getOpcode()) {
+ case VEISD::VVP_LOAD:
+ case VEISD::VVP_STORE:
+ return legalizeInternalLoadStoreOp(Op, CDAG);
+ }
+
+ EVT IdiomVT = Op.getValueType();
+ if (isPackedVectorType(IdiomVT) &&
+ !supportsPackedMode(Op.getOpcode(), IdiomVT))
+ return splitVectorOp(Op, CDAG);
+
+ // TODO: Implement odd/even splitting.
+ return legalizePackedAVL(Op, CDAG);
+}
+
+SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
+ MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
+
+ auto AVLPos = getAVLPos(Op->getOpcode());
+ auto MaskPos = getMaskPos(Op->getOpcode());
+
+ SDValue PackedMask = getNodeMask(Op);
+ auto AVLPair = getAnnotatedNodeAVL(Op);
+ SDValue PackedAVL = AVLPair.first;
+ assert(!AVLPair.second && "Expecting non pack-legalized oepration");
+
+ // request the parts
+ SDValue PartOps[2];
+
+ SDValue UpperPartAVL; // we will use this for packing things back together
+ for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
+ // VP ops already have an explicit mask and AVL. When expanding from non-VP
+ // attach those additional inputs here.
+ auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
+
+ if (Part == PackElem::Hi)
+ UpperPartAVL = SplitTM.AVL;
+
+ // Attach non-predicating value operands
+ SmallVector<SDValue, 4> OpVec;
+ for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
+ if (AVLPos && ((int)i) == *AVLPos)
+ continue;
+ if (MaskPos && ((int)i) == *MaskPos)
+ continue;
+
+ // Value operand
+ auto PackedOperand = Op.getOperand(i);
+ auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
+ SDValue PartV =
+ CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
+ OpVec.push_back(PartV);
+ }
+
+ // Add predicating args and generate part node.
+ OpVec.push_back(SplitTM.Mask);
+ OpVec.push_back(SplitTM.AVL);
+ // Emit legal VVP nodes.
+ PartOps[(int)Part] =
+ CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
+ }
+
+ // Re-package vectors.
+ return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
+ PartOps[(int)PackElem::Hi], UpperPartAVL);
+}
+
+SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
+ VECustomDAG &CDAG) const {
+ LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
+ // Only required for VEC and VVP ops.
+ if (!isVVPOrVEC(Op->getOpcode()))
+ return Op;
+
+ // Operation already has a legal AVL.
+ auto AVL = getNodeAVL(Op);
+ if (isLegalAVL(AVL))
+ return Op;
+
+ // Half and round up EVL for 32bit element types.
+ SDValue LegalAVL = AVL;
+ MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
+ if (isPackedVectorType(IdiomVT)) {
+ assert(maySafelyIgnoreMask(Op) &&
+ "TODO Shift predication from EVL into Mask");
+
+ if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
+ LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
+ } else {
+ auto ConstOne = CDAG.getConstant(1, MVT::i32);
+ auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
+ LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
+ }
+ }
+
+ SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
+
+ // Copy the operand list.
+ int NumOp = Op->getNumOperands();
+ auto AVLPos = getAVLPos(Op->getOpcode());
+ std::vector<SDValue> FixedOperands;
+ for (int i = 0; i < NumOp; ++i) {
+ if (AVLPos && (i == *AVLPos)) {
+ FixedOperands.push_back(AnnotatedLegalAVL);
+ continue;
+ }
+ FixedOperands.push_back(Op->getOperand(i));
+ }
+
+ // Clone the operation with fixed operands.
+ auto Flags = Op->getFlags();
+ SDValue NewN =
+ CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
+ return NewN;
+}