Skip to content

Commit c0fa432

Browse files
[SPIR-V] Add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion (#150090)
This PR introduces the support for the SPIR-V extension `SPV_INTEL_tensor_float32_conversion` and the corresponding OpenCL extension `cl_intel_tensor_float32_conversions`. This extension introduces a rounding instruction that converts standard 32-bit floating-point values to the TensorFloat32 (TF32) format. Reference Specification: https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_tensor_float32_conversion.asciidoc
1 parent 12eab1a commit c0fa432

File tree

9 files changed

+142
-6
lines changed

9 files changed

+142
-6
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct ConvertBuiltin {
148148
bool IsSaturated;
149149
bool IsRounded;
150150
bool IsBfloat16;
151+
bool IsTF32;
151152
FPRoundingMode::FPRoundingMode RoundingMode;
152153
};
153154

@@ -230,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
230231
// - "__spirv_SubgroupImageMediaBlockReadINTEL"
231232
// - "__spirv_SubgroupImageMediaBlockWriteINTEL"
232233
// - "__spirv_Convert"
234+
// - "__spirv_Round"
233235
// - "__spirv_UConvert"
234236
// - "__spirv_SConvert"
235237
// - "__spirv_FConvert"
@@ -242,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
242244
"SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
243245
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
244246
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
245-
"Convert|"
247+
"Convert|Round|"
246248
"UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
247249
std::smatch Match;
248250
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
@@ -697,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
697699
MachineIRBuilder &MIRBuilder,
698700
SPIRVGlobalRegistry *GR) {
699701
if (Call->isSpirvOp())
700-
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
702+
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
703+
Register(0));
701704

702705
Register ScopeRegister =
703706
buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -2677,8 +2680,20 @@ static bool generateConvertInst(const StringRef DemangledCall,
26772680
}
26782681
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
26792682
SPIRV::OpTypeFloat)) {
2680-
// Float -> Float
2681-
Opcode = SPIRV::OpFConvert;
2683+
if (Builtin->IsTF32) {
2684+
const auto *ST = static_cast<const SPIRVSubtarget *>(
2685+
&MIRBuilder.getMF().getSubtarget());
2686+
if (!ST->canUseExtension(
2687+
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion))
2688+
NeedExtMsg = "SPV_INTEL_tensor_float32_conversion";
2689+
IsRightComponentsNumber =
2690+
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2691+
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2692+
Opcode = SPIRV::OpRoundFToTF32INTEL;
2693+
} else {
2694+
// Float -> Float
2695+
Opcode = SPIRV::OpFConvert;
2696+
}
26822697
}
26832698
}
26842699

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
14611461
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
14621462
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
14631463
!not(!eq(!find(name, "bfloat16"), -1)));
1464+
bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)),
1465+
!not(!eq(!find(name, "tensor_float32"), -1)));
14641466
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
14651467
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
14661468
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1472,7 +1474,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
14721474
def ConvertBuiltins : GenericTable {
14731475
let FilterClass = "ConvertBuiltin";
14741476
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
1475-
"IsRounded", "IsBfloat16", "RoundingMode"];
1477+
"IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"];
14761478
string TypeOf_Set = "InstructionSet";
14771479
string TypeOf_RoundingMode = "FPRoundingMode";
14781480
}
@@ -1556,6 +1558,25 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
15561558
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
15571559
}
15581560

1561+
// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion
1562+
// Multiclass used to define at the same time both a demangled builtin record
1563+
// and a corresponding convert builtin record.
1564+
multiclass DemangledTF32RoundBuiltin<string name1, string name2> {
1565+
// Create records for scalar and vector conversions.
1566+
foreach i = ["", "2", "3", "4", "8", "16"] in {
1567+
def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
1568+
def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>;
1569+
}
1570+
}
1571+
1572+
defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">;
1573+
defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">;
1574+
1575+
foreach conv = ["FToTF32INTEL"] in {
1576+
def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>;
1577+
def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>;
1578+
}
1579+
15591580
//===----------------------------------------------------------------------===//
15601581
// Class defining a vector data load/store builtin record used for lowering
15611582
// into OpExtInst instruction.

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
102102
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103103
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
104104
{"SPV_KHR_float_controls2",
105-
SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
105+
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
106+
{"SPV_INTEL_tensor_float32_conversion",
107+
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
106108

107109
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
108110
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,9 @@ def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938
445445
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
446446
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
447447

448+
// SPV_INTEL_tensor_float32_conversion
449+
def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;
450+
448451
// 3.42.12 Composite Instructions
449452

450453
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,13 @@ void addInstrRequirements(const MachineInstr &MI,
15641564
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
15651565
}
15661566
break;
1567+
case SPIRV::OpRoundFToTF32INTEL:
1568+
if (ST.canUseExtension(
1569+
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1570+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1571+
Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
1572+
}
1573+
break;
15671574
case SPIRV::OpVariableLengthArrayINTEL:
15681575
case SPIRV::OpSaveMemoryINTEL:
15691576
case SPIRV::OpRestoreMemoryINTEL:

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
320320
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
321321
defm SPV_INTEL_int4 : ExtensionOperand<123>;
322322
defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
323+
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>;
323324

324325
//===----------------------------------------------------------------------===//
325326
// Multiclass used to define Capabilities enum values and at the same time
@@ -529,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d
529530
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
530531
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
531532
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
533+
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
532534

533535
//===----------------------------------------------------------------------===//
534536
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; CHECK-ERROR: result and argument must have the same number of components
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
5+
target triple = "spir64-unknown-unknown"
6+
7+
define spir_func void @test(<8 x float> %in) {
8+
%res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
9+
ret void
10+
}
11+
12+
declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; CHECK-ERROR: result and argument must have the same number of components
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
5+
target triple = "spir64-unknown-unknown"
6+
7+
define spir_func void @test(<8 x float> %in) {
8+
%res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
9+
ret void
10+
}
11+
12+
declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %}
3+
4+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
5+
; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion
6+
7+
; CHECK: OpCapability TensorFloat32RoundingINTEL
8+
; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion"
9+
10+
; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
11+
; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
12+
; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
13+
; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
14+
; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
15+
; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
16+
; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
17+
; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5
18+
19+
; CHECK: OpFunction %[[VoidTy]]
20+
; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
21+
; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
22+
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]]
23+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]]
24+
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]]
25+
26+
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]]
27+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]]
28+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]]
29+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]]
30+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]]
31+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]]
32+
33+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
34+
target triple = "spir64-unknown-unknown"
35+
36+
define spir_func void @test(float %a, <8 x float> %in) {
37+
%res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
38+
%res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
39+
%res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00)
40+
ret void
41+
}
42+
43+
declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)
44+
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
45+
46+
define dso_local spir_kernel void @test_ocl(float %a) {
47+
entry:
48+
%res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00)
49+
%res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer)
50+
%res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer)
51+
%res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer)
52+
%res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer)
53+
%res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer)
54+
ret void
55+
}
56+
57+
declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float)
58+
declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>)
59+
declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>)
60+
declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>)
61+
declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>)
62+
declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>)

0 commit comments

Comments
 (0)