@@ -89,23 +89,32 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
89
89
"Value ", "getOutput", (ins "unsigned":$i)
90
90
>,
91
91
InterfaceMethod<[{
92
- Query the index of the given input value, or `None` if the value is not
93
- an input.
92
+ Return the index of the given input value `v` , or `None` if the value is
93
+ not an input.
94
94
}],
95
- "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$view )
95
+ "llvm::Optional<unsigned>", "getIndexOfInput", (ins "Value ":$v )
96
96
>,
97
97
InterfaceMethod<[{
98
98
Query the index of the given view value, or `None` if the value is not
99
- an view.
99
+ a view.
100
100
}],
101
101
"llvm::Optional<unsigned>", "getIndexOfOutput", (ins "Value ":$view)
102
102
>,
103
103
InterfaceMethod<[{
104
- Query the type of the input view at the given index.
105
- }], "MemRefType ", "getInputViewType ", (ins "unsigned":$i)>,
104
+ Query the type of the input shape at the given index.
105
+ }], "ShapedType ", "getInputShapedType ", (ins "unsigned":$i)>,
106
106
InterfaceMethod<[{
107
107
Query the type of the output view at the given index.
108
- }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>,
108
+ }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>,
109
+ InterfaceMethod<[{
110
+ Query whether the op has only MemRef input and outputs.
111
+ }], "bool", "hasBufferSemantics">,
112
+ InterfaceMethod<[{
113
+ Query the subset of input operands that are of ranked tensor type.
114
+ }], "SmallVector<RankedTensorType, 4>", "getInputTensorTypes">,
115
+ InterfaceMethod<[{
116
+ Query the subset of output operands that are of ranked tensor type.
117
+ }], "SmallVector<RankedTensorType, 4>", "getOutputTensorTypes">,
109
118
110
119
StaticInterfaceMethod<[{
111
120
Create an operation of the current type with the given ___location,
@@ -340,7 +349,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
340
349
ArrayAttr iterator_types() {
341
350
// Outer parallel loops are always the number of output dimensions; i.e.
342
351
// [ b, xs, q] in the TF notation above.
343
- unsigned nPar = getOutputViewType (0).getRank();
352
+ unsigned nPar = getOutputShapedType (0).getRank();
344
353
unsigned nRed = getNumInputFeatureDimensions();
345
354
// Window loops are a special kind of reduction that is never tiled or
346
355
// parallelized across; i.e. [zs] in the TF notation above whose number
@@ -374,15 +383,25 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
374
383
let verifier = [{ return ::verify(*this); }];
375
384
}
376
385
386
+ def LinalgOperand: Type<
387
+ Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>;
388
+
389
+ class LinalgOperandOfRank<int rank>: Type<
390
+ And<[
391
+ LinalgOperand.predicate,
392
+ CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
393
+ >>;
394
+
377
395
class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
378
- let arguments = (ins Variadic<AnyStridedMemRef >:$views,
396
+ let arguments = (ins Variadic<LinalgOperand >:$views,
379
397
I64Attr:$args_in,
380
398
I64Attr:$args_out,
381
399
AffineMapArrayAttr:$indexing_maps,
382
400
ArrayAttr:$iterator_types,
383
401
OptionalAttr<StrAttr>:$doc,
384
402
OptionalAttr<FlatSymbolRefAttr>:$fun,
385
403
OptionalAttr<StrAttr>:$library_call);
404
+ let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
386
405
let regions = (region AnyRegion:$region);
387
406
let extraClassDeclaration = [{
388
407
SmallVector<StringRef, 8> linalgTraitAttrNames() {
@@ -511,6 +530,28 @@ def GenericOp : GenericOpBase<"generic"> {
511
530
}
512
531
}
513
532
```
533
+
534
+ To allow progressive lowering from the value world (a.k.a tensor values) to
535
+ the buffer world (a.k.a memref values), a `linalg.generic` op accepts
536
+ mixing input and output ranked tensor values with input and output memrefs.
537
+
538
+ ```mlir
539
+ %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
540
+ tensor<?x?xf32>,
541
+ memref<?x?xf32, stride_specification>,
542
+ tensor<?x?xf32>
543
+ -> (tensor<?x?xf32>)
544
+ ```
545
+
546
+ In this case, the number of return values must match the number of output
547
+ tensor arguments. The semantics is that the `linalg.generic` op
548
+ produces (i.e. allocates and fills) its return values.
549
+ Tensor values must be legalized by a buffer allocation pass before most
550
+ transformations can be applied. In particular, transformations that create
551
+ control flow around linalg.generic operations are not expected to mix with
552
+ tensors because SSA values do not escape naturally. Still, transformations
553
+ and rewrites that take advantage of tensor SSA values are expected to be
554
+ useful and will be added in the near future.
514
555
}];
515
556
let verifier = [{ return ::verify(*this); }];
516
557
}
@@ -555,9 +596,11 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
555
596
Example:
556
597
Defining a #matmul_trait attribute in MLIR can be done as follows:
557
598
```mlir
558
- func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32)
599
+ func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
600
+ %a: f32, %b: f32, %c: f32)
559
601
-> f32
560
602
{
603
+ "some_optional_condition"(%offset_m, %offset_n, %offset_k)
561
604
%d = mulf %a, %b: f32
562
605
%e = addf %c, %d: f32
563
606
return %e: f32
@@ -587,7 +630,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
587
630
588
631
This may lower to either:
589
632
```mlir
590
- call @linalg_matmul(%A, %B, %C) :
633
+ call @linalg_matmul(%offset_m, %offset_n, %offset_k, % A, %B, %C) :
591
634
(memref<?x?xf32, stride_specification>,
592
635
memref<?x?xf32, stride_specification>,
593
636
memref<?x?xf32, stride_specification>)
@@ -609,6 +652,29 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
609
652
}
610
653
}
611
654
```
655
+
656
+ To allow progressive lowering from the value world (a.k.a tensor values) to
657
+ the buffer world (a.k.a memref values), a `linalg.indexed_generic` op
658
+ accepts mixing input and output ranked tensor values with input and output
659
+ memrefs.
660
+
661
+ ```mlir
662
+ %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes}
663
+ : tensor<?x?xf32>,
664
+ memref<?x?xf32, stride_specification>,
665
+ tensor<?x?xf32>
666
+ -> (tensor<?x?xf32>)
667
+ ```
668
+
669
+ In this case, the number of return values must match the number of output
670
+ tensor arguments. The semantics is that the `linalg.indexed_generic` op
671
+ produces (i.e. allocates and fills) its return values.
672
+ Tensor values must be legalized by a buffer allocation pass before most
673
+ transformations can be applied. In particular, transformations that create
674
+ control flow around linalg.generic operations are not expected to mix with
675
+ tensors because SSA values do not escape naturally. Still, transformations
676
+ and rewrites that take advantage of tensor SSA values are expected to be
677
+ useful and will be added in the near future.
612
678
}];
613
679
let verifier = [{ return ::verify(*this); }];
614
680
}
0 commit comments