Skip to content

Commit 302bd54

Browse files
committed
Handle all Shaped Type 8-bit floats in a similar way.
This approach minimizes the code modification.
1 parent dda7834 commit 302bd54

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
346346
return nullptr;
347347
}
348348

349+
/// Returns a type with the same shape but with any 8-bit float element type
350+
/// converted to the same bit width integer type. This is a noop when the
351+
/// element type is not the 8-bit float type.
352+
static ShapedType
353+
convertShaped8BitFloatType(ShapedType type,
354+
const SPIRVConversionOptions &options) {
355+
if (!options.emulateUnsupportedFloatTypes)
356+
return nullptr;
357+
auto srcElementType = type.getElementType();
358+
Type convertedElementType = nullptr;
359+
// F8 types are converted to integer types with the same bit width.
360+
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
361+
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
362+
Float8E8M0FNUType>(srcElementType))
363+
convertedElementType = IntegerType::get(
364+
type.getContext(), srcElementType.getIntOrFloatBitWidth());
365+
366+
if (!convertedElementType)
367+
return type;
368+
369+
return type.clone(convertedElementType);
370+
}
371+
349372
/// Converts a sub-byte float ``type` to i32 regardless of target environment.
350373
/// Returns a nullptr for unsupported float types, including non sub-byte
351374
/// types.
@@ -411,22 +434,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
411434
const SPIRVConversionOptions &options, VectorType type,
412435
std::optional<spirv::StorageClass> storageClass = {}) {
413436
type = cast<VectorType>(convertIndexElementType(type, options));
437+
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
414438
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
415439
if (!scalarType) {
416-
// If this is not a spec allowed scalar type, there are 2 scenarios,
417-
// 8 bit floats or sub-byte integer types. try to handle them accrodingly.
418-
419-
// Hnadle 8 bit float types.
420-
auto floatType = dyn_cast<FloatType>(type.getElementType());
421-
if (floatType && floatType.getWidth() == 8) {
422-
// If this is an 8 bit float type, try to convert it to a supported
423-
// integer type.
424-
if (auto convertedType = convert8BitFloatType(options, floatType)) {
425-
return VectorType::get(type.getShape(), convertedType);
426-
}
427-
}
428-
429-
// Handle sub-byte integer types.
440+
// If this is not a spec allowed scalar type, try to handle sub-byte integer
441+
// types.
430442
auto intType = dyn_cast<IntegerType>(type.getElementType());
431443
if (!intType) {
432444
LLVM_DEBUG(llvm::dbgs()
@@ -519,6 +531,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
519531
}
520532

521533
type = cast<TensorType>(convertIndexElementType(type, options));
534+
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
522535
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
523536
if (!scalarType) {
524537
LLVM_DEBUG(llvm::dbgs()
@@ -684,12 +697,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
684697
arrayElemType = type.getElementType();
685698
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
686699
// Hnadle 8 bit float types.
687-
if (options.emulateUnsupportedFloatTypes && floatType &&
688-
floatType.getWidth() == 8) {
689-
// If this is an 8 bit float type, try to convert it to a supported
690-
// integer type.
691-
arrayElemType = convert8BitFloatType(options, floatType);
692-
}
700+
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
701+
arrayElemType = type.getElementType();
702+
// if (options.emulateUnsupportedFloatTypes && floatType &&
703+
// floatType.getWidth() == 8) {
704+
// // If this is an 8 bit float type, try to convert it to a supported
705+
// // integer type.
706+
// arrayElemType = convert8BitFloatType(options, floatType);
707+
// }
693708
} else {
694709
LLVM_DEBUG(
695710
llvm::dbgs()

0 commit comments

Comments
 (0)