diff options
Diffstat (limited to 'contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td')
-rw-r--r-- | contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td | 124 |
1 files changed, 117 insertions, 7 deletions
diff --git a/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index b51828b54679..91b45583c848 100644 --- a/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -42,14 +42,16 @@ class VOP3_VOP3PInst<string OpName, VOPProfile P, bit UseTiedOutput = 0, } let isCommutable = 1 in { -def V_PK_FMA_F16 : VOP3PInst<"v_pk_fma_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16_V2F16>, fma>; def V_PK_MAD_I16 : VOP3PInst<"v_pk_mad_i16", VOP3_Profile<VOP_V2I16_V2I16_V2I16_V2I16>>; def V_PK_MAD_U16 : VOP3PInst<"v_pk_mad_u16", VOP3_Profile<VOP_V2I16_V2I16_V2I16_V2I16>>; +let FPDPRounding = 1 in { +def V_PK_FMA_F16 : VOP3PInst<"v_pk_fma_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16_V2F16>, fma>; def V_PK_ADD_F16 : VOP3PInst<"v_pk_add_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fadd>; def V_PK_MUL_F16 : VOP3PInst<"v_pk_mul_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fmul>; -def V_PK_MAX_F16 : VOP3PInst<"v_pk_max_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fmaxnum>; -def V_PK_MIN_F16 : VOP3PInst<"v_pk_min_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fminnum>; +} // End FPDPRounding = 1 +def V_PK_MAX_F16 : VOP3PInst<"v_pk_max_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fmaxnum_like>; +def V_PK_MIN_F16 : VOP3PInst<"v_pk_min_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fminnum_like>; def V_PK_ADD_U16 : VOP3PInst<"v_pk_add_u16", VOP3_Profile<VOP_V2I16_V2I16_V2I16>, add>; def V_PK_ADD_I16 : VOP3PInst<"v_pk_add_i16", VOP3_Profile<VOP_V2I16_V2I16_V2I16>>; @@ -137,12 +139,14 @@ let SubtargetPredicate = HasMadMixInsts in { let isCommutable = 1 in { def V_MAD_MIX_F32 : VOP3_VOP3PInst<"v_mad_mix_f32", VOP3_Profile<VOP_F32_F16_F16_F16, VOP3_OPSEL>>; +let FPDPRounding = 1 in { // Clamp modifier is applied after conversion to f16. def V_MAD_MIXLO_F16 : VOP3_VOP3PInst<"v_mad_mixlo_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>; let ClampLo = 0, ClampHi = 1 in { def V_MAD_MIXHI_F16 : VOP3_VOP3PInst<"v_mad_mixhi_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>; } +} // End FPDPRounding = 1 } defm : MadFmaMixPats<fmad, V_MAD_MIX_F32, V_MAD_MIXLO_F16, V_MAD_MIXHI_F16>; @@ -154,18 +158,99 @@ let SubtargetPredicate = HasFmaMixInsts in { let isCommutable = 1 in { def V_FMA_MIX_F32 : VOP3_VOP3PInst<"v_fma_mix_f32", VOP3_Profile<VOP_F32_F16_F16_F16, VOP3_OPSEL>>; +let FPDPRounding = 1 in { // Clamp modifier is applied after conversion to f16. def V_FMA_MIXLO_F16 : VOP3_VOP3PInst<"v_fma_mixlo_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>; let ClampLo = 0, ClampHi = 1 in { def V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>; } +} // End FPDPRounding = 1 } defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>; } -let SubtargetPredicate = HasDLInsts in { +// Defines patterns that extract signed 4bit from each Idx[0]. +foreach Idx = [[0,28],[4,24],[8,20],[12,16],[16,12],[20,8],[24,4]] in + def ExtractSigned4bit_#Idx[0] : PatFrag<(ops node:$src), + (sra (shl node:$src, (i32 Idx[1])), (i32 28))>; + +// Defines code pattern that extracts U(unsigned/signed) 4/8bit from FromBitIndex. +class Extract<int FromBitIndex, int BitMask, bit U>: PatFrag< + (ops node:$src), + !if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)), !eq (FromBitIndex, 28)), // last element + !if (U, (srl node:$src, (i32 FromBitIndex)), (sra node:$src, (i32 FromBitIndex))), + !if (!eq (FromBitIndex, 0), // first element + !if (U, (and node:$src, (i32 BitMask)), + !if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src), + (sext_inreg node:$src, i8))), + !if (U, (and (srl node:$src, (i32 FromBitIndex)), (i32 BitMask)), + !if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src), + (sext_inreg (srl node:$src, (i32 FromBitIndex)), i8)))))>; + + +foreach Type = ["I", "U"] in + foreach Index = 0-3 in { + // Defines patterns that extract each Index'ed 8bit from an unsigned + // 32bit scalar value; + def #Type#Index#"_8bit" : Extract<!shl(Index, 3), 255, !if (!eq (Type, "U"), 1, 0)>; + + // Defines multiplication patterns where the multiplication is happening on each + // Index'ed 8bit of a 32bit scalar value. + + def Mul#Type#_Elt#Index : PatFrag< + (ops node:$src0, node:$src1), + (!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), AMDGPUmul_i24_oneuse, AMDGPUmul_u24_oneuse)) + (!cast<Extract>(#Type#Index#"_8bit") node:$src0), + (!cast<Extract>(#Type#Index#"_8bit") node:$src1))>; + } + +// Different variants of dot8 patterns cause a huge increase in the compile time. +// Define non-associative/commutative add/mul to prevent permutation in the dot8 +// pattern. +def NonACAdd : SDNode<"ISD::ADD" , SDTIntBinOp>; +def NonACAdd_oneuse : HasOneUseBinOp<NonACAdd>; + +def NonACAMDGPUmul_u24 : SDNode<"AMDGPUISD::MUL_U24" , SDTIntBinOp>; +def NonACAMDGPUmul_u24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_u24>; + +def NonACAMDGPUmul_i24 : SDNode<"AMDGPUISD::MUL_I24" , SDTIntBinOp>; +def NonACAMDGPUmul_i24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_i24>; + +foreach Type = ["I", "U"] in + foreach Index = 0-7 in { + // Defines patterns that extract each Index'ed 4bit from an unsigned + // 32bit scalar value; + def #Type#Index#"_4bit" : Extract<!shl(Index, 2), 15, !if (!eq (Type, "U"), 1, 0)>; + + // Defines multiplication patterns where the multiplication is happening on each + // Index'ed 8bit of a 32bit scalar value. + def Mul#Type#Index#"_4bit" : PatFrag< + (ops node:$src0, node:$src1), + (!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), NonACAMDGPUmul_i24_oneuse, NonACAMDGPUmul_u24_oneuse)) + (!cast<Extract>(#Type#Index#"_4bit") node:$src0), + (!cast<Extract>(#Type#Index#"_4bit") node:$src1))>; + } + +class UDot2Pat<Instruction Inst> : GCNPat < + (add (add_oneuse (AMDGPUmul_u24_oneuse (srl i32:$src0, (i32 16)), + (srl i32:$src1, (i32 16))), i32:$src2), + (AMDGPUmul_u24_oneuse (and i32:$src0, (i32 65535)), + (and i32:$src1, (i32 65535))) + ), + (Inst (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0)) +>; + +class SDot2Pat<Instruction Inst> : GCNPat < + (add (add_oneuse (AMDGPUmul_i24_oneuse (sra i32:$src0, (i32 16)), + (sra i32:$src1, (i32 16))), i32:$src2), + (AMDGPUmul_i24_oneuse (sext_inreg i32:$src0, i16), + (sext_inreg i32:$src1, i16))), + (Inst (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0)) +>; + +let SubtargetPredicate = HasDotInsts in { def V_DOT2_F32_F16 : VOP3PInst<"v_dot2_f32_f16", VOP3_Profile<VOP_F32_V2F16_V2F16_F32>>; def V_DOT2_I32_I16 : VOP3PInst<"v_dot2_i32_i16", VOP3_Profile<VOP_I32_V2I16_V2I16_I32>>; @@ -192,7 +277,32 @@ defm : DotPats<int_amdgcn_udot4, V_DOT4_U32_U8>; defm : DotPats<int_amdgcn_sdot8, V_DOT8_I32_I4>; defm : DotPats<int_amdgcn_udot8, V_DOT8_U32_U4>; -} // End SubtargetPredicate = HasDLInsts +def : UDot2Pat<V_DOT2_U32_U16>; +def : SDot2Pat<V_DOT2_I32_I16>; + +foreach Type = ["U", "I"] in + def : GCNPat < + !cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y, + (add_oneuse lhs, (!cast<PatFrag>("Mul"#Type#"_Elt"#y) i32:$src0, i32:$src1)))), + (!cast<VOP3PInst>("V_DOT4_"#Type#"32_"#Type#8) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>; + +foreach Type = ["U", "I"] in + def : GCNPat < + !cast<dag>(!foldl((add_oneuse i32:$src2, (!cast<PatFrag>("Mul"#Type#"0_4bit") i32:$src0, i32:$src1)), + [1, 2, 3, 4, 5, 6, 7], lhs, y, + (NonACAdd_oneuse lhs, (!cast<PatFrag>("Mul"#Type#y#"_4bit") i32:$src0, i32:$src1)))), + (!cast<VOP3PInst>("V_DOT8_"#Type#"32_"#Type#4) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>; + +// Different variants of dot8 code-gen dag patterns are not generated through table-gen due to a huge increase +// in the compile time. Directly handle the pattern generated by the FE here. +foreach Type = ["U", "I"] in + def : GCNPat < + !cast<dag>(!foldl((add_oneuse i32:$src2, (!cast<PatFrag>("Mul"#Type#"0_4bit") i32:$src0, i32:$src1)), + [7, 1, 2, 3, 4, 5, 6], lhs, y, + (NonACAdd_oneuse lhs, (!cast<PatFrag>("Mul"#Type#y#"_4bit") i32:$src0, i32:$src1)))), + (!cast<VOP3PInst>("V_DOT8_"#Type#"32_"#Type#4) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>; + +} // End SubtargetPredicate = HasDotInsts multiclass VOP3P_Real_vi<bits<10> op> { def _vi : VOP3P_Real<!cast<VOP3_Pseudo>(NAME), SIEncodingFamily.VI>, @@ -242,7 +352,7 @@ defm V_FMA_MIXHI_F16 : VOP3P_Real_vi <0x3a2>; } -let SubtargetPredicate = HasDLInsts in { +let SubtargetPredicate = HasDotInsts in { defm V_DOT2_F32_F16 : VOP3P_Real_vi <0x3a3>; defm V_DOT2_I32_I16 : VOP3P_Real_vi <0x3a6>; @@ -252,4 +362,4 @@ defm V_DOT4_U32_U8 : VOP3P_Real_vi <0x3a9>; defm V_DOT8_I32_I4 : VOP3P_Real_vi <0x3aa>; defm V_DOT8_U32_U4 : VOP3P_Real_vi <0x3ab>; -} // End SubtargetPredicate = HasDLInsts +} // End SubtargetPredicate = HasDotInsts |