Skip to content

Commit 48ea3e4

Browse files
authored
[mlir][spirv] Enable (de)serialization of TensorARM to/from OpConstan… (#151485)
…tNull This patch enables (de)serialization to/from OpConstantNull for null TensorARM --------- Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]>
1 parent a361cde commit 48ea3e4

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,7 @@ LogicalResult
17791779
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
17801780
if (operands.size() != 2) {
17811781
return emitError(unknownLoc,
1782-
"OpConstantNull must have type <id> and result <id>");
1782+
"OpConstantNull must only have type <id> and result <id>");
17831783
}
17841784

17851785
Type resultType = getType(operands[0]);
@@ -1789,8 +1789,15 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
17891789
}
17901790

17911791
auto resultID = operands[1];
1792+
Attribute attr;
17921793
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
1793-
auto attr = opBuilder.getZeroAttr(resultType);
1794+
attr = opBuilder.getZeroAttr(resultType);
1795+
} else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1796+
if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
1797+
attr = DenseElementsAttr::get(tensorType, element);
1798+
}
1799+
1800+
if (attr) {
17941801
// For normal constants, we just record the attribute (and its type) for
17951802
// later materialization at use sites.
17961803
constantMap.try_emplace(resultID, attr, resultType);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,25 @@ static Block *getPhiIncomingBlock(Block *block) {
6969
return block;
7070
}
7171

72+
static bool isZeroValue(Attribute attr) {
73+
if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
74+
return floatAttr.getValue().isZero();
75+
}
76+
if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
77+
return !boolAttr.getValue();
78+
}
79+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
80+
return intAttr.getValue().isZero();
81+
}
82+
if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
83+
return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
84+
}
85+
if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
86+
return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
87+
}
88+
return false;
89+
}
90+
7291
namespace mlir {
7392
namespace spirv {
7493

@@ -959,6 +978,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
959978
return 0;
960979
}
961980
} else if (isa<spirv::TensorArmType>(constType)) {
981+
if (isZeroValue(valueAttr)) {
982+
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
983+
{typeID, resultID});
984+
return resultID;
985+
}
962986
numberOfConstituents = shapedType.getNumElements();
963987
operands.reserve(numberOfConstituents + 2);
964988
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1202,11 +1226,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
12021226
}
12031227

12041228
uint32_t resultID = getNextID();
1205-
uint32_t operands[] = {typeID, resultID, constandID};
1206-
1207-
encodeInstructionInto(typesGlobalValues,
1208-
spirv::Opcode::OpConstantCompositeReplicateEXT,
1209-
operands);
1229+
if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
1230+
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1231+
{typeID, resultID});
1232+
} else {
1233+
encodeInstructionInto(typesGlobalValues,
1234+
spirv::Opcode::OpConstantCompositeReplicateEXT,
1235+
{typeID, resultID, constandID});
1236+
}
12101237

12111238
constCompositeReplicateIDMap[valueTypePair] = resultID;
12121239
return resultID;

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader
335335
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
336336
}
337337

338+
// CHECK-LABEL: @null_arm_tensor_of_i32
339+
spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
340+
// CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
341+
%0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
342+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
343+
}
344+
345+
// CHECK-LABEL: @null_arm_tensor_of_f32
346+
spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
347+
// CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
348+
%0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32>
349+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
350+
}
351+
338352
spirv.EntryPoint "GLCompute" @bool_const
339353
}
340354

@@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
391405
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
392406
}
393407

408+
// CHECK-LABEL: @null_cc_arm_tensor_of_i32
409+
spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" {
410+
// CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32>
411+
%0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32>
412+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32>
413+
}
414+
394415
// CHECK-LABEL: @splat_vector_f32
395416
spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" {
396417
// CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32>
@@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, ReplicatedCompos
439460
%0 = spirv.EXT.ConstantCompositeReplicate [2.0 : f32] : !spirv.arm.tensor<2x3xf32>
440461
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
441462
}
463+
464+
// CHECK-LABEL: @null_cc_arm_tensor_of_f32
465+
spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" {
466+
// CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32>
467+
%0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32>
468+
spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32>
469+
}
442470
}

0 commit comments

Comments
 (0)