Skip to content

[mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion" #151337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

YixingZhang007
Copy link
Contributor

This PR provides the support for the capability TensorFloat32RoundingINTEL and the instruction OpRoundFToTF32INTEL, as specified by the SPV_INTEL_tensor_float32_conversion extension.
This extension introduces a rounding instruction that converts standard 32-bit floating-point values to the TensorFloat32 (TF32) format.

Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_tensor_float32_conversion.asciidoc

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 30, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: None (YixingZhang007)

Changes

This PR provides the support for the capability TensorFloat32RoundingINTEL and the instruction OpRoundFToTF32INTEL, as specified by the SPV_INTEL_tensor_float32_conversion extension.
This extension introduces a rounding instruction that converts standard 32-bit floating-point values to the TensorFloat32 (TF32) format.

Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_tensor_float32_conversion.asciidoc


Full diff: https://github.com/llvm/llvm-project/pull/151337.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+14-3)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td (+54)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CastOps.cpp (+21)
  • (modified) mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir (+36)
  • (modified) mlir/test/Target/SPIRV/intel-ext-ops.mlir (+22)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 90383265002a3..9c9eefd054fa6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_me
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 def SPV_INTEL_cache_controls                     : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion          : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
       SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
       SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
       SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
-      SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
+      SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
+      SPV_INTEL_tensor_float32_conversion
     ]>;
 
 //===----------------------------------------------------------------------===//
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"B
   ];
 }
 
+def SPIRV_C_TensorFloat32RoundingINTEL                       : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_tensor_float32_conversion]>
+  ];
+}
+
 def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
-      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+      SPIRV_C_TensorFloat32RoundingINTEL
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL      : I32EnumAttrCase<"OpControlBarrie
 def SPIRV_OC_OpControlBarrierWaitINTEL        : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
 def SPIRV_OC_OpGroupIMulKHR                   : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR                   : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL              : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
 
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
-      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+      SPIRV_OC_OpRoundFToTF32INTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..b692c07122683 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
 // at (https://github.com/intel/llvm)
 // Supported extensions
 // * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
 //===----------------------------------------------------------------------===//
 
 
@@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
   let hasVerifier = 1;
 }
 
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
+  let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+  let description = [{
+    Convert value numerically from a 32-bit floating point type to tensor float32,
+    with rounding to the nearest even.
+
+    Result Type must be a scalar or vector of 32-bit floating-point type.
+    The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+    Float Value must be a scalar or vector of floating-point type.
+    It must have the same number of components as Result Type. The component width must be 32 bits.
+
+    Results are computed per component.
+  
+
+    ```
+    convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
+                          `:` operand-type `to` result-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %1 = spirv.RoundFToTF32 %0 : f32 to f32
+    %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+    ```
+
+  }];
+
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_tensor_float32_conversion]>,
+    Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+  );
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
 
 // -----
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fc3e7308356bf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.INTELRoundFToTF32Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELRoundFToTF32Op::verify() {
+  auto operandType = getOperand().getType();
+  auto resultType = getResult().getType();
+  // ODS checks that vector result type and vector operand type have the same
+  // shape.
+  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+    unsigned operandNumElements = vectorType.getNumElements();
+    unsigned resultNumElements =
+        llvm::cast<VectorType>(resultType).getNumElements();
+    if (operandNumElements != resultNumElements) {
+      return emitOpError(
+          "operand and result must have same number of elements");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.FConvertOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index bb15d018a6c44..aa5bee5796cfa 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -72,6 +72,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+  // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+  // expected-error @+1 {{operand and result must have same number of elements}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+  // CHECK-LABEL: @f32_to_tf32
+  spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @f32_to_tf32_vec
+  spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+    spirv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to submit the change to SPIRV_VectorOf as a separate PR, since there is quite a few tests changes that are unrelated to this extension. But as long as @kuhar approves it as it is, so I do. Rest of the change LGTM.

@kuhar
Copy link
Member

kuhar commented Aug 1, 2025

It'd be nice to submit the change to SPIRV_VectorOf as a separate PR

This is true and landing it separately would be safer / easier to revert. We could land a PR like this very quickly if @YixingZhang007 or any of the reviewers want to prepare one -- the test changes are mechanical.

@YixingZhang007
Copy link
Contributor Author

It'd be nice to submit the change to SPIRV_VectorOf as a separate PR

This is true and landing it separately would be safer / easier to revert. We could land a PR like this very quickly if @YixingZhang007 or any of the reviewers want to prepare one -- the test changes are mechanical.

This is a good point! I’ve reverted the changes to SPIRV_VectorOf and its associated tests. Currently, vector scalability isn’t enforced for any extensions, including SPV_INTEL_tensor_float32_conversion. I will create a PR for enforcing fixed-size vectors once this PR is merged.

@kuhar
Copy link
Member

kuhar commented Aug 1, 2025

Can we reverse the order and first switch to fixed vectors and then land this?

@kuhar kuhar changed the title [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion " [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion" Aug 1, 2025
@YixingZhang007
Copy link
Contributor Author

Can we reverse the order and first switch to fixed vectors and then land this?

For sure! I have created a PR for enforcing fixed-size vectors in SPRI-V, found at #151738. I will update this PR after the other one is merged.

@YixingZhang007 YixingZhang007 force-pushed the add_SPV_INTEL_tensor_float32_conversion_mlir branch from b12d21e to 01d179f Compare August 1, 2025 18:43
@kuhar kuhar merged commit b63a9b7 into llvm:main Aug 1, 2025
9 checks passed
Copy link

github-actions bot commented Aug 1, 2025

@YixingZhang007 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants