Skip to content

Commit 01d179f

Browse files
add the mlir support for SPV_INTEL_tensor_float32_conversion extension
1 parent afce932 commit 01d179f

File tree

5 files changed

+127
-50
lines changed

5 files changed

+127
-50
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
405405
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
406406
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
407407
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
408+
def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
408409

409410
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
410411
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -468,6 +469,7 @@ def SPIRV_ExtensionAttr :
468469
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
469470
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
470471
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
472+
SPV_INTEL_tensor_float32_conversion,
471473
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
472474
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
473475
SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
14651467
];
14661468
}
14671469

1470+
def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
1471+
list<Availability> availability = [
1472+
Extension<[SPV_INTEL_tensor_float32_conversion]>
1473+
];
1474+
}
1475+
14681476
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
14691477
list<Availability> availability = [
14701478
Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
15671575
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
15681576
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
15691577
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
1570-
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
1578+
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
1579+
SPIRV_C_TensorFloat32RoundingINTEL
15711580
]>;
15721581

15731582
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4587,6 +4596,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
45874596
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
45884597
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
45894598
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
4599+
def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
45904600

45914601
def SPIRV_OpcodeAttr :
45924602
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4692,7 +4702,8 @@ def SPIRV_OpcodeAttr :
46924702
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
46934703
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
46944704
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
4695-
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
4705+
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
4706+
SPIRV_OC_OpRoundFToTF32INTEL
46964707
]>;
46974708

46984709
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// at (https://github.com/intel/llvm)
1212
// Supported extensions
1313
// * SPV_INTEL_bfloat16_conversion
14+
// * SPV_INTEL_tensor_float32_conversion
1415
//===----------------------------------------------------------------------===//
1516

1617

@@ -19,7 +20,7 @@
1920

2021
// -----
2122

22-
def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
23+
def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> {
2324
let summary = "See extension SPV_INTEL_bfloat16_conversion";
2425

2526
let description = [{
@@ -58,16 +59,17 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
5859
let results = (outs
5960
SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
6061
);
62+
6163
let assemblyFormat = [{
6264
$operand attr-dict `:` type($operand) `to` type($result)
6365
}];
6466

65-
let hasVerifier = 1;
67+
let hasVerifier = 0;
6668
}
6769

6870
// -----
6971

70-
def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
72+
def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> {
7173
let summary = "See extension SPV_INTEL_bfloat16_conversion";
7274

7375
let description = [{
@@ -107,9 +109,57 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
107109
let assemblyFormat = [{
108110
$operand attr-dict `:` type($operand) `to` type($result)
109111
}];
110-
let hasVerifier = 1;
112+
113+
let hasVerifier = 0;
111114
}
112115

116+
// -----
117+
118+
def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> {
119+
let summary = "See extension SPV_INTEL_tensor_float32_conversion";
120+
121+
let description = [{
122+
Convert value numerically from a 32-bit floating point type to tensor float32,
123+
with rounding to the nearest even.
124+
125+
Result Type must be a scalar or vector of 32-bit floating-point type.
126+
The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
127+
128+
Float Value must be a scalar or vector of floating-point type.
129+
It must have the same number of components as Result Type. The component width must be 32 bits.
130+
131+
Results are computed per component.
132+
133+
#### Example:
134+
135+
```mlir
136+
%1 = spirv.RoundFToTF32 %0 : f32 to f32
137+
%3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
138+
```
139+
140+
}];
141+
142+
let availability = [
143+
MinVersion<SPIRV_V_1_0>,
144+
MaxVersion<SPIRV_V_1_6>,
145+
Extension<[SPV_INTEL_tensor_float32_conversion]>,
146+
Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
147+
];
148+
149+
let arguments = (ins
150+
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
151+
);
152+
153+
let results = (outs
154+
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
155+
);
156+
157+
let assemblyFormat = [{
158+
$operand attr-dict `:` type($operand) `to` type($result)
159+
}];
160+
161+
let hasVerifier = 0;
162+
}
113163

114164
// -----
115165

mlir/lib/Dialect/SPIRV/IR/CastOps.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -269,48 +269,6 @@ LogicalResult ConvertUToFOp::verify() {
269269
/*skipBitWidthCheck=*/true);
270270
}
271271

272-
//===----------------------------------------------------------------------===//
273-
// spirv.INTELConvertBF16ToFOp
274-
//===----------------------------------------------------------------------===//
275-
276-
LogicalResult INTELConvertBF16ToFOp::verify() {
277-
auto operandType = getOperand().getType();
278-
auto resultType = getResult().getType();
279-
// ODS checks that vector result type and vector operand type have the same
280-
// shape.
281-
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
282-
unsigned operandNumElements = vectorType.getNumElements();
283-
unsigned resultNumElements =
284-
llvm::cast<VectorType>(resultType).getNumElements();
285-
if (operandNumElements != resultNumElements) {
286-
return emitOpError(
287-
"operand and result must have same number of elements");
288-
}
289-
}
290-
return success();
291-
}
292-
293-
//===----------------------------------------------------------------------===//
294-
// spirv.INTELConvertFToBF16Op
295-
//===----------------------------------------------------------------------===//
296-
297-
LogicalResult INTELConvertFToBF16Op::verify() {
298-
auto operandType = getOperand().getType();
299-
auto resultType = getResult().getType();
300-
// ODS checks that vector result type and vector operand type have the same
301-
// shape.
302-
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
303-
unsigned operandNumElements = vectorType.getNumElements();
304-
unsigned resultNumElements =
305-
llvm::cast<VectorType>(resultType).getNumElements();
306-
if (operandNumElements != resultNumElements) {
307-
return emitOpError(
308-
"operand and result must have same number of elements");
309-
}
310-
}
311-
return success();
312-
}
313-
314272
//===----------------------------------------------------------------------===//
315273
// spirv.FConvertOp
316274
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
2929
// -----
3030

3131
spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
32-
// expected-error @+1 {{operand and result must have same number of elements}}
32+
// expected-error @+1 {{op requires the same shape for all operands and results}}
3333
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
3434
spirv.Return
3535
}
@@ -65,13 +65,49 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
6565
// -----
6666

6767
spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
68-
// expected-error @+1 {{operand and result must have same number of elements}}
68+
// expected-error @+1 {{op requires the same shape for all operands and results}}
6969
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
7070
spirv.Return
7171
}
7272

7373
// -----
7474

75+
//===----------------------------------------------------------------------===//
76+
// spirv.INTEL.RoundFToTF32
77+
//===----------------------------------------------------------------------===//
78+
79+
spirv.func @f32_to_tf32(%arg0 : f32) "None" {
80+
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
81+
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
82+
spirv.Return
83+
}
84+
85+
// -----
86+
87+
spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
88+
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
89+
%0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
90+
spirv.Return
91+
}
92+
93+
// -----
94+
95+
spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
96+
// expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
97+
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
98+
spirv.Return
99+
}
100+
101+
// -----
102+
103+
spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
104+
// expected-error @+1 {{op requires the same shape for all operands and results}}
105+
%0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
106+
spirv.Return
107+
}
108+
109+
// -----
110+
75111
//===----------------------------------------------------------------------===//
76112
// spirv.INTEL.SplitBarrier
77113
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/intel-ext-ops.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
3232

3333
// -----
3434

35+
//===----------------------------------------------------------------------===//
36+
// spirv.INTEL.RoundFToTF32
37+
//===----------------------------------------------------------------------===//
38+
39+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
40+
// CHECK-LABEL: @f32_to_tf32
41+
spirv.func @f32_to_tf32(%arg0 : f32) "None" {
42+
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
43+
%1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
44+
spirv.Return
45+
}
46+
47+
// CHECK-LABEL: @f32_to_tf32_vec
48+
spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
49+
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
50+
%1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
51+
spirv.Return
52+
}
53+
}
54+
55+
// -----
56+
3557
//===----------------------------------------------------------------------===//
3658
// spirv.INTEL.SplitBarrier
3759
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)