aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td')
-rw-r--r--contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td124
1 files changed, 117 insertions, 7 deletions
diff --git a/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index b51828b54679..91b45583c848 100644
--- a/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/contrib/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -42,14 +42,16 @@ class VOP3_VOP3PInst<string OpName, VOPProfile P, bit UseTiedOutput = 0,
}
let isCommutable = 1 in {
-def V_PK_FMA_F16 : VOP3PInst<"v_pk_fma_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16_V2F16>, fma>;
def V_PK_MAD_I16 : VOP3PInst<"v_pk_mad_i16", VOP3_Profile<VOP_V2I16_V2I16_V2I16_V2I16>>;
def V_PK_MAD_U16 : VOP3PInst<"v_pk_mad_u16", VOP3_Profile<VOP_V2I16_V2I16_V2I16_V2I16>>;
+let FPDPRounding = 1 in {
+def V_PK_FMA_F16 : VOP3PInst<"v_pk_fma_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16_V2F16>, fma>;
def V_PK_ADD_F16 : VOP3PInst<"v_pk_add_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fadd>;
def V_PK_MUL_F16 : VOP3PInst<"v_pk_mul_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fmul>;
-def V_PK_MAX_F16 : VOP3PInst<"v_pk_max_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fmaxnum>;
-def V_PK_MIN_F16 : VOP3PInst<"v_pk_min_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fminnum>;
+} // End FPDPRounding = 1
+def V_PK_MAX_F16 : VOP3PInst<"v_pk_max_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fmaxnum_like>;
+def V_PK_MIN_F16 : VOP3PInst<"v_pk_min_f16", VOP3_Profile<VOP_V2F16_V2F16_V2F16>, fminnum_like>;
def V_PK_ADD_U16 : VOP3PInst<"v_pk_add_u16", VOP3_Profile<VOP_V2I16_V2I16_V2I16>, add>;
def V_PK_ADD_I16 : VOP3PInst<"v_pk_add_i16", VOP3_Profile<VOP_V2I16_V2I16_V2I16>>;
@@ -137,12 +139,14 @@ let SubtargetPredicate = HasMadMixInsts in {
let isCommutable = 1 in {
def V_MAD_MIX_F32 : VOP3_VOP3PInst<"v_mad_mix_f32", VOP3_Profile<VOP_F32_F16_F16_F16, VOP3_OPSEL>>;
+let FPDPRounding = 1 in {
// Clamp modifier is applied after conversion to f16.
def V_MAD_MIXLO_F16 : VOP3_VOP3PInst<"v_mad_mixlo_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>;
let ClampLo = 0, ClampHi = 1 in {
def V_MAD_MIXHI_F16 : VOP3_VOP3PInst<"v_mad_mixhi_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>;
}
+} // End FPDPRounding = 1
}
defm : MadFmaMixPats<fmad, V_MAD_MIX_F32, V_MAD_MIXLO_F16, V_MAD_MIXHI_F16>;
@@ -154,18 +158,99 @@ let SubtargetPredicate = HasFmaMixInsts in {
let isCommutable = 1 in {
def V_FMA_MIX_F32 : VOP3_VOP3PInst<"v_fma_mix_f32", VOP3_Profile<VOP_F32_F16_F16_F16, VOP3_OPSEL>>;
+let FPDPRounding = 1 in {
// Clamp modifier is applied after conversion to f16.
def V_FMA_MIXLO_F16 : VOP3_VOP3PInst<"v_fma_mixlo_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>;
let ClampLo = 0, ClampHi = 1 in {
def V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3_Profile<VOP_F16_F16_F16_F16, VOP3_OPSEL>, 1>;
}
+} // End FPDPRounding = 1
}
defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
}
-let SubtargetPredicate = HasDLInsts in {
+// Defines patterns that extract signed 4bit from each Idx[0].
+foreach Idx = [[0,28],[4,24],[8,20],[12,16],[16,12],[20,8],[24,4]] in
+ def ExtractSigned4bit_#Idx[0] : PatFrag<(ops node:$src),
+ (sra (shl node:$src, (i32 Idx[1])), (i32 28))>;
+
+// Defines code pattern that extracts U(unsigned/signed) 4/8bit from FromBitIndex.
+class Extract<int FromBitIndex, int BitMask, bit U>: PatFrag<
+ (ops node:$src),
+ !if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)), !eq (FromBitIndex, 28)), // last element
+ !if (U, (srl node:$src, (i32 FromBitIndex)), (sra node:$src, (i32 FromBitIndex))),
+ !if (!eq (FromBitIndex, 0), // first element
+ !if (U, (and node:$src, (i32 BitMask)),
+ !if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src),
+ (sext_inreg node:$src, i8))),
+ !if (U, (and (srl node:$src, (i32 FromBitIndex)), (i32 BitMask)),
+ !if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src),
+ (sext_inreg (srl node:$src, (i32 FromBitIndex)), i8)))))>;
+
+
+foreach Type = ["I", "U"] in
+ foreach Index = 0-3 in {
+ // Defines patterns that extract each Index'ed 8bit from an unsigned
+ // 32bit scalar value;
+ def #Type#Index#"_8bit" : Extract<!shl(Index, 3), 255, !if (!eq (Type, "U"), 1, 0)>;
+
+ // Defines multiplication patterns where the multiplication is happening on each
+ // Index'ed 8bit of a 32bit scalar value.
+
+ def Mul#Type#_Elt#Index : PatFrag<
+ (ops node:$src0, node:$src1),
+ (!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), AMDGPUmul_i24_oneuse, AMDGPUmul_u24_oneuse))
+ (!cast<Extract>(#Type#Index#"_8bit") node:$src0),
+ (!cast<Extract>(#Type#Index#"_8bit") node:$src1))>;
+ }
+
+// Different variants of dot8 patterns cause a huge increase in the compile time.
+// Define non-associative/commutative add/mul to prevent permutation in the dot8
+// pattern.
+def NonACAdd : SDNode<"ISD::ADD" , SDTIntBinOp>;
+def NonACAdd_oneuse : HasOneUseBinOp<NonACAdd>;
+
+def NonACAMDGPUmul_u24 : SDNode<"AMDGPUISD::MUL_U24" , SDTIntBinOp>;
+def NonACAMDGPUmul_u24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_u24>;
+
+def NonACAMDGPUmul_i24 : SDNode<"AMDGPUISD::MUL_I24" , SDTIntBinOp>;
+def NonACAMDGPUmul_i24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_i24>;
+
+foreach Type = ["I", "U"] in
+ foreach Index = 0-7 in {
+ // Defines patterns that extract each Index'ed 4bit from an unsigned
+ // 32bit scalar value;
+ def #Type#Index#"_4bit" : Extract<!shl(Index, 2), 15, !if (!eq (Type, "U"), 1, 0)>;
+
+ // Defines multiplication patterns where the multiplication is happening on each
+ // Index'ed 8bit of a 32bit scalar value.
+ def Mul#Type#Index#"_4bit" : PatFrag<
+ (ops node:$src0, node:$src1),
+ (!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), NonACAMDGPUmul_i24_oneuse, NonACAMDGPUmul_u24_oneuse))
+ (!cast<Extract>(#Type#Index#"_4bit") node:$src0),
+ (!cast<Extract>(#Type#Index#"_4bit") node:$src1))>;
+ }
+
+class UDot2Pat<Instruction Inst> : GCNPat <
+ (add (add_oneuse (AMDGPUmul_u24_oneuse (srl i32:$src0, (i32 16)),
+ (srl i32:$src1, (i32 16))), i32:$src2),
+ (AMDGPUmul_u24_oneuse (and i32:$src0, (i32 65535)),
+ (and i32:$src1, (i32 65535)))
+ ),
+ (Inst (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
+>;
+
+class SDot2Pat<Instruction Inst> : GCNPat <
+ (add (add_oneuse (AMDGPUmul_i24_oneuse (sra i32:$src0, (i32 16)),
+ (sra i32:$src1, (i32 16))), i32:$src2),
+ (AMDGPUmul_i24_oneuse (sext_inreg i32:$src0, i16),
+ (sext_inreg i32:$src1, i16))),
+ (Inst (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
+>;
+
+let SubtargetPredicate = HasDotInsts in {
def V_DOT2_F32_F16 : VOP3PInst<"v_dot2_f32_f16", VOP3_Profile<VOP_F32_V2F16_V2F16_F32>>;
def V_DOT2_I32_I16 : VOP3PInst<"v_dot2_i32_i16", VOP3_Profile<VOP_I32_V2I16_V2I16_I32>>;
@@ -192,7 +277,32 @@ defm : DotPats<int_amdgcn_udot4, V_DOT4_U32_U8>;
defm : DotPats<int_amdgcn_sdot8, V_DOT8_I32_I4>;
defm : DotPats<int_amdgcn_udot8, V_DOT8_U32_U4>;
-} // End SubtargetPredicate = HasDLInsts
+def : UDot2Pat<V_DOT2_U32_U16>;
+def : SDot2Pat<V_DOT2_I32_I16>;
+
+foreach Type = ["U", "I"] in
+ def : GCNPat <
+ !cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y,
+ (add_oneuse lhs, (!cast<PatFrag>("Mul"#Type#"_Elt"#y) i32:$src0, i32:$src1)))),
+ (!cast<VOP3PInst>("V_DOT4_"#Type#"32_"#Type#8) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
+
+foreach Type = ["U", "I"] in
+ def : GCNPat <
+ !cast<dag>(!foldl((add_oneuse i32:$src2, (!cast<PatFrag>("Mul"#Type#"0_4bit") i32:$src0, i32:$src1)),
+ [1, 2, 3, 4, 5, 6, 7], lhs, y,
+ (NonACAdd_oneuse lhs, (!cast<PatFrag>("Mul"#Type#y#"_4bit") i32:$src0, i32:$src1)))),
+ (!cast<VOP3PInst>("V_DOT8_"#Type#"32_"#Type#4) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
+
+// Different variants of dot8 code-gen dag patterns are not generated through table-gen due to a huge increase
+// in the compile time. Directly handle the pattern generated by the FE here.
+foreach Type = ["U", "I"] in
+ def : GCNPat <
+ !cast<dag>(!foldl((add_oneuse i32:$src2, (!cast<PatFrag>("Mul"#Type#"0_4bit") i32:$src0, i32:$src1)),
+ [7, 1, 2, 3, 4, 5, 6], lhs, y,
+ (NonACAdd_oneuse lhs, (!cast<PatFrag>("Mul"#Type#y#"_4bit") i32:$src0, i32:$src1)))),
+ (!cast<VOP3PInst>("V_DOT8_"#Type#"32_"#Type#4) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
+
+} // End SubtargetPredicate = HasDotInsts
multiclass VOP3P_Real_vi<bits<10> op> {
def _vi : VOP3P_Real<!cast<VOP3_Pseudo>(NAME), SIEncodingFamily.VI>,
@@ -242,7 +352,7 @@ defm V_FMA_MIXHI_F16 : VOP3P_Real_vi <0x3a2>;
}
-let SubtargetPredicate = HasDLInsts in {
+let SubtargetPredicate = HasDotInsts in {
defm V_DOT2_F32_F16 : VOP3P_Real_vi <0x3a3>;
defm V_DOT2_I32_I16 : VOP3P_Real_vi <0x3a6>;
@@ -252,4 +362,4 @@ defm V_DOT4_U32_U8 : VOP3P_Real_vi <0x3a9>;
defm V_DOT8_I32_I4 : VOP3P_Real_vi <0x3aa>;
defm V_DOT8_U32_U4 : VOP3P_Real_vi <0x3ab>;
-} // End SubtargetPredicate = HasDLInsts
+} // End SubtargetPredicate = HasDotInsts