Skip to content

Commit 009bc91

Browse files
committed
ability to use poison as padding value
1 parent 70af09e commit 009bc91

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1212
#include "mlir/Dialect/Complex/IR/Complex.h"
1313
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1415
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1516
#include "mlir/IR/AffineExpr.h"
1617
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
230231
Value paddingValue;
231232
if (auto complexTy =
232233
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
233-
auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
234-
paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
235-
complexTy, complexAttr);
236-
} else {
237-
paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
238-
cast<TypedAttr>(paddingValueAttr));
234+
if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
235+
paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
236+
complexTy, complexAttr);
237+
}
238+
} else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
239+
paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
240+
getElementTypeOrSelf(v.getType()));
241+
} else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
242+
paddingValue =
243+
arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
239244
}
245+
assert(paddingValue && "failed to create value from padding attribute");
240246

241247
// Pad the operand to the bounding box defined by `paddedShape`.
242248
SmallVector<int64_t> tensorShape;

mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
55
func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
66
{
7+
// %goo = ub.poison : f32
78
%0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
89
func.return %0 : tensor<24x25xf32>
910
}
@@ -18,7 +19,8 @@ module attributes {transform.with_named_sequence} {
1819
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
1920

2021
%fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
21-
padding_values=[0.0 : f32, 0.0 : f32]
22+
// padding_values= [poison, 0.0 : f32]
23+
padding_values= [0.0 : f32, 0.0 : f32]
2224
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
2325

2426
transform.yield

0 commit comments

Comments
 (0)