diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h index b5a12426aff80..113765157946d 100644 --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -15,10 +15,16 @@ #ifndef MLIR_DIALECT_COMMONFOLDERS_H #define MLIR_DIALECT_COMMONFOLDERS_H +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Types.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" + +#include +#include #include namespace mlir { @@ -30,11 +36,13 @@ class PoisonAttr; /// Uses `resultType` for the type of the returned attribute. /// Optional PoisonAttr template argument allows to specify 'poison' attribute /// which will be directly propagated to result. -template (ElementValueT, ElementValueT)>> + std::optional(ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, Type resultType, CalculationT &&calculate) { @@ -65,7 +73,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, if (!calRes) return {}; - return AttrElementT::get(resultType, *calRes); + return ResultAttrElementT::get(resultType, *calRes); } if (isa(operands[0]) && @@ -99,7 +107,7 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, return {}; auto lhsIt = *maybeLhsIt; auto rhsIt = *maybeRhsIt; - SmallVector elementResults; + SmallVector elementResults; elementResults.reserve(lhs.getNumElements()); for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) { auto elementResult = calculate(*lhsIt, *rhsIt); @@ -119,11 +127,13 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, /// attribute. /// Optional PoisonAttr template argument allows to specify 'poison' attribute /// which will be directly propagated to result. -template (ElementValueT, ElementValueT)>> + std::optional(ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); @@ -139,64 +149,73 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, return operands[1]; } - auto getResultType = [](Attribute attr) -> Type { + auto getAttrType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); return {}; }; - Type lhsType = getResultType(operands[0]); - Type rhsType = getResultType(operands[1]); + Type lhsType = getAttrType(operands[0]); + Type rhsType = getAttrType(operands[1]); if (!lhsType || !rhsType) return {}; if (lhsType != rhsType) return {}; return constFoldBinaryOpConditional( operands, lhsType, std::forward(calculate)); } template > + function_ref> Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, CalculationT &&calculate) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditional( operands, resultType, - [&](ElementValueT a, ElementValueT b) -> std::optional { - return calculate(a, b); - }); + [&](ElementValueT a, ElementValueT b) + -> std::optional { return calculate(a, b); }); } -template > + function_ref> Attribute constFoldBinaryOp(ArrayRef operands, CalculationT &&calculate) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditional( operands, - [&](ElementValueT a, ElementValueT b) -> std::optional { - return calculate(a, b); - }); + [&](ElementValueT a, ElementValueT b) + -> std::optional { return calculate(a, b); }); } /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. +/// Uses `resultType` for the type of the returned attribute. /// Optional PoisonAttr template argument allows to specify 'poison' attribute /// which will be directly propagated to result. -template (ElementValueT)>> + function_ref(ElementValueT)>> Attribute constFoldUnaryOpConditional(ArrayRef operands, + Type resultType, CalculationT &&calculate) { - if (!llvm::getSingleElement(operands)) + if (!resultType || !llvm::getSingleElement(operands)) return {}; static_assert( @@ -214,7 +233,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, auto res = calculate(op.getValue()); if (!res) return {}; - return AttrElementT::get(op.getType(), *res); + return ResultAttrElementT::get(resultType, *res); } if (isa(operands[0])) { // Both operands are splats so we can avoid expanding the values out and @@ -224,7 +243,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, auto elementResult = calculate(op.getSplatValue()); if (!elementResult) return {}; - return DenseElementsAttr::get(op.getType(), *elementResult); + return DenseElementsAttr::get(cast(resultType), *elementResult); } else if (isa(operands[0])) { // Operands are ElementsAttr-derived; perform an element-wise fold by // expanding the values. @@ -234,7 +253,7 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, if (!maybeOpIt) return {}; auto opIt = *maybeOpIt; - SmallVector elementResults; + SmallVector elementResults; elementResults.reserve(op.getNumElements()); for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { auto elementResult = calculate(*opIt); @@ -242,19 +261,81 @@ Attribute constFoldUnaryOpConditional(ArrayRef operands, return {}; elementResults.push_back(*elementResult); } - return DenseElementsAttr::get(op.getShapedType(), elementResults); + return DenseElementsAttr::get(cast(resultType), elementResults); } return {}; } -template (ElementValueT)>> +Attribute constFoldUnaryOpConditional(ArrayRef operands, + CalculationT &&calculate) { + if (!llvm::getSingleElement(operands)) + return {}; + + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + + auto getAttrType = [](Attribute attr) -> Type { + if (auto typed = dyn_cast_or_null(attr)) + return typed.getType(); + return {}; + }; + + Type operandType = getAttrType(operands[0]); + if (!operandType) + return {}; + + return constFoldUnaryOpConditional( + operands, operandType, std::forward(calculate)); +} + +template > +Attribute constFoldUnaryOp(ArrayRef operands, Type resultType, + CalculationT &&calculate) { + return constFoldUnaryOpConditional( + operands, resultType, + [&](ElementValueT a) -> std::optional { + return calculate(a); + }); +} + +template > + class ResultAttrElementT = AttrElementT, + class ResultElementValueT = typename ResultAttrElementT::ValueType, + class CalculationT = function_ref> Attribute constFoldUnaryOp(ArrayRef operands, CalculationT &&calculate) { - return constFoldUnaryOpConditional( - operands, [&](ElementValueT a) -> std::optional { + return constFoldUnaryOpConditional( + operands, [&](ElementValueT a) -> std::optional { return calculate(a); }); } diff --git a/mlir/test/Dialect/common_folders.mlir b/mlir/test/Dialect/common_folders.mlir new file mode 100644 index 0000000000000..92598b4937552 --- /dev/null +++ b/mlir/test/Dialect/common_folders.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s --test-fold-type-converting-op --split-input-file | FileCheck %s + +// CHECK-LABEL: @test_fold_unary_op_f32_to_si32( +func.func @test_fold_unary_op_f32_to_si32() -> tensor<4x2xsi32> { + // CHECK-NEXT: %[[POSITIVE_ONE:.*]] = arith.constant dense<1> : tensor<4x2xsi32> + // CHECK-NEXT: return %[[POSITIVE_ONE]] : tensor<4x2xsi32> + %operand = arith.constant dense<5.1> : tensor<4x2xf32> + %sign = test.sign %operand : (tensor<4x2xf32>) -> tensor<4x2xsi32> + return %sign : tensor<4x2xsi32> +} + +// ----- + +// CHECK-LABEL: @test_fold_binary_op_f32_to_i1( +func.func @test_fold_binary_op_f32_to_i1() -> tensor<8xi1> { + // CHECK-NEXT: %[[FALSE:.*]] = arith.constant dense : tensor<8xi1> + // CHECK-NEXT: return %[[FALSE]] : tensor<8xi1> + %lhs = arith.constant dense<5.1> : tensor<8xf32> + %rhs = arith.constant dense<4.2> : tensor<8xf32> + %less_than = test.less_than %lhs, %rhs : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xi1> + return %less_than : tensor<8xi1> +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2eaad552a7a3a..f3554e8579d29 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1169,6 +1169,26 @@ def OpP : TEST_Op<"op_p"> { let results = (outs I32); } +// Test constant-folding a pattern that maps `(F32) -> SI32`. +def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> { + let arguments = (ins RankedTensorOf<[F32]>:$operand); + let results = (outs RankedTensorOf<[SI32]>:$result); + + let assemblyFormat = [{ + $operand attr-dict `:` functional-type(operands, results) + }]; +} + +// Test constant-folding a pattern that maps `(F32, F32) -> I1`. +def LessThanOp : TEST_Op<"less_than", [SameOperandsAndResultShape]> { + let arguments = (ins RankedTensorOf<[F32]>:$lhs, RankedTensorOf<[F32]>:$rhs); + let results = (outs RankedTensorOf<[I1]>:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` functional-type(operands, results) + }]; +} + // Test same operand name enforces equality condition check. def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index eda618f5b09c6..1ad1b6163374d 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -10,6 +10,7 @@ #include "TestOps.h" #include "TestTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -202,6 +203,66 @@ struct HoistEligibleOps : public OpRewritePattern { } }; +struct FoldSignOpF32ToSI32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::SignOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return failure(); + + TypedAttr operandAttr; + matchPattern(op->getOperand(0), m_Constant(&operandAttr)); + if (!operandAttr) + return failure(); + + TypedAttr res = cast_or_null( + constFoldUnaryOp( + operandAttr, op.getType(), [](APFloat operand) -> APSInt { + static const APFloat zero(0.0f); + int operandSign = 0; + if (operand != zero) + operandSign = (operand < zero) ? -1 : +1; + return APSInt(APInt(32, operandSign), false); + })); + if (!res) + return failure(); + + rewriter.replaceOpWithNewOp(op, res); + return success(); + } +}; + +struct FoldLessThanOpF32ToI1 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(test::LessThanOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 2 || op->getNumResults() != 1) + return failure(); + + TypedAttr lhsAttr; + TypedAttr rhsAttr; + matchPattern(op->getOperand(0), m_Constant(&lhsAttr)); + matchPattern(op->getOperand(1), m_Constant(&rhsAttr)); + + if (!lhsAttr || !rhsAttr) + return failure(); + + Attribute operandAttrs[2] = {lhsAttr, rhsAttr}; + TypedAttr res = cast_or_null( + constFoldBinaryOp( + operandAttrs, op.getType(), [](APFloat lhs, APFloat rhs) -> APInt { + return APInt(1, lhs < rhs); + })); + if (!res) + return failure(); + + rewriter.replaceOpWithNewOp(op, res); + return success(); + } +}; + /// This pattern moves "test.move_before_parent_op" before the parent op. struct MoveBeforeParentOp : public RewritePattern { MoveBeforeParentOp(MLIRContext *context) @@ -2181,6 +2242,24 @@ struct TestSelectiveReplacementPatternDriver (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; + +struct TestFoldTypeConvertingOp + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldTypeConvertingOp) + + StringRef getArgument() const final { return "test-fold-type-converting-op"; } + StringRef getDescription() const final { + return "Test helper functions for folding ops whose input and output types " + "differ, e.g. float comparisons of the form `(f32, f32) -> i1`."; + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -2211,6 +2290,8 @@ void registerPatternsTestPass() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir