@@ -346,6 +346,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
346
346
return nullptr ;
347
347
}
348
348
349
+ // / Returns a type with the same shape but with any 8-bit float element type
350
+ // / converted to the same bit width integer type. This is a noop when the
351
+ // / element type is not the 8-bit float type.
352
+ static ShapedType
353
+ convertShaped8BitFloatType (ShapedType type,
354
+ const SPIRVConversionOptions &options) {
355
+ if (!options.emulateUnsupportedFloatTypes )
356
+ return nullptr ;
357
+ auto srcElementType = type.getElementType ();
358
+ Type convertedElementType = nullptr ;
359
+ // F8 types are converted to integer types with the same bit width.
360
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
361
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
362
+ Float8E8M0FNUType>(srcElementType))
363
+ convertedElementType = IntegerType::get (
364
+ type.getContext (), srcElementType.getIntOrFloatBitWidth ());
365
+
366
+ if (!convertedElementType)
367
+ return type;
368
+
369
+ return type.clone (convertedElementType);
370
+ }
371
+
349
372
// / Converts a sub-byte float ``type` to i32 regardless of target environment.
350
373
// / Returns a nullptr for unsupported float types, including non sub-byte
351
374
// / types.
@@ -411,22 +434,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
411
434
const SPIRVConversionOptions &options, VectorType type,
412
435
std::optional<spirv::StorageClass> storageClass = {}) {
413
436
type = cast<VectorType>(convertIndexElementType (type, options));
437
+ type = cast<VectorType>(convertShaped8BitFloatType (type, options));
414
438
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
415
439
if (!scalarType) {
416
- // If this is not a spec allowed scalar type, there are 2 scenarios,
417
- // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
418
-
419
- // Hnadle 8 bit float types.
420
- auto floatType = dyn_cast<FloatType>(type.getElementType ());
421
- if (floatType && floatType.getWidth () == 8 ) {
422
- // If this is an 8 bit float type, try to convert it to a supported
423
- // integer type.
424
- if (auto convertedType = convert8BitFloatType (options, floatType)) {
425
- return VectorType::get (type.getShape (), convertedType);
426
- }
427
- }
428
-
429
- // Handle sub-byte integer types.
440
+ // If this is not a spec allowed scalar type, try to handle sub-byte integer
441
+ // types.
430
442
auto intType = dyn_cast<IntegerType>(type.getElementType ());
431
443
if (!intType) {
432
444
LLVM_DEBUG (llvm::dbgs ()
@@ -519,6 +531,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
519
531
}
520
532
521
533
type = cast<TensorType>(convertIndexElementType (type, options));
534
+ type = cast<TensorType>(convertShaped8BitFloatType (type, options));
522
535
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
523
536
if (!scalarType) {
524
537
LLVM_DEBUG (llvm::dbgs ()
@@ -684,12 +697,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
684
697
arrayElemType = type.getElementType ();
685
698
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
686
699
// Hnadle 8 bit float types.
687
- if (options.emulateUnsupportedFloatTypes && floatType &&
688
- floatType.getWidth () == 8 ) {
689
- // If this is an 8 bit float type, try to convert it to a supported
690
- // integer type.
691
- arrayElemType = convert8BitFloatType (options, floatType);
692
- }
700
+ type = cast<MemRefType>(convertShaped8BitFloatType (type, options));
701
+ arrayElemType = type.getElementType ();
702
+ // if (options.emulateUnsupportedFloatTypes && floatType &&
703
+ // floatType.getWidth() == 8) {
704
+ // // If this is an 8 bit float type, try to convert it to a supported
705
+ // // integer type.
706
+ // arrayElemType = convert8BitFloatType(options, floatType);
707
+ // }
693
708
} else {
694
709
LLVM_DEBUG (
695
710
llvm::dbgs ()
0 commit comments