From 4b6d70da23743e55ee32aee57f659e4465367b31 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 08:38:45 +0000 Subject: [PATCH 1/8] Add 8-bit float emulation for SPIR-V conversion. SPIR-V does not support any 8-bit floats. Threfore, 8-bit floats are emulated as 8-bit integers. --- mlir/include/mlir/Conversion/Passes.td | 18 +++- .../SPIRV/Transforms/SPIRVConversion.h | 4 + .../ControlFlowToSPIRVPass.cpp | 2 + .../FuncToSPIRV/FuncToSPIRVPass.cpp | 2 + .../TensorToSPIRV/TensorToSPIRVPass.cpp | 2 + .../SPIRV/Transforms/SPIRVConversion.cpp | 97 ++++++++++++++++++- 6 files changed, 120 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index cf7596cc8a928..dab29e68c17ec 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> { "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by " "the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by emulating them with integer types of same bit width"> ]; } @@ -416,7 +419,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by emulating them with integer types of same bit width"> ]; } @@ -500,7 +506,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by emulating them with integer types of same bit width"> ]; } @@ -1167,7 +1176,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by emulating them with integer types of same bit width"> ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 3d22ec918f4c5..03ae54a8ae30a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -39,6 +39,10 @@ struct SPIRVConversionOptions { /// The number of bits to store a boolean value. unsigned boolNumBits{8}; + /// Whether to emulate unsupported floats with integer types of same bit + /// width. + bool emulateUnsupportedFloatTypes{true}; + /// How sub-byte values are storaged in memory. SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4df4912..01657cced2281 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = + this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f659afb10..ca67079ce9bb1 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,8 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = + this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386ea80124..309ed8b054628 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,8 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = + this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 35ec0190b5a61..37dd75b586002 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx, // SPIR-V dialect. Keeping it local till the use case arises. static std::optional getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { + if (isa(type)) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: @@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } + // Handle 8-bit floats. + if (options.emulateUnsupportedFloatTypes && isa(type)) { + auto bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 8) + return bitWidth / 8; + else + return std::nullopt; + } + if (auto complexType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) @@ -318,6 +328,67 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, type.getSignedness()); } +/// Converts 8-bit float types to integer types with the same bit width. +/// Returns a nullptr for unsupported 8-bit float types. +static Type convert8BitFloatType(const SPIRVConversionOptions &options, + FloatType type) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa(type)) + return IntegerType::get(type.getContext(), type.getWidth()); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n"); + return nullptr; +} + +/// Converts a sub-byte float ``type` to i32 regardless of target environment. +/// Returns a nullptr for unsupported float types, including non sub-byte +/// types. +/// +/// We are treating 8 bit floats as sub-byte types here due to it's similar +/// nature of being used as a packed format. + +/// Note that we don't recognize +/// sub-byte types in `spirv::ScalarType` and use the above given that these +/// sub-byte types are not supported at all in SPIR-V; there are no +/// compute/storage capability for them like other supported integer types. + +// static Type convertPackedFLoatType(const SPIRVConversionOptions &options, +// FloatType type) { + +// // F4, F6, F8 types are converted to integer types with the same bit width. + +// if (isa(type)) +// auto emulatedType = IntegerType::get(type.getContext(), type.getWidth()); + +// if (type.getWidth() > 8) { +// LLVM_DEBUG(llvm::dbgs() << "not a packed type\n"); +// return nullptr; +// } +// if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) { +// LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); +// return nullptr; +// } + +// // if (!llvm::isPowerOf2_32(type.getWidth())) { +// // LLVM_DEBUG(llvm::dbgs() +// // << "unsupported non-power-of-two bitwidth in sub-byte" << +// type +// // << "\n"); +// // return nullptr; +// // } + +// LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); +// return IntegerType::get(type.getContext(), /*width=*/32, +// type.getSignedness()); +// } + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -339,8 +410,20 @@ convertVectorType(const spirv::TargetEnv &targetEnv, type = cast(convertIndexElementType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { - // If this is not a spec allowed scalar type, try to handle sub-byte integer - // types. + // If this is not a spec allowed scalar type, there are 2 scenarios, + // 8 bit floats or sub-byte integer types. try to handle them accrodingly. + + // Hnadle 8 bit float types. + auto floatType = dyn_cast(type.getElementType()); + if (floatType && floatType.getWidth() == 8) { + // If this is an 8 bit float type, try to convert it to a supported + // integer type. + if (auto convertedType = convert8BitFloatType(options, floatType)) { + return VectorType::get(type.getShape(), convertedType); + } + } + + // Handle sub-byte integer types. auto intType = dyn_cast(type.getElementType()); if (!intType) { LLVM_DEBUG(llvm::dbgs() @@ -596,6 +679,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } else if (auto indexType = dyn_cast(elementType)) { type = cast(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); + } else if (auto floatType = dyn_cast(elementType)) { + // Hnadle 8 bit float types. + if (options.emulateUnsupportedFloatTypes && floatType && + floatType.getWidth() == 8) { + // If this is an 8 bit float type, try to convert it to a supported + // integer type. + arrayElemType = convert8BitFloatType(options, floatType); + } } else { LLVM_DEBUG( llvm::dbgs() @@ -1444,6 +1535,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](FloatType floatType) -> std::optional { if (auto scalarType = dyn_cast(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); + if (floatType.getWidth() == 8) + return convert8BitFloatType(this->options, floatType); return Type(); }); From 4bdd204643589f3807d2e5b6fa5125412143984f Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 08:40:31 +0000 Subject: [PATCH 2/8] Add arith.constant support. Handles scalar and vector. --- .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d43e6816641cb..f066671efd754 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get IntegerAttr from FloatAttr. +IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -296,8 +304,16 @@ struct ConstantCompositeOpPattern final SmallVector elements; if (isa(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + if (srcElemType.getIntOrFloatBitWidth() == 8 && + isa(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +377,17 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa(srcType)) { auto srcAttr = cast(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + if (srcType.getIntOrFloatBitWidth() == 8 && isa(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); From 6a3d341c2fe8a15b45c3c89425d7d39001e9e3e7 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 09:26:06 +0000 Subject: [PATCH 3/8] Handle all Shaped Type 8-bit floats in a similar way. This approach minimizes the code modification. --- .../SPIRV/Transforms/SPIRVConversion.cpp | 55 ++++++++++++------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 37dd75b586002..1ddefb53aa94d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -343,6 +343,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options, return nullptr; } +/// Returns a type with the same shape but with any 8-bit float element type +/// converted to the same bit width integer type. This is a noop when the +/// element type is not the 8-bit float type. +static ShapedType +convertShaped8BitFloatType(ShapedType type, + const SPIRVConversionOptions &options) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + auto srcElementType = type.getElementType(); + Type convertedElementType = nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa(srcElementType)) + convertedElementType = IntegerType::get( + type.getContext(), srcElementType.getIntOrFloatBitWidth()); + + if (!convertedElementType) + return type; + + return type.clone(convertedElementType); +} + /// Converts a sub-byte float ``type` to i32 regardless of target environment. /// Returns a nullptr for unsupported float types, including non sub-byte /// types. @@ -408,22 +431,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional storageClass = {}) { type = cast(convertIndexElementType(type, options)); + type = cast(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { - // If this is not a spec allowed scalar type, there are 2 scenarios, - // 8 bit floats or sub-byte integer types. try to handle them accrodingly. - - // Hnadle 8 bit float types. - auto floatType = dyn_cast(type.getElementType()); - if (floatType && floatType.getWidth() == 8) { - // If this is an 8 bit float type, try to convert it to a supported - // integer type. - if (auto convertedType = convert8BitFloatType(options, floatType)) { - return VectorType::get(type.getShape(), convertedType); - } - } - - // Handle sub-byte integer types. + // If this is not a spec allowed scalar type, try to handle sub-byte integer + // types. auto intType = dyn_cast(type.getElementType()); if (!intType) { LLVM_DEBUG(llvm::dbgs() @@ -516,6 +528,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, } type = cast(convertIndexElementType(type, options)); + type = cast(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() @@ -681,12 +694,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, arrayElemType = type.getElementType(); } else if (auto floatType = dyn_cast(elementType)) { // Hnadle 8 bit float types. - if (options.emulateUnsupportedFloatTypes && floatType && - floatType.getWidth() == 8) { - // If this is an 8 bit float type, try to convert it to a supported - // integer type. - arrayElemType = convert8BitFloatType(options, floatType); - } + type = cast(convertShaped8BitFloatType(type, options)); + arrayElemType = type.getElementType(); + // if (options.emulateUnsupportedFloatTypes && floatType && + // floatType.getWidth() == 8) { + // // If this is an 8 bit float type, try to convert it to a supported + // // integer type. + // arrayElemType = convert8BitFloatType(options, floatType); + // } } else { LLVM_DEBUG( llvm::dbgs() From a65613798fd7d9c0156bddf4c4de2a8cd13a74f6 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 09:29:29 +0000 Subject: [PATCH 4/8] Remove commented out code. --- .../SPIRV/Transforms/SPIRVConversion.cpp | 46 ------------------- 1 file changed, 46 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 1ddefb53aa94d..e00ebfd272bf7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -366,52 +366,6 @@ convertShaped8BitFloatType(ShapedType type, return type.clone(convertedElementType); } -/// Converts a sub-byte float ``type` to i32 regardless of target environment. -/// Returns a nullptr for unsupported float types, including non sub-byte -/// types. -/// -/// We are treating 8 bit floats as sub-byte types here due to it's similar -/// nature of being used as a packed format. - -/// Note that we don't recognize -/// sub-byte types in `spirv::ScalarType` and use the above given that these -/// sub-byte types are not supported at all in SPIR-V; there are no -/// compute/storage capability for them like other supported integer types. - -// static Type convertPackedFLoatType(const SPIRVConversionOptions &options, -// FloatType type) { - -// // F4, F6, F8 types are converted to integer types with the same bit width. - -// if (isa(type)) -// auto emulatedType = IntegerType::get(type.getContext(), type.getWidth()); - -// if (type.getWidth() > 8) { -// LLVM_DEBUG(llvm::dbgs() << "not a packed type\n"); -// return nullptr; -// } -// if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) { -// LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); -// return nullptr; -// } - -// // if (!llvm::isPowerOf2_32(type.getWidth())) { -// // LLVM_DEBUG(llvm::dbgs() -// // << "unsupported non-power-of-two bitwidth in sub-byte" << -// type -// // << "\n"); -// // return nullptr; -// // } - -// LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); -// return IntegerType::get(type.getContext(), /*width=*/32, -// type.getSignedness()); -// } - /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. From 730bace9993f49dbc3b4313408edfa4f8c2c7bde Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 09:37:11 +0000 Subject: [PATCH 5/8] Remove unnecessary commented out code. --- mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index e00ebfd272bf7..3580f7a61ae7e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -650,12 +650,6 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, // Hnadle 8 bit float types. type = cast(convertShaped8BitFloatType(type, options)); arrayElemType = type.getElementType(); - // if (options.emulateUnsupportedFloatTypes && floatType && - // floatType.getWidth() == 8) { - // // If this is an 8 bit float type, try to convert it to a supported - // // integer type. - // arrayElemType = convert8BitFloatType(options, floatType); - // } } else { LLVM_DEBUG( llvm::dbgs() From 4167a4a85a621e70307b68c07d130ca557f1dad1 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 10:05:25 +0000 Subject: [PATCH 6/8] Fix clang-format issue. --- .../Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp | 3 +-- mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp | 3 +-- mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 01657cced2281..56b6181018153 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,8 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; - options.emulateUnsupportedFloatTypes = - this->emulateUnsupportedFloatTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index ca67079ce9bb1..c0439a4033eac 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,8 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; - options.emulateUnsupportedFloatTypes = - this->emulateUnsupportedFloatTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index 309ed8b054628..8cd650e649008 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,8 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; - options.emulateUnsupportedFloatTypes = - this->emulateUnsupportedFloatTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); From 357c290eefa69a532d1e8fc4baea6289d69f8261 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Tue, 15 Jul 2025 20:00:17 +0000 Subject: [PATCH 7/8] Add test case & make arith-to-spirv use emulation flag. --- .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 +++- .../SPIRV/Transforms/SPIRVConversion.cpp | 4 +- .../ArithToSPIRV/arith-to-spirv.mlir | 11 ++++ .../FuncToSPIRV/types-to-spirv.mlir | 54 +++++++++++++++++++ 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index f066671efd754..a9257ceba8f58 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -306,7 +306,9 @@ struct ConstantCompositeOpPattern final for (FloatAttr srcAttr : dstElementsAttr.getValues()) { Attribute dstAttr = nullptr; // Handle 8-bit float conversion to 8-bit integer. - if (srcElemType.getIntOrFloatBitWidth() == 8 && + auto *typeConverter = getTypeConverter(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && isa(dstElemType)) { dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); @@ -381,7 +383,9 @@ struct ConstantScalarOpPattern final // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType.getIntOrFloatBitWidth() == 8 && isa(dstType) && + auto *typeConverter = getTypeConverter(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa(dstType) && dstType.getIntOrFloatBitWidth() == 8) { // If the source is an 8-bit float, convert it to a 8-bit integer. dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); @@ -1374,6 +1378,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 3580f7a61ae7e..4a0ec19b86690 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -345,12 +345,12 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options, /// Returns a type with the same shape but with any 8-bit float element type /// converted to the same bit width integer type. This is a noop when the -/// element type is not the 8-bit float type. +/// element type is not the 8-bit float type or emulation flag is set to false. static ShapedType convertShaped8BitFloatType(ShapedType type, const SPIRVConversionOptions &options) { if (!options.emulateUnsupportedFloatTypes) - return nullptr; + return type; auto srcElementType = type.getElementType(); Type convertedElementType = nullptr; // F8 types are converted to integer types with the same bit width. diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 1abe0fd2ec468..751e727534efe 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -559,6 +559,17 @@ func.func @constant() { return } +// CHECK-LABEL: @constant_8bit_float +func.func @constant_8bit_float() { + // CHECK: spirv.Constant 56 : i8 + %cst = arith.constant 1.0 : f8E4M3 + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2> + return +} + // CHECK-LABEL: @constant_16bit func.func @constant_16bit() { // CHECK: spirv.Constant 4 : i16 diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 1737f4a906bf8..0c77c88334572 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \ // RUN: FileCheck %s --check-prefix=NOEMU +// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \ +// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT //===----------------------------------------------------------------------===// // Integer types @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return } func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } } // end module + + +// ----- + +// Check that 8-bit float types are emulated as i8. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + + // CHECK: spirv.func @float8_to_integer8 + // CHECK-SAME: (%arg0: i8 + // CHECK-SAME: %arg1: i8 + // CHECK-SAME: %arg2: i8 + // CHECK-SAME: %arg3: i8 + // CHECK-SAME: %arg4: i8 + // CHECK-SAME: %arg5: i8 + // CHECK-SAME: %arg6: i8 + // CHECK-SAME: %arg7: i8 + // CHECK-SAME: %arg8: vector<4xi8> + // CHECK-SAME: %arg9: !spirv.ptr [0])>, StorageBuffer> + // CHECK-SAME: %arg10: !spirv.array<4 x i8> + // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8 + // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2 + // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3 + // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN + // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4 + // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU + // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ> + // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class> + // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2> + // UNSUPPORTED_FLOAT-SAME: ) { + + func.func @float8_to_integer8( + %arg0: f8E5M2, // CHECK-NOT: f8E5M2 + %arg1: f8E4M3, // CHECK-NOT: f8E4M3 + %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN + %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ + %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ + %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ + %arg6: f8E3M4, // CHECK-NOT: f8E3M4 + %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU + %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ> + %arg9: memref<8xf8E4M3, #spirv.storage_class>, // CHECK-NOT: memref + %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor + ) { + // CHECK: spirv.Return + return + } +} From 9ce85efede06cd50f93e4f7beae9eee4debcfb98 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Mon, 21 Jul 2025 19:35:37 +0000 Subject: [PATCH 8/8] Address review comments. --- mlir/include/mlir/Conversion/Passes.td | 20 +++++++++++-------- .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 9 ++++++--- .../SPIRV/Transforms/SPIRVConversion.cpp | 8 +++----- .../ArithToSPIRV/arith-to-spirv.mlir | 6 ++++++ 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index dab29e68c17ec..6e1baaf23fcf7 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -197,8 +197,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> { "Emulate narrower scalar types with 32-bit ones if not supported by " "the target">, Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", - "bool", /*default=*/"true", - "Emulate unsupported float types by emulating them with integer types of same bit width"> + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -421,8 +422,9 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> { "Emulate narrower scalar types with 32-bit ones if not supported by" " the target">, Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", - "bool", /*default=*/"true", - "Emulate unsupported float types by emulating them with integer types of same bit width"> + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -508,8 +510,9 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> { "Emulate narrower scalar types with 32-bit ones if not supported by" " the target">, Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", - "bool", /*default=*/"true", - "Emulate unsupported float types by emulating them with integer types of same bit width"> + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -1178,8 +1181,9 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> { "Emulate narrower scalar types with 32-bit ones if not supported by" " the target">, Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", - "bool", /*default=*/"true", - "Emulate unsupported float types by emulating them with integer types of same bit width"> + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index a9257ceba8f58..265293b83f84c 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,9 +99,12 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } -// Get IntegerAttr from FloatAttr. -IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, - ConversionPatternRewriter &rewriter) { +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { APFloat floatVal = floatAttr.getValue(); APInt intVal = floatVal.bitcastToAPInt(); return rewriter.getIntegerAttr(dstType, intVal); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 4a0ec19b86690..8f4c4cc027798 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -169,7 +169,6 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx, // SPIR-V dialect. Keeping it local till the use case arises. static std::optional getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { - if (isa(type)) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: @@ -188,8 +187,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { auto bitWidth = type.getIntOrFloatBitWidth(); if (bitWidth == 8) return bitWidth / 8; - else - return std::nullopt; + return std::nullopt; } if (auto complexType = dyn_cast(type)) { @@ -339,7 +337,7 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options, Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType>(type)) return IntegerType::get(type.getContext(), type.getWidth()); - LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n"); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n"); return nullptr; } @@ -351,7 +349,7 @@ convertShaped8BitFloatType(ShapedType type, const SPIRVConversionOptions &options) { if (!options.emulateUnsupportedFloatTypes) return type; - auto srcElementType = type.getElementType(); + Type srcElementType = type.getElementType(); Type convertedElementType = nullptr; // F8 types are converted to integer types with the same bit width. if (isa : vector<4xi8> %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3> + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8> // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8> return }