|
11 | 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
12 | 12 | #include "mlir/Dialect/Complex/IR/Complex.h"
|
13 | 13 | #include "mlir/Dialect/Tensor/IR/Tensor.h"
|
| 14 | +#include "mlir/Dialect/UB/IR/UBOps.h" |
14 | 15 | #include "mlir/Dialect/Utils/StaticValueUtils.h"
|
15 | 16 | #include "mlir/IR/AffineExpr.h"
|
16 | 17 | #include "mlir/IR/BuiltinAttributes.h"
|
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
|
230 | 231 | Value paddingValue;
|
231 | 232 | if (auto complexTy =
|
232 | 233 | dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
|
233 |
| - auto complexAttr = cast<ArrayAttr>(paddingValueAttr); |
234 |
| - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), |
235 |
| - complexTy, complexAttr); |
236 |
| - } else { |
237 |
| - paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), |
238 |
| - cast<TypedAttr>(paddingValueAttr)); |
| 234 | + if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) { |
| 235 | + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), |
| 236 | + complexTy, complexAttr); |
| 237 | + } |
| 238 | + } else if (isa<ub::PoisonAttr>(paddingValueAttr)) { |
| 239 | + paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), |
| 240 | + getElementTypeOrSelf(v.getType())); |
| 241 | + } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) { |
| 242 | + paddingValue = |
| 243 | + arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); |
239 | 244 | }
|
| 245 | + assert(paddingValue && "failed to create value from padding attribute"); |
240 | 246 |
|
241 | 247 | // Pad the operand to the bounding box defined by `paddedShape`.
|
242 | 248 | SmallVector<int64_t> tensorShape;
|
|
0 commit comments