Skip to content

Commit 2140a97

Browse files
[mlir][Linalg] Extend generic ops to allow tensors
Summary: This diff adds support to allow `linalg.generic` and `linalg.indexed_generic` to take tensor input and output arguments. The subset of output tensor operand types must appear verbatim in the result types after an arrow. The parser, printer and verifier are extended to accomodate this behavior. The Linalg operations now support variadic ranked tensor return values. This extension exhibited issues with the current handling of NativeCall in RewriterGen.cpp. As a consequence, an explicit cast to `SmallVector<Value, 4>` is added in the proper place to support the new behavior (better suggestions are welcome). Relevant cleanups and name uniformization are applied. Relevant invalid and roundtrip test are added. Reviewers: mehdi_amini, rriddle, jpienaar, antiagainst, ftynse Subscribers: burmako, shauheen, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72022
1 parent 9d49e5c commit 2140a97

File tree

16 files changed

+408
-135
lines changed

16 files changed

+408
-135
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ def Linalg_Dialect : Dialect {
1919
let name = "linalg";
2020
let description = [{
2121
The `linalg` dialect groups together a set of types, operations and
22-
transformations that are useful to implement a structured abstraction where
23-
ops can lower to scalar load/store and operations or to more general library
24-
calls.
22+
transformations that are useful to implement a structured abstraction on
23+
buffers and tensors. These abstractions are useful for transformations and
24+
can lower to scalar load/store and other operations or to more general
25+
library calls.
2526

2627
The `linalg` dialect manipulates the following types and operations:
2728

@@ -67,12 +68,13 @@ def Linalg_Dialect : Dialect {
6768
A set of payload carrying operations that implement the [structured ops](
6869
https://docs.google.com/presentation/d/1P-j1GrH6Q5gLBjao0afQ-GfvcAeF-QU4GXXeSy0eJ9I/edit#slide=id.p
6970
)
70-
abstraction on buffers. `linalg` has `2` generic operations `linalg.generic`
71-
and `linalg.indexed_generic` for expressing custom operations. This is
72-
subject to further evolution as transformations and analyses continue to be
73-
developed.
71+
abstraction on tensors and buffers. `linalg` has `2` generic operations
72+
`linalg.generic` and `linalg.indexed_generic` for expressing custom
73+
operations.
74+
This is subject to further evolution as transformations and analyses
75+
continue to be developed.
7476

75-
Additionally, `linalg` provides some common named operations:
77+
Additionally, `linalg` provides some commonly named operations:
7678

7779
* `linalg.copy`,
7880
* `linalg.fill`,

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def Linalg_RangeOp :
5959
}
6060

6161
def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
62-
Arguments<(ins AnyStridedMemRef:$view, Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,
62+
Arguments<(ins AnyStridedMemRef:$view,
63+
Variadic<AnyTypeOf<[Range, Index]>>:$indexings)>,
6364
Results<(outs AnyStridedMemRef)> {
6465
let summary = "Produce a rank-reduced `subview` of a base `view`.";
6566
let description = [{
@@ -108,11 +109,11 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
108109

109110
let extraClassDeclaration = [{
110111
enum { FirstIndexingOperand = 1 };
111-
unsigned getRank() { return getViewType().getRank(); }
112-
Type getElementType() { return getViewType().getElementType(); }
113-
MemRefType getViewType() { return getType().cast<MemRefType>(); }
112+
unsigned getRank() { return getShapedType().getRank(); }
113+
Type getElementType() { return getShapedType().getElementType(); }
114+
ShapedType getShapedType() { return getType().cast<ShapedType>(); }
114115
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
115-
MemRefType getBaseViewType() { return view()->getType().cast<MemRefType>(); }
116+
ShapedType getBaseViewType() { return view()->getType().cast<ShapedType>();}
116117

117118
// Get the underlying indexing at a given rank.
118119
Value indexing(unsigned rank) { return *(indexings().begin() + rank); }
@@ -131,7 +132,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
131132
def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
132133
Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>,
133134
Results<(outs AnyStridedMemRef)> {
134-
let summary = "transpose operation produces a new strided memref (metadata-only)";
135+
let summary = "`transpose` produces a new strided memref (metadata-only)";
135136
let description = [{
136137
The `linalg.transpose` op produces a strided memref whose sizes and strides
137138
are a permutation of the original `view`. This is a pure metadata
@@ -151,14 +152,14 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
151152
let verifier = [{
152153
if (!permutation().isPermutation())
153154
return emitOpError("expected a permutation map");
154-
if (permutation().getNumDims() != getViewType().getRank())
155+
if (permutation().getNumDims() != getShapedType().getRank())
155156
return emitOpError("expected a permutation map of same rank as the view");
156157
return success();
157158
}];
158159

159160
let extraClassDeclaration = [{
160161
static StringRef getPermutationAttrName() { return "permutation"; }
161-
MemRefType getViewType() { return view()->getType().cast<MemRefType>(); }
162+
ShapedType getShapedType() { return view()->getType().cast<ShapedType>(); }
162163
}];
163164
}
164165

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,32 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
8989
"Value ", "getOutput", (ins "unsigned":$i)
9090
>,
9191
InterfaceMethod<[{
92-
Query the index of the given input value, or `None` if the value is not
93-
an input.
92+
Return the index of the given input value `v`, or `None` if the value is
93+
not an input.
9494
}],
95-
"llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$view)
95+
"llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v)
9696
>,
9797
InterfaceMethod<[{
9898
Query the index of the given view value, or `None` if the value is not
99-
an view.
99+
a view.
100100
}],
101101
"llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
102102
>,
103103
InterfaceMethod<[{
104-
Query the type of the input view at the given index.
105-
}], "MemRefType", "getInputViewType", (ins "unsigned":$i)>,
104+
Query the type of the input shape at the given index.
105+
}], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>,
106106
InterfaceMethod<[{
107107
Query the type of the output view at the given index.
108-
}], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>,
108+
}], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
109+
InterfaceMethod<[{
110+
Query whether the op has only MemRef input and outputs.
111+
}], "bool", "hasBufferSemantics">,
112+
InterfaceMethod<[{
113+
Query the subset of input operands that are of ranked tensor type.
114+
}], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
115+
InterfaceMethod<[{
116+
Query the subset of output operands that are of ranked tensor type.
117+
}], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
109118

110119
StaticInterfaceMethod<[{
111120
Create an operation of the current type with the given ___location,
@@ -340,7 +349,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
340349
ArrayAttr iterator_types() {
341350
// Outer parallel loops are always the number of output dimensions; i.e.
342351
// [ b, xs, q] in the TF notation above.
343-
unsigned nPar = getOutputViewType(0).getRank();
352+
unsigned nPar = getOutputShapedType(0).getRank();
344353
unsigned nRed = getNumInputFeatureDimensions();
345354
// Window loops are a special kind of reduction that is never tiled or
346355
// parallelized across; i.e. [zs] in the TF notation above whose number
@@ -374,15 +383,25 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
374383
let verifier = [{ return ::verify(*this); }];
375384
}
376385

386+
def LinalgOperand: Type<
387+
Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>;
388+
389+
class LinalgOperandOfRank<int rank>: Type<
390+
And<[
391+
LinalgOperand.predicate,
392+
CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
393+
>>;
394+
377395
class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
378-
let arguments = (ins Variadic<AnyStridedMemRef>:$views,
396+
let arguments = (ins Variadic<LinalgOperand>:$views,
379397
I64Attr:$args_in,
380398
I64Attr:$args_out,
381399
AffineMapArrayAttr:$indexing_maps,
382400
ArrayAttr:$iterator_types,
383401
OptionalAttr<StrAttr>:$doc,
384402
OptionalAttr<FlatSymbolRefAttr>:$fun,
385403
OptionalAttr<StrAttr>:$library_call);
404+
let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
386405
let regions = (region AnyRegion:$region);
387406
let extraClassDeclaration = [{
388407
SmallVector<StringRef, 8> linalgTraitAttrNames() {
@@ -511,6 +530,28 @@ def GenericOp : GenericOpBase<"generic"> {
511530
}
512531
}
513532
```
533+
534+
To allow progressive lowering from the value world (a.k.a tensor values) to
535+
the buffer world (a.k.a memref values), a `linalg.generic` op accepts
536+
mixing input and output ranked tensor values with input and output memrefs.
537+
538+
```mlir
539+
%1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
540+
tensor<?x?xf32>,
541+
memref<?x?xf32, stride_specification>,
542+
tensor<?x?xf32>
543+
-> (tensor<?x?xf32>)
544+
```
545+
546+
In this case, the number of return values must match the number of output
547+
tensor arguments. The semantics is that the `linalg.generic` op
548+
produces (i.e. allocates and fills) its return values.
549+
Tensor values must be legalized by a buffer allocation pass before most
550+
transformations can be applied. In particular, transformations that create
551+
control flow around linalg.generic operations are not expected to mix with
552+
tensors because SSA values do not escape naturally. Still, transformations
553+
and rewrites that take advantage of tensor SSA values are expected to be
554+
useful and will be added in the near future.
514555
}];
515556
let verifier = [{ return ::verify(*this); }];
516557
}
@@ -555,9 +596,11 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
555596
Example:
556597
Defining a #matmul_trait attribute in MLIR can be done as follows:
557598
```mlir
558-
func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
599+
func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
600+
%a: f32, %b: f32, %c: f32)
559601
-> f32
560602
{
603+
"some_optional_condition"(%offset_m, %offset_n, %offset_k)
561604
%d = mulf %a, %b: f32
562605
%e = addf %c, %d: f32
563606
return %e: f32
@@ -587,7 +630,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
587630

588631
This may lower to either:
589632
```mlir
590-
call @linalg_matmul(%A, %B, %C) :
633+
call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) :
591634
(memref<?x?xf32, stride_specification>,
592635
memref<?x?xf32, stride_specification>,
593636
memref<?x?xf32, stride_specification>)
@@ -609,6 +652,29 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
609652
}
610653
}
611654
```
655+
656+
To allow progressive lowering from the value world (a.k.a tensor values) to
657+
the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
658+
accepts mixing input and output ranked tensor values with input and output
659+
memrefs.
660+
661+
```mlir
662+
%1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes}
663+
: tensor<?x?xf32>,
664+
memref<?x?xf32, stride_specification>,
665+
tensor<?x?xf32>
666+
-> (tensor<?x?xf32>)
667+
```
668+
669+
In this case, the number of return values must match the number of output
670+
tensor arguments. The semantics is that the `linalg.indexed_generic` op
671+
produces (i.e. allocates and fills) its return values.
672+
Tensor values must be legalized by a buffer allocation pass before most
673+
transformations can be applied. In particular, transformations that create
674+
control flow around linalg.generic operations are not expected to mix with
675+
tensors because SSA values do not escape naturally. Still, transformations
676+
and rewrites that take advantage of tensor SSA values are expected to be
677+
useful and will be added in the near future.
612678
}];
613679
let verifier = [{ return ::verify(*this); }];
614680
}

0 commit comments

Comments
 (0)