Skip to content

[mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion" #151337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;

def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
Expand Down Expand Up @@ -468,6 +469,7 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
SPV_INTEL_tensor_float32_conversion,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
Expand Down Expand Up @@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}

def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
list<Availability> availability = [
Extension<[SPV_INTEL_tensor_float32_conversion]>
];
}

def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
Expand Down Expand Up @@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
SPIRV_C_TensorFloat32RoundingINTEL
]>;

def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
Expand Down Expand Up @@ -4587,6 +4596,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;

def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
Expand Down Expand Up @@ -4692,7 +4702,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
SPIRV_OC_OpRoundFToTF32INTEL
]>;

// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
Expand Down
58 changes: 54 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// at (https://github.com/intel/llvm)
// Supported extensions
// * SPV_INTEL_bfloat16_conversion
// * SPV_INTEL_tensor_float32_conversion
//===----------------------------------------------------------------------===//


Expand All @@ -19,7 +20,7 @@

// -----

def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_bfloat16_conversion";

let description = [{
Expand Down Expand Up @@ -58,16 +59,17 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
let results = (outs
SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
);

let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];

let hasVerifier = 1;
let hasVerifier = 0;
}

// -----

def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_bfloat16_conversion";

let description = [{
Expand Down Expand Up @@ -107,9 +109,57 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
let hasVerifier = 1;

let hasVerifier = 0;
}

// -----

def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_tensor_float32_conversion";

let description = [{
Convert value numerically from a 32-bit floating point type to tensor float32,
with rounding to the nearest even.

Result Type must be a scalar or vector of 32-bit floating-point type.
The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.

Float Value must be a scalar or vector of floating-point type.
It must have the same number of components as Result Type. The component width must be 32 bits.

Results are computed per component.

#### Example:

```mlir
%1 = spirv.RoundFToTF32 %0 : f32 to f32
%3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
```

}];

let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[SPV_INTEL_tensor_float32_conversion]>,
Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
];

let arguments = (ins
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
);

let results = (outs
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
);

let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];

let hasVerifier = 0;
}

// -----

Expand Down
42 changes: 0 additions & 42 deletions mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,48 +269,6 @@ LogicalResult ConvertUToFOp::verify() {
/*skipBitWidthCheck=*/true);
}

//===----------------------------------------------------------------------===//
// spirv.INTELConvertBF16ToFOp
//===----------------------------------------------------------------------===//

LogicalResult INTELConvertBF16ToFOp::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}

//===----------------------------------------------------------------------===//
// spirv.INTELConvertFToBF16Op
//===----------------------------------------------------------------------===//

LogicalResult INTELConvertFToBF16Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}

//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
Expand Down
40 changes: 38 additions & 2 deletions mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
// -----

spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
// expected-error @+1 {{operand and result must have same number of elements}}
// expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
spirv.Return
}
Expand Down Expand Up @@ -65,13 +65,49 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
// -----

spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
// expected-error @+1 {{operand and result must have same number of elements}}
// expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
spirv.Return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.INTEL.RoundFToTF32
//===----------------------------------------------------------------------===//

spirv.func @f32_to_tf32(%arg0 : f32) "None" {
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
spirv.Return
}

// -----

spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
%0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
spirv.Return
}

// -----

spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
// expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
spirv.Return
}

// -----

spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
// expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
spirv.Return
}

// -----

//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions mlir/test/Target/SPIRV/intel-ext-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]

// -----

//===----------------------------------------------------------------------===//
// spirv.INTEL.RoundFToTF32
//===----------------------------------------------------------------------===//

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
// CHECK-LABEL: @f32_to_tf32
spirv.func @f32_to_tf32(%arg0 : f32) "None" {
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
%1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
spirv.Return
}

// CHECK-LABEL: @f32_to_tf32_vec
spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
// CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
%1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
spirv.Return
}
}

// -----

//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
Expand Down