Skip to content

Commit 9ce85ef

Browse files
committed
Address review comments.
1 parent 357c290 commit 9ce85ef

File tree

4 files changed

+27
-16
lines changed

4 files changed

+27
-16
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
197197
"Emulate narrower scalar types with 32-bit ones if not supported by "
198198
"the target">,
199199
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
200-
"bool", /*default=*/"true",
201-
"Emulate unsupported float types by emulating them with integer types of same bit width">
200+
"bool", /*default=*/"true",
201+
"Emulate unsupported float types by representing them with integer "
202+
"types of same bit width">
202203
];
203204
}
204205

@@ -421,8 +422,9 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
421422
"Emulate narrower scalar types with 32-bit ones if not supported by"
422423
" the target">,
423424
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
424-
"bool", /*default=*/"true",
425-
"Emulate unsupported float types by emulating them with integer types of same bit width">
425+
"bool", /*default=*/"true",
426+
"Emulate unsupported float types by representing them with integer "
427+
"types of same bit width">
426428
];
427429
}
428430

@@ -508,8 +510,9 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
508510
"Emulate narrower scalar types with 32-bit ones if not supported by"
509511
" the target">,
510512
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
511-
"bool", /*default=*/"true",
512-
"Emulate unsupported float types by emulating them with integer types of same bit width">
513+
"bool", /*default=*/"true",
514+
"Emulate unsupported float types by representing them with integer "
515+
"types of same bit width">
513516
];
514517
}
515518

@@ -1178,8 +1181,9 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
11781181
"Emulate narrower scalar types with 32-bit ones if not supported by"
11791182
" the target">,
11801183
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
1181-
"bool", /*default=*/"true",
1182-
"Emulate unsupported float types by emulating them with integer types of same bit width">
1184+
"bool", /*default=*/"true",
1185+
"Emulate unsupported float types by representing them with integer "
1186+
"types of same bit width">
11831187
];
11841188
}
11851189

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
9999
return builder.getF32FloatAttr(dstVal.convertToFloat());
100100
}
101101

102-
// Get IntegerAttr from FloatAttr.
103-
IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
104-
ConversionPatternRewriter &rewriter) {
102+
// Get in IntegerAttr from FloatAttr while preserving the bits.
103+
// Useful for converting float constants to integer constants while preserving
104+
// the bits.
105+
static IntegerAttr
106+
getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
107+
ConversionPatternRewriter &rewriter) {
105108
APFloat floatVal = floatAttr.getValue();
106109
APInt intVal = floatVal.bitcastToAPInt();
107110
return rewriter.getIntegerAttr(dstType, intVal);

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
169169
// SPIR-V dialect. Keeping it local till the use case arises.
170170
static std::optional<int64_t>
171171
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172-
173172
if (isa<spirv::ScalarType>(type)) {
174173
auto bitWidth = type.getIntOrFloatBitWidth();
175174
// According to the SPIR-V spec:
@@ -188,8 +187,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
188187
auto bitWidth = type.getIntOrFloatBitWidth();
189188
if (bitWidth == 8)
190189
return bitWidth / 8;
191-
else
192-
return std::nullopt;
190+
return std::nullopt;
193191
}
194192

195193
if (auto complexType = dyn_cast<ComplexType>(type)) {
@@ -339,7 +337,7 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
339337
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
340338
Float8E8M0FNUType>(type))
341339
return IntegerType::get(type.getContext(), type.getWidth());
342-
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
340+
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
343341
return nullptr;
344342
}
345343

@@ -351,7 +349,7 @@ convertShaped8BitFloatType(ShapedType type,
351349
const SPIRVConversionOptions &options) {
352350
if (!options.emulateUnsupportedFloatTypes)
353351
return type;
354-
auto srcElementType = type.getElementType();
352+
Type srcElementType = type.getElementType();
355353
Type convertedElementType = nullptr;
356354
// F8 types are converted to integer types with the same bit width.
357355
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,16 @@ func.func @constant() {
563563
func.func @constant_8bit_float() {
564564
// CHECK: spirv.Constant 56 : i8
565565
%cst = arith.constant 1.0 : f8E4M3
566+
// CHECK: spirv.Constant 56 : i8
567+
%cst_i8 = arith.bitcast %cst : f8E4M3 to i8
566568
// CHECK: spirv.Constant dense<56> : vector<4xi8>
567569
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
570+
// CHECK: spirv.Constant dense<56> : vector<4xi8>
571+
%cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
568572
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
569573
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
574+
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
575+
%cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
570576
return
571577
}
572578

0 commit comments

Comments
 (0)