Skip to content

Commit 1555d7f

Browse files
Tobias Gysiftynse
authored andcommitted
[mlir] subview op lowering for target memrefs with const offset
The current standard to llvm conversion pass lowers subview ops only if dynamic offsets are provided. This commit extends the lowering with a code path that uses the constant offset of the target memref for the subview op lowering (see Example 3 of the subview op definition for an example) if no dynamic offsets are provided. Differential Revision: https://reviews.llvm.org/D74280
1 parent ed3527c commit 1555d7f

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,7 +2304,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
23042304
// Currently, only rank > 0 and full or no operands are supported. Fail to
23052305
// convert otherwise.
23062306
unsigned rank = sourceMemRefType.getRank();
2307-
if (viewMemRefType.getRank() == 0 || (rank != dynamicOffsets.size()) ||
2307+
if (viewMemRefType.getRank() == 0 ||
2308+
(!dynamicOffsets.empty() && rank != dynamicOffsets.size()) ||
23082309
(!dynamicSizes.empty() && rank != dynamicSizes.size()) ||
23092310
(!dynamicStrides.empty() && rank != dynamicStrides.size()))
23102311
return matchFailure();
@@ -2315,6 +2316,11 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
23152316
if (failed(successStrides))
23162317
return matchFailure();
23172318

2319+
// Fail to convert if neither a dynamic nor static offset is available.
2320+
if (dynamicOffsets.empty() &&
2321+
offset == MemRefType::getDynamicStrideOrOffset())
2322+
return matchFailure();
2323+
23182324
// Create the descriptor.
23192325
MemRefDescriptor sourceMemRef(operands.front());
23202326
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@@ -2348,14 +2354,18 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
23482354
}
23492355

23502356
// Offset.
2351-
Value baseOffset = sourceMemRef.offset(rewriter, loc);
2352-
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
2353-
Value min = dynamicOffsets[i];
2354-
baseOffset = rewriter.create<LLVM::AddOp>(
2355-
loc, baseOffset,
2356-
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
2357+
if (dynamicOffsets.empty()) {
2358+
targetMemRef.setConstantOffset(rewriter, loc, offset);
2359+
} else {
2360+
Value baseOffset = sourceMemRef.offset(rewriter, loc);
2361+
for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
2362+
Value min = dynamicOffsets[i];
2363+
baseOffset = rewriter.create<LLVM::AddOp>(
2364+
loc, baseOffset,
2365+
rewriter.create<LLVM::MulOp>(loc, min, strideValues[i]));
2366+
}
2367+
targetMemRef.setOffset(rewriter, loc, baseOffset);
23572368
}
2358-
targetMemRef.setOffset(rewriter, loc, baseOffset);
23592369

23602370
// Update sizes and strides.
23612371
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,31 @@ func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4
815815
return
816816
}
817817

818+
// CHECK-LABEL: func @subview_const_stride_and_offset(
819+
func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) {
820+
// The last "insertvalue" that populates the memref descriptor from the function arguments.
821+
// CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1]
822+
823+
// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
824+
// CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
825+
// CHECK: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
826+
// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
827+
// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
828+
// CHECK: %[[CST62:.*]] = llvm.mlir.constant(62 : i64)
829+
// CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i64)
830+
// CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : index)
831+
// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
832+
// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
833+
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i64)
834+
// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
835+
// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
836+
// CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64)
837+
// CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
838+
%1 = subview %0[][][] :
839+
memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>>
840+
return
841+
}
842+
818843
// -----
819844

820845
module {

0 commit comments

Comments
 (0)