From 01d179fcab3d07b3a37eed94e94f2aaa32c4da6a Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 30 Jul 2025 06:46:00 -0700 Subject: [PATCH 1/2] add the mlir support for SPV_INTEL_tensor_float32_conversion extension --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 15 ++++- .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 58 +++++++++++++++++-- mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 42 -------------- mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 40 ++++++++++++- mlir/test/Target/SPIRV/intel-ext-ops.mlir | 22 +++++++ 5 files changed, 127 insertions(+), 50 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 37ee85b04f1eb..bdfd728d1d0b3 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>; +def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -468,6 +469,7 @@ def SPIRV_ExtensionAttr : SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls, + SPV_INTEL_tensor_float32_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage, @@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B ]; } +def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> { + list availability = [ + Extension<[SPV_INTEL_tensor_float32_conversion]> + ]; +} + def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> { list availability = [ Extension<[SPV_INTEL_cache_controls]> @@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR, - SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR + SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR, + SPIRV_C_TensorFloat32RoundingINTEL ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -4587,6 +4596,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>; def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>; def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>; +def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>; def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ @@ -4692,7 +4702,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL, - SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR + SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR, + SPIRV_OC_OpRoundFToTF32INTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td index 82d26e365fb24..2a7fa534cc3dc 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td @@ -11,6 +11,7 @@ // at (https://github.com/intel/llvm) // Supported extensions // * SPV_INTEL_bfloat16_conversion +// * SPV_INTEL_tensor_float32_conversion //===----------------------------------------------------------------------===// @@ -19,7 +20,7 @@ // ----- -def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> { +def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> { let summary = "See extension SPV_INTEL_bfloat16_conversion"; let description = [{ @@ -58,16 +59,17 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> { let results = (outs SPIRV_ScalarOrVectorOf:$result ); + let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) }]; - let hasVerifier = 1; + let hasVerifier = 0; } // ----- -def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { +def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> { let summary = "See extension SPV_INTEL_bfloat16_conversion"; let description = [{ @@ -107,9 +109,57 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) }]; - let hasVerifier = 1; + + let hasVerifier = 0; } +// ----- + +def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> { + let summary = "See extension SPV_INTEL_tensor_float32_conversion"; + + let description = [{ + Convert value numerically from a 32-bit floating point type to tensor float32, + with rounding to the nearest even. + + Result Type must be a scalar or vector of 32-bit floating-point type. + The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value. + + Float Value must be a scalar or vector of floating-point type. + It must have the same number of components as Result Type. The component width must be 32 bits. + + Results are computed per component. + + #### Example: + + ```mlir + %1 = spirv.RoundFToTF32 %0 : f32 to f32 + %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32> + ``` + + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_tensor_float32_conversion]>, + Capability<[SPIRV_C_TensorFloat32RoundingINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + + let hasVerifier = 0; +} // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index e27dc274673be..fcf4eb6fbcf60 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -269,48 +269,6 @@ LogicalResult ConvertUToFOp::verify() { /*skipBitWidthCheck=*/true); } -//===----------------------------------------------------------------------===// -// spirv.INTELConvertBF16ToFOp -//===----------------------------------------------------------------------===// - -LogicalResult INTELConvertBF16ToFOp::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// spirv.INTELConvertFToBF16Op -//===----------------------------------------------------------------------===// - -LogicalResult INTELConvertFToBF16Op::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - //===----------------------------------------------------------------------===// // spirv.FConvertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index 22352da07cf13..0fdbdb66857f6 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { // ----- spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} + // expected-error @+1 {{op requires the same shape for all operands and results}} %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16> spirv.Return } @@ -65,13 +65,49 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" { // ----- spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} + // expected-error @+1 {{op requires the same shape for all operands and results}} %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32> spirv.Return } // ----- +//===----------------------------------------------------------------------===// +// spirv.INTEL.RoundFToTF32 +//===----------------------------------------------------------------------===// + +spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { + // expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got 'f64'}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32> + spirv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir index 6d2fd324363c6..53cf8bf8fbd62 100644 --- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir +++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir @@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK-LABEL: @f32_to_tf32 + spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return + } + + // CHECK-LABEL: @f32_to_tf32_vec + spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// From cd465dc4ebbd3682a22df411a274462ea871b403 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Fri, 1 Aug 2025 11:55:01 -0700 Subject: [PATCH 2/2] update intel-ext-ops.mlir --- mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index 0fdbdb66857f6..2e2fb1a9df328 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { // ----- spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got 'f64'}} + // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}} %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 spirv.Return }