Skip to content

Commit 186709c

Browse files
committed
[mlir] [VectorOps] Progressive lowering of vector.broadcast
Summary: Rather than having a full, recursive, lowering of vector.broadcast to LLVM IR, it is much more elegant to have a progressive lowering of each vector.broadcast into a lower dimensional vector.broadcast, until only elementary vector operations remain. This results in more elegant, step-wise code, that is easier to understand. Also makes some optimizations in the generated code. Reviewers: nicolasvasilache, mehdi_amini, andydavis1, grosul1 Reviewed By: nicolasvasilache Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, grosul1, frgossen, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D78071
1 parent cc0ec3f commit 186709c

File tree

5 files changed

+437
-311
lines changed

5 files changed

+437
-311
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
5555
/// ContractionOpLowering,
5656
/// ShapeCastOp2DDownCastRewritePattern,
5757
/// ShapeCastOp2DUpCastRewritePattern
58+
/// BroadcastOpLowering,
5859
/// TransposeOpLowering
5960
/// OuterproductOpLowering
6061
/// These transformation express higher level vector ops in terms of more

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 1 addition & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -126,155 +126,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
126126

127127
namespace {
128128

129-
class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
130-
public:
131-
explicit VectorBroadcastOpConversion(MLIRContext *context,
132-
LLVMTypeConverter &typeConverter)
133-
: ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
134-
typeConverter) {}
135-
136-
LogicalResult
137-
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
138-
ConversionPatternRewriter &rewriter) const override {
139-
auto broadcastOp = cast<vector::BroadcastOp>(op);
140-
VectorType dstVectorType = broadcastOp.getVectorType();
141-
if (typeConverter.convertType(dstVectorType) == nullptr)
142-
return failure();
143-
// Rewrite when the full vector type can be lowered (which
144-
// implies all 'reduced' types can be lowered too).
145-
auto adaptor = vector::BroadcastOpOperandAdaptor(operands);
146-
VectorType srcVectorType =
147-
broadcastOp.getSourceType().dyn_cast<VectorType>();
148-
rewriter.replaceOp(
149-
op, expandRanks(adaptor.source(), // source value to be expanded
150-
op->getLoc(), // ___location of original broadcast
151-
srcVectorType, dstVectorType, rewriter));
152-
return success();
153-
}
154-
155-
private:
156-
// Expands the given source value over all the ranks, as defined
157-
// by the source and destination type (a null source type denotes
158-
// expansion from a scalar value into a vector).
159-
//
160-
// TODO(ajcbik): consider replacing this one-pattern lowering
161-
// with a two-pattern lowering using other vector
162-
// ops once all insert/extract/shuffle operations
163-
// are available with lowering implementation.
164-
//
165-
Value expandRanks(Value value, Location loc, VectorType srcVectorType,
166-
VectorType dstVectorType,
167-
ConversionPatternRewriter &rewriter) const {
168-
assert((dstVectorType != nullptr) && "invalid result type in broadcast");
169-
// Determine rank of source and destination.
170-
int64_t srcRank = srcVectorType ? srcVectorType.getRank() : 0;
171-
int64_t dstRank = dstVectorType.getRank();
172-
int64_t curDim = dstVectorType.getDimSize(0);
173-
if (srcRank < dstRank)
174-
// Duplicate this rank.
175-
return duplicateOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
176-
curDim, rewriter);
177-
// If all trailing dimensions are the same, the broadcast consists of
178-
// simply passing through the source value and we are done. Otherwise,
179-
// any non-matching dimension forces a stretch along this rank.
180-
assert((srcVectorType != nullptr) && (srcRank > 0) &&
181-
(srcRank == dstRank) && "invalid rank in broadcast");
182-
for (int64_t r = 0; r < dstRank; r++) {
183-
if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
184-
return stretchOneRank(value, loc, srcVectorType, dstVectorType, dstRank,
185-
curDim, rewriter);
186-
}
187-
}
188-
return value;
189-
}
190-
191-
// Picks the best way to duplicate a single rank. For the 1-D case, a
192-
// single insert-elt/shuffle is the most efficient expansion. For higher
193-
// dimensions, however, we need dim x insert-values on a new broadcast
194-
// with one less leading dimension, which will be lowered "recursively"
195-
// to matching LLVM IR.
196-
// For example:
197-
// v = broadcast s : f32 to vector<4x2xf32>
198-
// becomes:
199-
// x = broadcast s : f32 to vector<2xf32>
200-
// v = [x,x,x,x]
201-
// becomes:
202-
// x = [s,s]
203-
// v = [x,x,x,x]
204-
Value duplicateOneRank(Value value, Location loc, VectorType srcVectorType,
205-
VectorType dstVectorType, int64_t rank, int64_t dim,
206-
ConversionPatternRewriter &rewriter) const {
207-
Type llvmType = typeConverter.convertType(dstVectorType);
208-
assert((llvmType != nullptr) && "unlowerable vector type");
209-
if (rank == 1) {
210-
Value undef = rewriter.create<LLVM::UndefOp>(loc, llvmType);
211-
Value expand = insertOne(rewriter, typeConverter, loc, undef, value,
212-
llvmType, rank, 0);
213-
SmallVector<int32_t, 4> zeroValues(dim, 0);
214-
return rewriter.create<LLVM::ShuffleVectorOp>(
215-
loc, expand, undef, rewriter.getI32ArrayAttr(zeroValues));
216-
}
217-
Value expand = expandRanks(value, loc, srcVectorType,
218-
reducedVectorTypeFront(dstVectorType), rewriter);
219-
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
220-
for (int64_t d = 0; d < dim; ++d) {
221-
result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
222-
rank, d);
223-
}
224-
return result;
225-
}
226-
227-
// Picks the best way to stretch a single rank. For the 1-D case, a
228-
// single insert-elt/shuffle is the most efficient expansion when at
229-
// a stretch. Otherwise, every dimension needs to be expanded
230-
// individually and individually inserted in the resulting vector.
231-
// For example:
232-
// v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
233-
// becomes:
234-
// a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
235-
// b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
236-
// c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
237-
// d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
238-
// v = [a,b,c,d]
239-
// becomes:
240-
// x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
241-
// y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
242-
// a = [x, y]
243-
// etc.
244-
Value stretchOneRank(Value value, Location loc, VectorType srcVectorType,
245-
VectorType dstVectorType, int64_t rank, int64_t dim,
246-
ConversionPatternRewriter &rewriter) const {
247-
Type llvmType = typeConverter.convertType(dstVectorType);
248-
assert((llvmType != nullptr) && "unlowerable vector type");
249-
Value result = rewriter.create<LLVM::UndefOp>(loc, llvmType);
250-
bool atStretch = dim != srcVectorType.getDimSize(0);
251-
if (rank == 1) {
252-
assert(atStretch);
253-
Type redLlvmType =
254-
typeConverter.convertType(dstVectorType.getElementType());
255-
Value one =
256-
extractOne(rewriter, typeConverter, loc, value, redLlvmType, rank, 0);
257-
Value expand = insertOne(rewriter, typeConverter, loc, result, one,
258-
llvmType, rank, 0);
259-
SmallVector<int32_t, 4> zeroValues(dim, 0);
260-
return rewriter.create<LLVM::ShuffleVectorOp>(
261-
loc, expand, result, rewriter.getI32ArrayAttr(zeroValues));
262-
}
263-
VectorType redSrcType = reducedVectorTypeFront(srcVectorType);
264-
VectorType redDstType = reducedVectorTypeFront(dstVectorType);
265-
Type redLlvmType = typeConverter.convertType(redSrcType);
266-
for (int64_t d = 0; d < dim; ++d) {
267-
int64_t pos = atStretch ? 0 : d;
268-
Value one = extractOne(rewriter, typeConverter, loc, value, redLlvmType,
269-
rank, pos);
270-
Value expand = expandRanks(one, loc, redSrcType, redDstType, rewriter);
271-
result = insertOne(rewriter, typeConverter, loc, result, expand, llvmType,
272-
rank, d);
273-
}
274-
return result;
275-
}
276-
};
277-
278129
/// Conversion pattern for a vector.matrix_multiply.
279130
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
280131
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
@@ -1209,8 +1060,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
12091060
VectorInsertStridedSliceOpSameRankRewritePattern,
12101061
VectorStridedSliceOpConversion>(ctx);
12111062
patterns
1212-
.insert<VectorBroadcastOpConversion,
1213-
VectorReductionOpConversion,
1063+
.insert<VectorReductionOpConversion,
12141064
VectorShuffleOpConversion,
12151065
VectorExtractElementOpConversion,
12161066
VectorExtractOpConversion,

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,114 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
979979
}
980980
};
981981

982-
/// Progressive lowering of OuterProductOp.
982+
/// Progressive lowering of BroadcastOp.
983+
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
984+
public:
985+
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
986+
987+
LogicalResult matchAndRewrite(vector::BroadcastOp op,
988+
PatternRewriter &rewriter) const override {
989+
auto loc = op.getLoc();
990+
VectorType dstType = op.getVectorType();
991+
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
992+
Type eltType = dstType.getElementType();
993+
994+
// Determine rank of source and destination.
995+
int64_t srcRank = srcType ? srcType.getRank() : 0;
996+
int64_t dstRank = dstType.getRank();
997+
998+
// Duplicate this rank.
999+
// For example:
1000+
// %x = broadcast %y : k-D to n-D, k < n
1001+
// becomes:
1002+
// %b = broadcast %y : k-D to (n-1)-D
1003+
// %x = [%b,%b,%b,%b] : n-D
1004+
// becomes:
1005+
// %b = [%y,%y] : (n-1)-D
1006+
// %x = [%b,%b,%b,%b] : n-D
1007+
if (srcRank < dstRank) {
1008+
// Scalar to any vector can use splat.
1009+
if (srcRank == 0) {
1010+
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, op.source());
1011+
return success();
1012+
}
1013+
// Duplication.
1014+
VectorType resType =
1015+
VectorType::get(dstType.getShape().drop_front(), eltType);
1016+
Value bcst =
1017+
rewriter.create<vector::BroadcastOp>(loc, resType, op.source());
1018+
Value zero = rewriter.create<ConstantOp>(loc, eltType,
1019+
rewriter.getZeroAttr(eltType));
1020+
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
1021+
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1022+
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1023+
rewriter.replaceOp(op, result);
1024+
return success();
1025+
}
1026+
1027+
// Find non-matching dimension, if any.
1028+
assert(srcRank == dstRank);
1029+
int64_t m = -1;
1030+
for (int64_t r = 0; r < dstRank; r++)
1031+
if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
1032+
m = r;
1033+
break;
1034+
}
1035+
1036+
// All trailing dimensions are the same. Simply pass through.
1037+
if (m == -1) {
1038+
rewriter.replaceOp(op, op.source());
1039+
return success();
1040+
}
1041+
1042+
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
1043+
if (srcRank == 1) {
1044+
assert(m == 0);
1045+
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1046+
rewriter.replaceOpWithNewOp<SplatOp>(op, dstType, ext);
1047+
return success();
1048+
}
1049+
1050+
// Any non-matching dimension forces a stretch along this rank.
1051+
// For example:
1052+
// %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32>
1053+
// becomes:
1054+
// %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32>
1055+
// %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32>
1056+
// %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32>
1057+
// %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32>
1058+
// %x = [%a,%b,%c,%d]
1059+
// becomes:
1060+
// %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32>
1061+
// %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32>
1062+
// %a = [%u, %v]
1063+
// ..
1064+
// %x = [%a,%b,%c,%d]
1065+
VectorType resType =
1066+
VectorType::get(dstType.getShape().drop_front(), eltType);
1067+
Value zero = rewriter.create<ConstantOp>(loc, eltType,
1068+
rewriter.getZeroAttr(eltType));
1069+
Value result = rewriter.create<SplatOp>(loc, dstType, zero);
1070+
if (m == 0) {
1071+
// Stetch at start.
1072+
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), 0);
1073+
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1074+
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
1075+
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1076+
} else {
1077+
// Stetch not at start.
1078+
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
1079+
Value ext = rewriter.create<vector::ExtractOp>(loc, op.source(), d);
1080+
Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
1081+
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
1082+
}
1083+
}
1084+
rewriter.replaceOp(op, result);
1085+
return success();
1086+
}
1087+
};
1088+
1089+
/// Progressive lowering of TransposeOp.
9831090
/// One:
9841091
/// %x = vector.transpose %y, [1, 0]
9851092
/// is replaced by:
@@ -1518,7 +1625,7 @@ void mlir::vector::populateVectorContractLoweringPatterns(
15181625
OwningRewritePatternList &patterns, MLIRContext *context,
15191626
VectorTransformsOptions parameters) {
15201627
patterns.insert<ShapeCastOp2DDownCastRewritePattern,
1521-
ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering,
1522-
OuterProductOpLowering>(context);
1628+
ShapeCastOp2DUpCastRewritePattern, BroadcastOpLowering,
1629+
TransposeOpLowering, OuterProductOpLowering>(context);
15231630
patterns.insert<ContractionOpLowering>(parameters, context);
15241631
}

0 commit comments

Comments
 (0)