Skip to content

Commit 12dcb89

Browse files
poechselftynse
authored andcommitted
[mlir] [linalg] Only promote selected buffers.
The promotion transformation is promoting all input and output buffers of the transformed op. The user might want to only promote some of these buffers. Differential Revision: https://reviews.llvm.org/D78498
1 parent d881626 commit 12dcb89

File tree

5 files changed

+97
-3
lines changed

5 files changed

+97
-3
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,9 @@ def PreconditionPromoteSubviewsLinalgOp : CPred<
114114
"succeeded(promoteSubviewsLinalgOpPrecondition(op))">;
115115
def PromoteSubviewsLinalgOp : NativeCodeCall<
116116
"promoteSubviewsLinalgOp($_builder, op)">;
117+
118+
class PromoteSelectedSubviewsLinalgOp<list<int> operands, string marker=""> :
119+
NativeCodeCall<"promoteSelectedSubviewsLinalgOpAndSetMarker($_builder, op, {" #
120+
StrJoinInt<operands>.result # "}, \"" # marker # "\")">;
121+
117122
#endif // LINALG_TRANSFORMS

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op);
121121
SmallVector<Value, 0> promoteSubviewsLinalgOp(PatternRewriter &rewriter,
122122
Operation *op);
123123

124+
/// Similar to `promoteSubviewsLinalgOp` but only tries to promote
125+
/// the views corresponding to the operands specified in
126+
/// `operandIndicesToPromote`.
127+
/// If linalgMarker is specified and the transformation is successfull
128+
/// sets the attribute `kLinalgTransformMarker` to `linalgMarker`.
129+
SmallVector<Value, 0> promoteSelectedSubviewsLinalgOpAndSetMarker(
130+
PatternRewriter &rewriter, Operation *op,
131+
ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker = "");
124132
} // namespace linalg
125133
} // namespace mlir
126134

mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,24 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
338338
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
339339
"DRR failure case must be a precondition");
340340

341+
LinalgOp linOp = cast<LinalgOp>(op);
342+
SmallVector<int64_t, 4> toPromote;
343+
int64_t nBuffers = linOp.getNumInputsAndOutputBuffers();
344+
toPromote.reserve(nBuffers);
345+
for (int64_t i = 0; i < nBuffers; ++i)
346+
toPromote.push_back(i);
347+
return promoteSelectedSubviewsLinalgOpAndSetMarker(rewriter, op, toPromote);
348+
}
349+
350+
SmallVector<Value, 0> mlir::linalg::promoteSelectedSubviewsLinalgOpAndSetMarker(
351+
PatternRewriter &rewriter, Operation *op,
352+
ArrayRef<int64_t> operandIndicesToPromote, StringRef linalgMarker) {
353+
LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: "
354+
<< *op << ":\n");
355+
356+
assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) &&
357+
"DRR failure case must be a precondition");
358+
341359
if (auto convOp = dyn_cast<linalg::ConvOp>(op)) {
342360
// TODO(ntv): add a level of indirection to linalg.generic.
343361
if (convOp.padding())
@@ -348,11 +366,16 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
348366
assert(linOp.hasBufferSemantics() &&
349367
"expected linalg op with buffer semantics");
350368
SetVector<Value> subViews;
351-
for (auto it : linOp.getInputsAndOutputBuffers())
352-
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
369+
for (int64_t index : operandIndicesToPromote)
370+
if (auto sv =
371+
dyn_cast_or_null<SubViewOp>(linOp.getBuffer(index).getDefiningOp()))
353372
subViews.insert(sv);
373+
354374
if (!subViews.empty()) {
355-
promoteSubViewOperands(rewriter, linOp, subViews);
375+
auto newOp = promoteSubViewOperands(rewriter, linOp, subViews);
376+
if (!linalgMarker.empty())
377+
newOp.setAttr(LinalgTransforms::kLinalgTransformMarker,
378+
rewriter.getStringAttr(linalgMarker));
356379
return {};
357380
}
358381
llvm_unreachable("DRR failure case must be a precondition");

mlir/test/Dialect/Linalg/transform-patterns.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,53 @@ func @promote_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
395395
// CHECK : linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
396396
// CHECK : linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
397397
// CHECK : linalg.matmul(%[[v0]], %[[v1]], %[[v2]]) : memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>, memref<?x?xf32, #[[STRIDED_2D]]>
398+
399+
func @promote_first_subview_matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
400+
%arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
401+
%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
402+
%c2000 = constant 2000 : index
403+
%c3000 = constant 3000 : index
404+
%c4000 = constant 4000 : index
405+
%c0 = constant 0 : index
406+
%c1 = constant 1 : index
407+
%0 = dim %arg0, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
408+
%1 = dim %arg0, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
409+
%2 = dim %arg1, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
410+
loop.for %arg3 = %c0 to %0 step %c2000 {
411+
loop.for %arg4 = %c0 to %2 step %c3000 {
412+
loop.for %arg5 = %c0 to %1 step %c4000 {
413+
%3 = std.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] :
414+
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
415+
%4 = std.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] :
416+
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
417+
%5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
418+
memref<?x?xf32, offset: ?, strides: [?, 1]> to memref<?x?xf32, offset: ?, strides: [?, ?]>
419+
linalg.matmul(%3, %4, %5) {__internal_linalg_transform__ = "_promote_first_view_"} :
420+
memref<?x?xf32, offset: ?, strides: [?, ?]>,
421+
memref<?x?xf32, offset: ?, strides: [?, ?]>,
422+
memref<?x?xf32, offset: ?, strides: [?, ?]>
423+
}
424+
}
425+
}
426+
return
427+
}
428+
// CHECK-LABEL: func @promote_first_subview_matmul
429+
// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c2000 {
430+
// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c3000 {
431+
// CHECK: loop.for {{.*}} = %c0 to {{.*}} step %c4000 {
432+
// CHECK: %[[s0:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
433+
// CHECK: %[[s1:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
434+
// CHECK: %[[s2:.*]] = subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
435+
// CHECK: %[[a0:.*]] = alloc({{%.*}}) : memref<?xi8>
436+
// CHECK: %[[v0:.*]] = std.view %[[a0]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
437+
// CHECK: %[[l0:.*]] = subview %[[v0]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map:.*]]>
438+
// CHECK-NOT: %[[a1:.*]] = alloc({{%.*}}) : memref<?xi8>
439+
// CHECK-NOT: %[[v1:.*]] = std.view %[[a1]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
440+
// CHECK-NOT: %[[l0:.*]] = subview %[[v1]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
441+
// CHECK-NOT: %[[a2:.*]] = alloc({{%.*}}) : memref<?xi8>
442+
// CHECK-NOT: %[[v2:.*]] = std.view %[[a2]][][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
443+
// CHECK-NOT: %[[l0:.*]] = subview %[[v2]][{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32> to memref<?x?xf32, #[[map]]>
444+
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
445+
// CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>
446+
// CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref<?x?xf32, #map{{.*}}>, memref<?x?xf32, #map{{.*}}>^
447+
// CHECK: linalg.matmul(%[[v0]], %[[s1]], %[[s2]]) : memref<?x?xf32>, memref<?x?xf32, #[[map]]>, memref<?x?xf32, #[[map]]>

mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,12 @@ def : Pat<(MatmulOp:$op $_, $_, $_),
149149
HasLinalgTransformMarker<"_promote_views_">]>>
150150
)]>;
151151

152+
def : Pat<(MatmulOp:$op $_, $_, $_),
153+
(PromoteSelectedSubviewsLinalgOp<[0], "first_view_promotion">),
154+
[(Constraint<And<[
155+
PreconditionPromoteSubviewsLinalgOp,
156+
HasOperandsOfType<"SubViewOp">,
157+
HasLinalgTransformMarker<"_promote_first_view_">]>>
158+
)]>;
159+
152160
#endif // TEST_LINALG_TRANSFORMS_PATTERNS

0 commit comments

Comments
 (0)