diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 88931b53a6889..d0ae5132252ff 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1779,7 +1779,7 @@ LogicalResult spirv::Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, - "OpConstantNull must have type and result "); + "OpConstantNull must only have type and result "); } Type resultType = getType(operands[0]); @@ -1789,8 +1789,15 @@ spirv::Deserializer::processConstantNull(ArrayRef operands) { } auto resultID = operands[1]; + Attribute attr; if (resultType.isIntOrFloat() || isa(resultType)) { - auto attr = opBuilder.getZeroAttr(resultType); + attr = opBuilder.getZeroAttr(resultType); + } else if (auto tensorType = dyn_cast(resultType)) { + if (auto element = opBuilder.getZeroAttr(tensorType.getElementType())) + attr = DenseElementsAttr::get(tensorType, element); + } + + if (attr) { // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. constantMap.try_emplace(resultID, attr, resultType); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 737f29662f64b..59665ec1add54 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -69,6 +69,25 @@ static Block *getPhiIncomingBlock(Block *block) { return block; } +static bool isZeroValue(Attribute attr) { + if (auto floatAttr = dyn_cast(attr)) { + return floatAttr.getValue().isZero(); + } + if (auto boolAttr = dyn_cast(attr)) { + return !boolAttr.getValue(); + } + if (auto intAttr = dyn_cast(attr)) { + return intAttr.getValue().isZero(); + } + if (auto splatElemAttr = dyn_cast(attr)) { + return isZeroValue(splatElemAttr.getSplatValue()); + } + if (auto denseElemAttr = dyn_cast(attr)) { + return all_of(denseElemAttr.getValues(), isZeroValue); + } + return false; +} + namespace mlir { namespace spirv { @@ -959,6 +978,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, return 0; } } else if (isa(constType)) { + if (isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + return resultID; + } numberOfConstituents = shapedType.getNumElements(); operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { @@ -1202,11 +1226,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc, } uint32_t resultID = getNextID(); - uint32_t operands[] = {typeID, resultID, constandID}; - - encodeInstructionInto(typesGlobalValues, - spirv::Opcode::OpConstantCompositeReplicateEXT, - operands); + if (dyn_cast(resultType) && isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + } else { + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpConstantCompositeReplicateEXT, + {typeID, resultID, constandID}); + } constCompositeReplicateIDMap[valueTypePair] = resultID; return resultID; diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index 1695d2a6a2eb4..3be49eefcaebf 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -335,6 +335,20 @@ spirv.module Logical Vulkan requires #spirv.vce } + // CHECK-LABEL: @null_arm_tensor_of_i32 + spirv.func @null_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + + // CHECK-LABEL: @null_arm_tensor_of_f32 + spirv.func @null_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<0.0> : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } + spirv.EntryPoint "GLCompute" @bool_const } @@ -391,6 +405,13 @@ spirv.module Logical GLSL450 requires #spirv.vce } + // CHECK-LABEL: @null_cc_arm_tensor_of_i32 + spirv.func @null_cc_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK: spirv.Constant dense<0> : !spirv.arm.tensor<2x3xi32> + %0 = spirv.EXT.ConstantCompositeReplicate [0 : i32] : !spirv.arm.tensor<2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> + } + // CHECK-LABEL: @splat_vector_f32 spirv.func @splat_vector_f32() -> (vector<3xf32>) "None" { // CHECK: spirv.EXT.ConstantCompositeReplicate [1.000000e+00 : f32] : vector<3xf32> @@ -439,4 +460,11 @@ spirv.module Logical GLSL450 requires #spirv.vce spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } + + // CHECK-LABEL: @null_cc_arm_tensor_of_f32 + spirv.func @null_cc_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.arm.tensor<2x3xf32> + %0 = spirv.EXT.ConstantCompositeReplicate [0.0 : f32] : !spirv.arm.tensor<2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> + } }