Skip to content

Commit d314b7d

Browse files
committed
[MLIR] ShapedType accessor minor fixes + add isDynamicDim accessor
Minor fixes and cleanup for ShapedType accessors, use ShapedType::kDynamicSize, add ShapedType::isDynamicDim. Differential Revision: https://reviews.llvm.org/D77710
1 parent e7db1ae commit d314b7d

File tree

6 files changed

+29
-21
lines changed

6 files changed

+29
-21
lines changed

mlir/include/mlir/IR/StandardTypes.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,11 @@ class ShapedType : public Type {
252252

253253
/// If this is ranked type, return the size of the specified dimension.
254254
/// Otherwise, abort.
255-
int64_t getDimSize(int64_t i) const;
255+
int64_t getDimSize(unsigned idx) const;
256+
257+
/// Returns true if this dimension has a dynamic size (for ranked types);
258+
/// aborts for unranked types.
259+
bool isDynamicDim(unsigned idx) const;
256260

257261
/// Returns the position of the dynamic dimension relative to just the dynamic
258262
/// dimensions, given its `index` within the shape.
@@ -276,7 +280,9 @@ class ShapedType : public Type {
276280
}
277281

278282
/// Whether the given dimension size indicates a dynamic dimension.
279-
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
283+
static constexpr bool isDynamic(int64_t dSize) {
284+
return dSize == kDynamicSize;
285+
}
280286
static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) {
281287
return dStrideOrOffset == kDynamicStrideOrOffset;
282288
}

mlir/lib/Analysis/Utils.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,10 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
330330
if (addMemRefDimBounds) {
331331
auto memRefType = memref.getType().cast<MemRefType>();
332332
for (unsigned r = 0; r < rank; r++) {
333-
cst.addConstantLowerBound(r, 0);
334-
int64_t dimSize = memRefType.getDimSize(r);
335-
if (ShapedType::isDynamic(dimSize))
333+
cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0);
334+
if (memRefType.isDynamicDim(r))
336335
continue;
337-
cst.addConstantUpperBound(r, dimSize - 1);
336+
cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1);
338337
}
339338
}
340339
cst.removeTrivialRedundancy();

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,16 +1888,15 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
18881888
OperandAdaptor<DimOp> transformed(operands);
18891889
MemRefType type = dimOp.getOperand().getType().cast<MemRefType>();
18901890

1891-
auto shape = type.getShape();
18921891
int64_t index = dimOp.getIndex();
18931892
// Extract dynamic size from the memref descriptor.
1894-
if (ShapedType::isDynamic(shape[index]))
1893+
if (type.isDynamicDim(index))
18951894
rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor())
18961895
.size(rewriter, op->getLoc(), index)});
18971896
else
18981897
// Use constant for static size.
1899-
rewriter.replaceOp(
1900-
op, createIndexConstant(rewriter, op->getLoc(), shape[index]));
1898+
rewriter.replaceOp(op, createIndexConstant(rewriter, op->getLoc(),
1899+
type.getDimSize(index)));
19011900
return success();
19021901
}
19031902
};

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,
133133
unsigned index) {
134134
auto memRefType = memrefDefOp.getType();
135135
// Statically shaped.
136-
if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
136+
if (!memRefType.isDynamicDim(index))
137137
return true;
138138
// Get the position of the dimension among dynamic dimensions;
139139
unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,14 +1068,14 @@ static LogicalResult verify(DimOp op) {
10681068
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
10691069
// Constant fold dim when the size along the index referred to is a constant.
10701070
auto opType = memrefOrTensor().getType();
1071-
int64_t indexSize = ShapedType::kDynamicSize;
1071+
int64_t dimSize = ShapedType::kDynamicSize;
10721072
if (auto tensorType = opType.dyn_cast<RankedTensorType>())
1073-
indexSize = tensorType.getShape()[getIndex()];
1073+
dimSize = tensorType.getShape()[getIndex()];
10741074
else if (auto memrefType = opType.dyn_cast<MemRefType>())
1075-
indexSize = memrefType.getShape()[getIndex()];
1075+
dimSize = memrefType.getShape()[getIndex()];
10761076

1077-
if (!ShapedType::isDynamic(indexSize))
1078-
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
1077+
if (!ShapedType::isDynamic(dimSize))
1078+
return IntegerAttr::get(IndexType::get(getContext()), dimSize);
10791079

10801080
// Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp.
10811081
auto memrefType = opType.dyn_cast<MemRefType>();
@@ -2310,13 +2310,12 @@ Value ViewOp::getDynamicOffset() {
23102310

23112311
static LogicalResult verifyDynamicStrides(MemRefType memrefType,
23122312
ArrayRef<int64_t> strides) {
2313-
ArrayRef<int64_t> shape = memrefType.getShape();
23142313
unsigned rank = memrefType.getRank();
23152314
assert(rank == strides.size());
23162315
bool dynamicStrides = false;
23172316
for (int i = rank - 2; i >= 0; --i) {
23182317
// If size at dim 'i + 1' is dynamic, set the 'dynamicStrides' flag.
2319-
if (ShapedType::isDynamic(shape[i + 1]))
2318+
if (memrefType.isDynamicDim(i + 1))
23202319
dynamicStrides = true;
23212320
// If stride at dim 'i' is not dynamic, return error.
23222321
if (dynamicStrides && strides[i] != MemRefType::getDynamicStrideOrOffset())

mlir/lib/IR/StandardTypes.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,14 @@ int64_t ShapedType::getRank() const { return getShape().size(); }
184184

185185
bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
186186

187-
int64_t ShapedType::getDimSize(int64_t i) const {
188-
assert(i >= 0 && i < getRank() && "invalid index for shaped type");
189-
return getShape()[i];
187+
int64_t ShapedType::getDimSize(unsigned idx) const {
188+
assert(idx < getRank() && "invalid index for shaped type");
189+
return getShape()[idx];
190+
}
191+
192+
bool ShapedType::isDynamicDim(unsigned idx) const {
193+
assert(idx < getRank() && "invalid index for shaped type");
194+
return isDynamic(getShape()[idx]);
190195
}
191196

192197
unsigned ShapedType::getDynamicDimIndex(unsigned index) const {

0 commit comments

Comments
 (0)