Skip to content

Commit b9a627e

Browse files
authored
[mlir][spirv] Add 8-bit float type emulation (#148811)
8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.
1 parent c8b6ddf commit b9a627e

File tree

9 files changed

+185
-7
lines changed

9 files changed

+185
-7
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
196196
"bool", /*default=*/"true",
197197
"Emulate narrower scalar types with 32-bit ones if not supported by "
198198
"the target">,
199+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
200+
"bool", /*default=*/"true",
201+
"Emulate unsupported float types by representing them with integer "
202+
"types of same bit width">
199203
];
200204
}
201205

@@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
416420
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
417421
"bool", /*default=*/"true",
418422
"Emulate narrower scalar types with 32-bit ones if not supported by"
419-
" the target">
423+
" the target">,
424+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
425+
"bool", /*default=*/"true",
426+
"Emulate unsupported float types by representing them with integer "
427+
"types of same bit width">
420428
];
421429
}
422430

@@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
500508
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
501509
"bool", /*default=*/"true",
502510
"Emulate narrower scalar types with 32-bit ones if not supported by"
503-
" the target">
511+
" the target">,
512+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
513+
"bool", /*default=*/"true",
514+
"Emulate unsupported float types by representing them with integer "
515+
"types of same bit width">
504516
];
505517
}
506518

@@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
11671179
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
11681180
"bool", /*default=*/"true",
11691181
"Emulate narrower scalar types with 32-bit ones if not supported by"
1170-
" the target">
1182+
" the target">,
1183+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
1184+
"bool", /*default=*/"true",
1185+
"Emulate unsupported float types by representing them with integer "
1186+
"types of same bit width">
11711187
];
11721188
}
11731189

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
3939
/// The number of bits to store a boolean value.
4040
unsigned boolNumBits{8};
4141

42+
/// Whether to emulate unsupported floats with integer types of same bit
43+
/// width.
44+
bool emulateUnsupportedFloatTypes{true};
45+
4246
/// How sub-byte values are storaged in memory.
4347
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
4448

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
9999
return builder.getF32FloatAttr(dstVal.convertToFloat());
100100
}
101101

102+
// Get in IntegerAttr from FloatAttr while preserving the bits.
103+
// Useful for converting float constants to integer constants while preserving
104+
// the bits.
105+
static IntegerAttr
106+
getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107+
ConversionPatternRewriter &rewriter) {
108+
APFloat floatVal = floatAttr.getValue();
109+
APInt intVal = floatVal.bitcastToAPInt();
110+
return rewriter.getIntegerAttr(dstType, intVal);
111+
}
112+
102113
/// Returns true if the given `type` is a boolean scalar or vector type.
103114
static bool isBoolScalarOrVector(Type type) {
104115
assert(type && "Not a valid type");
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
296307
SmallVector<Attribute, 8> elements;
297308
if (isa<FloatType>(srcElemType)) {
298309
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
299-
FloatAttr dstAttr =
300-
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
310+
Attribute dstAttr = nullptr;
311+
// Handle 8-bit float conversion to 8-bit integer.
312+
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
313+
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
314+
srcElemType.getIntOrFloatBitWidth() == 8 &&
315+
isa<IntegerType>(dstElemType)) {
316+
dstAttr =
317+
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
318+
} else {
319+
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
320+
rewriter);
321+
}
301322
if (!dstAttr)
302323
return failure();
303324
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
361382
// Floating-point types.
362383
if (isa<FloatType>(srcType)) {
363384
auto srcAttr = cast<FloatAttr>(cstAttr);
364-
auto dstAttr = srcAttr;
385+
Attribute dstAttr = srcAttr;
365386

366387
// Floating-point types not supported in the target environment are all
367388
// converted to float type.
368-
if (srcType != dstType) {
389+
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
390+
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
391+
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
392+
dstType.getIntOrFloatBitWidth() == 8) {
393+
// If the source is an 8-bit float, convert it to a 8-bit integer.
394+
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
395+
if (!dstAttr)
396+
return failure();
397+
} else if (srcType != dstType) {
369398
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
370399
if (!dstAttr)
371400
return failure();
@@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
13521381

13531382
SPIRVConversionOptions options;
13541383
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1384+
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
13551385
SPIRVTypeConverter typeConverter(targetAttr, options);
13561386

13571387
// Use UnrealizedConversionCast as the bridge so that we don't need to pull

mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
4343

4444
SPIRVConversionOptions options;
4545
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
46+
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
4647
SPIRVTypeConverter typeConverter(targetAttr, options);
4748

4849
// TODO: We should also take care of block argument type conversion.

mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
4242

4343
SPIRVConversionOptions options;
4444
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
45+
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
4546
SPIRVTypeConverter typeConverter(targetAttr, options);
4647

4748
RewritePatternSet patterns(context);

mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
4141

4242
SPIRVConversionOptions options;
4343
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
44+
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
4445
SPIRVTypeConverter typeConverter(targetAttr, options);
4546

4647
RewritePatternSet patterns(context);

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
182182
return bitWidth / 8;
183183
}
184184

185+
// Handle 8-bit floats.
186+
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187+
auto bitWidth = type.getIntOrFloatBitWidth();
188+
if (bitWidth == 8)
189+
return bitWidth / 8;
190+
return std::nullopt;
191+
}
192+
185193
if (auto complexType = dyn_cast<ComplexType>(type)) {
186194
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
187195
if (!elementSize)
@@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
318326
type.getSignedness());
319327
}
320328

329+
/// Converts 8-bit float types to integer types with the same bit width.
330+
/// Returns a nullptr for unsupported 8-bit float types.
331+
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
332+
FloatType type) {
333+
if (!options.emulateUnsupportedFloatTypes)
334+
return nullptr;
335+
// F8 types are converted to integer types with the same bit width.
336+
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337+
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338+
Float8E8M0FNUType>(type))
339+
return IntegerType::get(type.getContext(), type.getWidth());
340+
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341+
return nullptr;
342+
}
343+
344+
/// Returns a type with the same shape but with any 8-bit float element type
345+
/// converted to the same bit width integer type. This is a noop when the
346+
/// element type is not the 8-bit float type or emulation flag is set to false.
347+
static ShapedType
348+
convertShaped8BitFloatType(ShapedType type,
349+
const SPIRVConversionOptions &options) {
350+
if (!options.emulateUnsupportedFloatTypes)
351+
return type;
352+
Type srcElementType = type.getElementType();
353+
Type convertedElementType = nullptr;
354+
// F8 types are converted to integer types with the same bit width.
355+
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356+
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357+
Float8E8M0FNUType>(srcElementType))
358+
convertedElementType = IntegerType::get(
359+
type.getContext(), srcElementType.getIntOrFloatBitWidth());
360+
361+
if (!convertedElementType)
362+
return type;
363+
364+
return type.clone(convertedElementType);
365+
}
366+
321367
/// Returns a type with the same shape but with any index element type converted
322368
/// to the matching integer type. This is a noop when the element type is not
323369
/// the index type.
@@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
337383
const SPIRVConversionOptions &options, VectorType type,
338384
std::optional<spirv::StorageClass> storageClass = {}) {
339385
type = cast<VectorType>(convertIndexElementType(type, options));
386+
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
340387
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
341388
if (!scalarType) {
342389
// If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
433480
}
434481

435482
type = cast<TensorType>(convertIndexElementType(type, options));
483+
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
436484
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
437485
if (!scalarType) {
438486
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
596644
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
597645
type = cast<MemRefType>(convertIndexElementType(type, options));
598646
arrayElemType = type.getElementType();
647+
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
648+
// Hnadle 8 bit float types.
649+
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
650+
arrayElemType = type.getElementType();
599651
} else {
600652
LLVM_DEBUG(
601653
llvm::dbgs()
@@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
14441496
addConversion([this](FloatType floatType) -> std::optional<Type> {
14451497
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
14461498
return convertScalarType(this->targetEnv, this->options, scalarType);
1499+
if (floatType.getWidth() == 8)
1500+
return convert8BitFloatType(this->options, floatType);
14471501
return Type();
14481502
});
14491503

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,23 @@ func.func @constant() {
559559
return
560560
}
561561

562+
// CHECK-LABEL: @constant_8bit_float
563+
func.func @constant_8bit_float() {
564+
// CHECK: spirv.Constant 56 : i8
565+
%cst = arith.constant 1.0 : f8E4M3
566+
// CHECK: spirv.Constant 56 : i8
567+
%cst_i8 = arith.bitcast %cst : f8E4M3 to i8
568+
// CHECK: spirv.Constant dense<56> : vector<4xi8>
569+
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
570+
// CHECK: spirv.Constant dense<56> : vector<4xi8>
571+
%cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
572+
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
573+
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
574+
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
575+
%cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
576+
return
577+
}
578+
562579
// CHECK-LABEL: @constant_16bit
563580
func.func @constant_16bit() {
564581
// CHECK: spirv.Constant 4 : i16

mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
22
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
33
// RUN: FileCheck %s --check-prefix=NOEMU
4+
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
5+
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
46

57
//===----------------------------------------------------------------------===//
68
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
944946
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
945947

946948
} // end module
949+
950+
951+
// -----
952+
953+
// Check that 8-bit float types are emulated as i8.
954+
module attributes {
955+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
956+
} {
957+
958+
// CHECK: spirv.func @float8_to_integer8
959+
// CHECK-SAME: (%arg0: i8
960+
// CHECK-SAME: %arg1: i8
961+
// CHECK-SAME: %arg2: i8
962+
// CHECK-SAME: %arg3: i8
963+
// CHECK-SAME: %arg4: i8
964+
// CHECK-SAME: %arg5: i8
965+
// CHECK-SAME: %arg6: i8
966+
// CHECK-SAME: %arg7: i8
967+
// CHECK-SAME: %arg8: vector<4xi8>
968+
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
969+
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
970+
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
971+
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
972+
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
973+
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
974+
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
975+
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
976+
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
977+
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
978+
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
979+
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
980+
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
981+
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
982+
// UNSUPPORTED_FLOAT-SAME: ) {
983+
984+
func.func @float8_to_integer8(
985+
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
986+
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
987+
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
988+
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
989+
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
990+
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
991+
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
992+
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
993+
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
994+
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
995+
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
996+
) {
997+
// CHECK: spirv.Return
998+
return
999+
}
1000+
}

0 commit comments

Comments
 (0)