Skip to content

Commit 882ba48

Browse files
[mlir][Linalg] Create a tool to generate named Linalg ops from a Tensor Comprehensions-like specification.
Summary: This revision adds a tool that generates the ODS and C++ implementation for "named" Linalg ops according to the [RFC discussion](https://llvm.discourse.group/t/rfc-declarative-named-ops-in-the-linalg-dialect/745). While the mechanisms and language aspects are by no means set in stone, this revision allows connecting the pieces end-to-end from a mathematical-like specification. Some implementation details and short-term decisions taken for the purpose of bootstrapping and that are not set in stone include: 1. using a "[Tensor Comprehension](https://arxiv.org/abs/1802.04730)-inspired" syntax 2. implicit and eager discovery of dims and symbols when parsing 3. using EDSC ops to specify the computation (e.g. std_addf, std_mul_f, ...) A followup revision will connect this tool to tablegen mechanisms and allow the emission of named Linalg ops that automatically lower to various loop forms and run end to end. For the following "Tensor Comprehension-inspired" string: ``` def batch_matmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n))); } ``` With -gen-ods-decl=1, this emits (modulo formatting): ``` def batch_matmulOp : LinalgNamedStructured_Op<"batch_matmul", [ NInputs<2>, NOutputs<1>, NamedStructuredOpTraits]> { let arguments = (ins Variadic<LinalgOperand>:$views); let results = (outs Variadic<AnyRankedTensor>:$output_tensors); let extraClassDeclaration = [{ llvm::Optional<SmallVector<StringRef, 8>> referenceIterators(); llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps(); void regionBuilder(ArrayRef<BlockArgument> args); }]; let hasFolder = 1; } ``` With -gen-ods-impl, this emits (modulo formatting): ``` llvm::Optional<SmallVector<StringRef, 8>> batch_matmul::referenceIterators() { return SmallVector<StringRef, 8>{ getParallelIteratorTypeName(), getParallelIteratorTypeName(), getParallelIteratorTypeName(), getReductionIteratorTypeName() }; } llvm::Optional<SmallVector<AffineMap, 8>> batch_matmul::referenceIndexingMaps() { MLIRContext *context = getContext(); AffineExpr d0, d1, d2, d3; bindDims(context, d0, d1, d2, d3); return SmallVector<AffineMap, 8>{ AffineMap::get(4, 0, {d0, d1, d3}), AffineMap::get(4, 0, {d3, d2}), AffineMap::get(4, 0, {d0, d1, d2}) }; } void batch_matmul::regionBuilder(ArrayRef<BlockArgument> args) { using namespace edsc; using namespace intrinsics; ValueHandle _0(args[0]), _1(args[1]), _2(args[2]); ValueHandle _4 = std_mulf(_0, _1); ValueHandle _5 = std_addf(_2, _4); (linalg_yield(ValueRange{ _5 })); } ``` Differential Revision: https://reviews.llvm.org/D77067
1 parent a04ab2e commit 882ba48

File tree

10 files changed

+1851
-4
lines changed

10 files changed

+1851
-4
lines changed

mlir/docs/Dialects/Linalg.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,93 @@ from a description in terms of only the generic op interface.
451451
This is the main reason there are only a small number of ops today: we expect
452452
them to be auto-generated from Tablegen soon.
453453

454+
### Named Payload Ops Specification
455+
456+
Linalg provides a declarative specification and a generation tool
457+
(`mlir-linalg-ods-gen`) to automatically produce named ops from a notation that
458+
is inspired by Einstein notation.
459+
460+
The syntax and semantics used in `mlir-linalg-ods-gen` are very much in flight
461+
and borrow from Tensor Comprehensions (TC) but differ in a few dimensions, to
462+
better adapt to Linalg:
463+
464+
1. The input and output tensor parameters are specified as `id :
465+
type(symbolic-affine-expression-list)` (e.g. `A : f32(M, N + M)`) and each
466+
new symbol is discovered eagerly. TC on the other hand does not allow
467+
general symbolic affine expressions.
468+
1. The output shapes are specified explicitly, in TC they are always derived
469+
from the input shapes.
470+
1. The operations used to specify computations use EDSC intrinsics so that they
471+
can easily be parsed and emitted into a simple region builder without
472+
resorting to more general MLIR parsing.
473+
1. Reduction dimensions are specified with angle bracket notation on the
474+
operation they apply to (e.g. `std_add<k>` specifies that `k` is a reduction
475+
dimension). In TC, a reduction is specified with `op=` operator and the
476+
reduction dimensions are inferred.
477+
1. The parallel and reduction dimension are ordered by the textual program
478+
order. For instance, in the comprehension `O(i, j) = std_add<k, l>(...)`,
479+
`i` (resp. `j`) is a parallel iterator encoded by affine dimension of
480+
position `0` (resp. `1`); `k` (resp. `l`) is a reduction iterator encoded by
481+
an affine dimension of position `2` (resp. `3`).
482+
483+
These decisions and syntax are subject to evolution and change. In particular,
484+
op-specific attributes, dynamic ranks, some form of templating, shape
485+
calculation function specification, etc. may be added in the future.
486+
487+
At this time, the following restrictions are imposed on the syntax and
488+
semantics:
489+
490+
1. Each def may only contain a single comprehension but each comprehension may
491+
perform multiple updates.
492+
2. Each tensor may only be used with a single indexing expression.
493+
494+
The following specification may be used to define a named `batchmatmul` op:
495+
496+
```
497+
def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
498+
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
499+
}
500+
```
501+
502+
When `mlir-linalg-ods-gen -gen-ods-decl=1` is called, the following ODS is
503+
produced:
504+
505+
```
506+
def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
507+
NInputs<2>,
508+
NOutputs<1>,
509+
NamedStructuredOpTraits]> { ... }
510+
```
511+
512+
When `mlir-linalg-ods-gen -gen-impl=1` is called, the following C++ is produced:
513+
514+
```
515+
llvm::Optional<SmallVector<StringRef, 8>> batchmatmul::referenceIterators() {
516+
return SmallVector<StringRef, 8>{
517+
getParallelIteratorTypeName(),
518+
getParallelIteratorTypeName(),
519+
getParallelIteratorTypeName(),
520+
getReductionIteratorTypeName() };
521+
}
522+
llvm::Optional<SmallVector<AffineMap, 8>> batchmatmul::referenceIndexingMaps() {
523+
MLIRContext *context = getContext();
524+
AffineExpr d0, d1, d2, d3;
525+
bindDims(context, d0, d1, d2, d3);
526+
return SmallVector<AffineMap, 8>{
527+
AffineMap::get(4, 0, {d0, d1, d3}),
528+
AffineMap::get(4, 0, {d3, d2}),
529+
AffineMap::get(4, 0, {d0, d1, d2}) };
530+
}
531+
void batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
532+
using namespace edsc;
533+
using namespace intrinsics;
534+
ValueHandle _0(args[0]), _1(args[1]), _2(args[2]);
535+
ValueHandle _4 = std_mulf(_0, _1);
536+
ValueHandle _5 = std_addf(_2, _4);
537+
(linalg_yield(ValueRange{ _5 }));
538+
}
539+
```
540+
454541
## Open Issues and Design Alternatives<a name="open_issues"></a>
455542
Multiple open issues and design alternatives are in flight and it is time to
456543
lay them out for the community to discuss and pick apart:

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
256256
/// OptionalAttr<I64ArrayAttr>:$strides
257257
/// OptionalAttr<I64ArrayAttr>:$dilations
258258
/// OptionalAttr<I64ElementsAttr>:$padding
259-
/// `strides` denotes the step of each window along the dimension.
259+
/// `stirdes` denotes the step of each window along the dimension.
260260
class PoolingBase_Op<string mnemonic, list<OpTrait> props>
261261
: LinalgStructured_Op<mnemonic, props> {
262262
let description = [{
@@ -821,4 +821,18 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
821821
let hasFolder = 1;
822822
}
823823

824+
//===----------------------------------------------------------------------===//
825+
// Named Linalg ops, implemented as a declarative configurations of generic ops.
826+
//===----------------------------------------------------------------------===//
827+
828+
def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
829+
830+
class LinalgNamedStructured_Op<string mnemonic, list<OpTrait> props>
831+
: Op<Linalg_Dialect, mnemonic,
832+
!listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> {
833+
string spec = ?;
834+
let assemblyFormat = "`(` operands `)` attr-dict `:` "
835+
"functional-type(operands, results)";
836+
}
837+
824838
#endif // LINALG_STRUCTURED_OPS

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
219219
ArrayRef<AffineExpr> localExprs,
220220
MLIRContext *context);
221221

222-
raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr);
222+
raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
223223

224224
template <typename U> bool AffineExpr::isa() const {
225225
if (std::is_same<U, AffineBinaryOpExpr>::value)

mlir/lib/IR/AffineExpr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ AffineExpr AffineExpr::compose(AffineMap map) const {
613613
map.getResults().end());
614614
return replaceDimsAndSymbols(dimReplacements, {});
615615
}
616-
raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr &expr) {
616+
raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
617617
expr.print(os);
618618
return os;
619619
}

mlir/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ set(MLIR_TEST_DEPENDS
3535
MLIRUnitTests
3636
mlir-cpu-runner
3737
mlir-edsc-builder-api-test
38+
mlir-linalg-ods-gen
3839
mlir-opt
3940
mlir-sdbm-api-test
4041
mlir-tblgen

mlir/test/lit.cfg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell)
2222

2323
# suffixes: A list of file extensions to treat as test files.
24-
config.suffixes = ['.td', '.mlir', '.toy', '.ll']
24+
config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc']
2525

2626
# test_source_root: The root path where tests are located.
2727
config.test_source_root = os.path.dirname(__file__)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS
2+
// RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
3+
4+
// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 -test-emit-include-td-header \
5+
// RUN: | mlir-tblgen -gen-op-decls -I %S/../../include
6+
7+
// ODS-LABEL: def matvecOp : LinalgNamedStructured_Op<"matvec", [
8+
// ODS-NEXT: NInputs<2>,
9+
// ODS-NEXT: NOutputs<1>,
10+
// ODS-NEXT: NamedStructuredOpTraits]>
11+
//
12+
// IMPL-LABEL: matvec::referenceIterators() {
13+
// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
14+
//
15+
// IMPL: matvec::referenceIndexingMaps() {
16+
// IMPL: AffineMap::get(2, 0, {d0, d1}),
17+
// IMPL-NEXT: AffineMap::get(2, 0, {d1}),
18+
// IMPL-NEXT: AffineMap::get(2, 0, {d0}) };
19+
//
20+
// IMPL: matvec::regionBuilder(ArrayRef<BlockArgument> args) {
21+
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
22+
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
23+
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
24+
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
25+
//
26+
def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
27+
C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
28+
}
29+
30+
// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [
31+
// ODS-NEXT: NInputs<2>,
32+
// ODS-NEXT: NOutputs<1>,
33+
// ODS-NEXT: NamedStructuredOpTraits]>
34+
//
35+
// IMPL-LABEL: matmul::referenceIterators() {
36+
// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
37+
//
38+
// IMPL: matmul::referenceIndexingMaps() {
39+
// IMPL: AffineMap::get(3, 0, {d0, d2}),
40+
// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}),
41+
// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}) };
42+
//
43+
// IMPL: matmul::regionBuilder(ArrayRef<BlockArgument> args) {
44+
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
45+
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
46+
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
47+
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
48+
//
49+
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
50+
C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
51+
}
52+
53+
// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [
54+
// ODS-NEXT: NInputs<2>,
55+
// ODS-NEXT: NOutputs<1>,
56+
// ODS-NEXT: NamedStructuredOpTraits]>
57+
//
58+
// IMPL-LABEL: batchmatmul::referenceIterators() {
59+
// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} }
60+
//
61+
// IMPL: batchmatmul::referenceIndexingMaps() {
62+
// IMPL: AffineMap::get(4, 0, {d0, d1, d3}),
63+
// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}),
64+
// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}) };
65+
//
66+
// IMPL: batchmatmul::regionBuilder(ArrayRef<BlockArgument> args) {
67+
// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
68+
// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]);
69+
// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]);
70+
// IMPL: (linalg_yield(ValueRange{ [[e]] }));
71+
//
72+
// TBLGEN: batchmatmulOp
73+
def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
74+
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
75+
}

mlir/tools/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_subdirectory(mlir-cuda-runner)
22
add_subdirectory(mlir-cpu-runner)
3+
add_subdirectory(mlir-linalg-ods-gen)
34
add_subdirectory(mlir-opt)
45
add_subdirectory(mlir-translate)
56
add_subdirectory(mlir-vulkan-runner)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
add_llvm_tool(mlir-linalg-ods-gen
2+
mlir-linalg-ods-gen.cpp
3+
)
4+
llvm_update_compile_flags(mlir-linalg-ods-gen)
5+
target_link_libraries(mlir-linalg-ods-gen PRIVATE
6+
MLIRParser
7+
MLIRSupport
8+
LLVMCore
9+
LLVMSupport
10+
)

0 commit comments

Comments
 (0)