From 009bc910a2851b2c02281d3fe9d16994ecdb4ec2 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 4 Aug 2025 09:25:18 -0700 Subject: [PATCH] ability to use poison as padding value --- .../Linalg/Transforms/PadTilingInterface.cpp | 18 ++++++++++++------ .../transform-op-pad-tiling-interface.mlir | 4 +++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2e6252336dfeb..3d12bc397813b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" @@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, Value paddingValue; if (auto complexTy = dyn_cast(getElementTypeOrSelf(v.getType()))) { - auto complexAttr = cast(paddingValueAttr); - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), - complexTy, complexAttr); - } else { - paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), - cast(paddingValueAttr)); + if (auto complexAttr = dyn_cast(paddingValueAttr)) { + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + complexTy, complexAttr); + } + } else if (isa(paddingValueAttr)) { + paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), + getElementTypeOrSelf(v.getType())); + } else if (auto typedAttr = dyn_cast(paddingValueAttr)) { + paddingValue = + arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); } + assert(paddingValue && "failed to create value from padding attribute"); // Pad the operand to the bounding box defined by `paddedShape`. SmallVector tensorShape; diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir index f7418769f79ca..2857b53103779 100644 --- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir @@ -4,6 +4,7 @@ // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32> func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32> { + // %goo = ub.poison : f32 %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> } @@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] { - padding_values=[0.0 : f32, 0.0 : f32] + // padding_values= [poison, 0.0 : f32] + padding_values= [0.0 : f32, 0.0 : f32] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield