//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the targeting of the Machinelegalizer class for SPIR-V. // //===----------------------------------------------------------------------===// #include "SPIRVLegalizerInfo.h" #include "SPIRV.h" #include "SPIRVGlobalRegistry.h" #include "SPIRVSubtarget.h" #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/TargetOpcodes.h" using namespace llvm; using namespace llvm::LegalizeActions; using namespace llvm::LegalityPredicates; static const std::set TypeFoldingSupportingOpcs = { TargetOpcode::G_ADD, TargetOpcode::G_FADD, TargetOpcode::G_SUB, TargetOpcode::G_FSUB, TargetOpcode::G_MUL, TargetOpcode::G_FMUL, TargetOpcode::G_SDIV, TargetOpcode::G_UDIV, TargetOpcode::G_FDIV, TargetOpcode::G_SREM, TargetOpcode::G_UREM, TargetOpcode::G_FREM, TargetOpcode::G_FNEG, TargetOpcode::G_CONSTANT, TargetOpcode::G_FCONSTANT, TargetOpcode::G_AND, TargetOpcode::G_OR, TargetOpcode::G_XOR, TargetOpcode::G_SHL, TargetOpcode::G_ASHR, TargetOpcode::G_LSHR, TargetOpcode::G_SELECT, TargetOpcode::G_EXTRACT_VECTOR_ELT, }; bool isTypeFoldingSupported(unsigned Opcode) { return TypeFoldingSupportingOpcs.count(Opcode) > 0; } SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { using namespace TargetOpcode; this->ST = &ST; GR = ST.getSPIRVGlobalRegistry(); const LLT s1 = LLT::scalar(1); const LLT s8 = LLT::scalar(8); const LLT s16 = LLT::scalar(16); const LLT s32 = LLT::scalar(32); const LLT s64 = LLT::scalar(64); const LLT v16s64 = LLT::fixed_vector(16, 64); const LLT v16s32 = LLT::fixed_vector(16, 32); const LLT v16s16 = LLT::fixed_vector(16, 16); const LLT v16s8 = LLT::fixed_vector(16, 8); const LLT v16s1 = LLT::fixed_vector(16, 1); const LLT v8s64 = LLT::fixed_vector(8, 64); const LLT v8s32 = LLT::fixed_vector(8, 32); const LLT v8s16 = LLT::fixed_vector(8, 16); const LLT v8s8 = LLT::fixed_vector(8, 8); const LLT v8s1 = LLT::fixed_vector(8, 1); const LLT v4s64 = LLT::fixed_vector(4, 64); const LLT v4s32 = LLT::fixed_vector(4, 32); const LLT v4s16 = LLT::fixed_vector(4, 16); const LLT v4s8 = LLT::fixed_vector(4, 8); const LLT v4s1 = LLT::fixed_vector(4, 1); const LLT v3s64 = LLT::fixed_vector(3, 64); const LLT v3s32 = LLT::fixed_vector(3, 32); const LLT v3s16 = LLT::fixed_vector(3, 16); const LLT v3s8 = LLT::fixed_vector(3, 8); const LLT v3s1 = LLT::fixed_vector(3, 1); const LLT v2s64 = LLT::fixed_vector(2, 64); const LLT v2s32 = LLT::fixed_vector(2, 32); const LLT v2s16 = LLT::fixed_vector(2, 16); const LLT v2s8 = LLT::fixed_vector(2, 8); const LLT v2s1 = LLT::fixed_vector(2, 1); const unsigned PSize = ST.getPointerSize(); const LLT p0 = LLT::pointer(0, PSize); // Function const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup const LLT p2 = LLT::pointer(2, PSize); // UniformConstant const LLT p3 = LLT::pointer(3, PSize); // Workgroup const LLT p4 = LLT::pointer(4, PSize); // Generic const LLT p5 = LLT::pointer(5, PSize); // Input // TODO: remove copy-pasting here by using concatenation in some way. auto allPtrsScalarsAndVectors = { p0, p1, p2, p3, p4, p5, s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; auto allScalarsAndVectors = { s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; auto allIntScalars = {s8, s16, s32, s64}; auto allFloatScalarsAndVectors = { s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; auto allFloatAndIntScalars = allIntScalars; auto allPtrs = {p0, p1, p2, p3, p4, p5}; auto allWritablePtrs = {p0, p1, p3, p4}; for (auto Opc : TypeFoldingSupportingOpcs) getActionDefinitionsBuilder(Opc).custom(); getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); // TODO: add proper rules for vectors legalization. getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); getActionDefinitionsBuilder(G_ADDRSPACE_CAST) .legalForCartesianProduct(allPtrs, allPtrs); getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) .legalForCartesianProduct(allIntScalarsAndVectors, allFloatScalarsAndVectors); getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) .legalForCartesianProduct(allFloatScalarsAndVectors, allScalarsAndVectors); getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) .legalFor(allIntScalarsAndVectors); getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( allIntScalarsAndVectors, allIntScalarsAndVectors); getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); getActionDefinitionsBuilder(G_BITCAST).legalIf(all( typeInSet(0, allPtrsScalarsAndVectors), typeInSet(1, allPtrsScalarsAndVectors), LegalityPredicate(([=](const LegalityQuery &Query) { return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); })))); getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); getActionDefinitionsBuilder(G_INTTOPTR) .legalForCartesianProduct(allPtrs, allIntScalars); getActionDefinitionsBuilder(G_PTRTOINT) .legalForCartesianProduct(allIntScalars, allPtrs); getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( allPtrs, allIntScalars); // ST.canDirectlyComparePointers() for pointer args is supported in // legalizeCustom(). getActionDefinitionsBuilder(G_ICMP).customIf( all(typeInSet(0, allBoolScalarsAndVectors), typeInSet(1, allPtrsScalarsAndVectors))); getActionDefinitionsBuilder(G_FCMP).legalIf( all(typeInSet(0, allBoolScalarsAndVectors), typeInSet(1, allFloatScalarsAndVectors))); getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) .legalForCartesianProduct(allIntScalars, allWritablePtrs); getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); // TODO: add proper legalization rules. getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) .alwaysLegal(); // Extensions. getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) .legalForCartesianProduct(allScalarsAndVectors); // FP conversions. getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) .legalForCartesianProduct(allFloatScalarsAndVectors); // Pointer-handling. getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); // Control-flow. getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); getActionDefinitionsBuilder({G_FPOW, G_FEXP, G_FEXP2, G_FLOG, G_FLOG2, G_FABS, G_FMINNUM, G_FMAXNUM, G_FCEIL, G_FCOS, G_FSIN, G_FSQRT, G_FFLOOR, G_FRINT, G_FNEARBYINT, G_INTRINSIC_ROUND, G_INTRINSIC_TRUNC, G_FMINIMUM, G_FMAXIMUM, G_INTRINSIC_ROUNDEVEN}) .legalFor(allFloatScalarsAndVectors); getActionDefinitionsBuilder(G_FCOPYSIGN) .legalForCartesianProduct(allFloatScalarsAndVectors, allFloatScalarsAndVectors); getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( allFloatScalarsAndVectors, allIntScalarsAndVectors); getLegacyLegalizerInfo().computeTables(); verify(*ST.getInstrInfo()); } static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, LegalizerHelper &Helper, MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) { Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF()); Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) .addDef(ConvReg) .addUse(Reg); return ConvReg; } bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const { auto Opc = MI.getOpcode(); MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); if (!isTypeFoldingSupported(Opc)) { assert(Opc == TargetOpcode::G_ICMP); assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); auto &Op0 = MI.getOperand(2); auto &Op1 = MI.getOperand(3); Register Reg0 = Op0.getReg(); Register Reg1 = Op1.getReg(); CmpInst::Predicate Cond = static_cast(MI.getOperand(1).getPredicate()); if ((!ST->canDirectlyComparePointers() || (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { LLT ConvT = LLT::scalar(ST->getPointerSize()); Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), ST->getPointerSize()); SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); } return true; } // TODO: implement legalization for other opcodes. return true; }