@@ -126,155 +126,6 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
126
126
127
127
namespace {
128
128
129
- class VectorBroadcastOpConversion : public ConvertToLLVMPattern {
130
- public:
131
- explicit VectorBroadcastOpConversion (MLIRContext *context,
132
- LLVMTypeConverter &typeConverter)
133
- : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context,
134
- typeConverter) {}
135
-
136
- LogicalResult
137
- matchAndRewrite (Operation *op, ArrayRef<Value> operands,
138
- ConversionPatternRewriter &rewriter) const override {
139
- auto broadcastOp = cast<vector::BroadcastOp>(op);
140
- VectorType dstVectorType = broadcastOp.getVectorType ();
141
- if (typeConverter.convertType (dstVectorType) == nullptr )
142
- return failure ();
143
- // Rewrite when the full vector type can be lowered (which
144
- // implies all 'reduced' types can be lowered too).
145
- auto adaptor = vector::BroadcastOpOperandAdaptor (operands);
146
- VectorType srcVectorType =
147
- broadcastOp.getSourceType ().dyn_cast <VectorType>();
148
- rewriter.replaceOp (
149
- op, expandRanks (adaptor.source (), // source value to be expanded
150
- op->getLoc (), // ___location of original broadcast
151
- srcVectorType, dstVectorType, rewriter));
152
- return success ();
153
- }
154
-
155
- private:
156
- // Expands the given source value over all the ranks, as defined
157
- // by the source and destination type (a null source type denotes
158
- // expansion from a scalar value into a vector).
159
- //
160
- // TODO(ajcbik): consider replacing this one-pattern lowering
161
- // with a two-pattern lowering using other vector
162
- // ops once all insert/extract/shuffle operations
163
- // are available with lowering implementation.
164
- //
165
- Value expandRanks (Value value, Location loc, VectorType srcVectorType,
166
- VectorType dstVectorType,
167
- ConversionPatternRewriter &rewriter) const {
168
- assert ((dstVectorType != nullptr ) && " invalid result type in broadcast" );
169
- // Determine rank of source and destination.
170
- int64_t srcRank = srcVectorType ? srcVectorType.getRank () : 0 ;
171
- int64_t dstRank = dstVectorType.getRank ();
172
- int64_t curDim = dstVectorType.getDimSize (0 );
173
- if (srcRank < dstRank)
174
- // Duplicate this rank.
175
- return duplicateOneRank (value, loc, srcVectorType, dstVectorType, dstRank,
176
- curDim, rewriter);
177
- // If all trailing dimensions are the same, the broadcast consists of
178
- // simply passing through the source value and we are done. Otherwise,
179
- // any non-matching dimension forces a stretch along this rank.
180
- assert ((srcVectorType != nullptr ) && (srcRank > 0 ) &&
181
- (srcRank == dstRank) && " invalid rank in broadcast" );
182
- for (int64_t r = 0 ; r < dstRank; r++) {
183
- if (srcVectorType.getDimSize (r) != dstVectorType.getDimSize (r)) {
184
- return stretchOneRank (value, loc, srcVectorType, dstVectorType, dstRank,
185
- curDim, rewriter);
186
- }
187
- }
188
- return value;
189
- }
190
-
191
- // Picks the best way to duplicate a single rank. For the 1-D case, a
192
- // single insert-elt/shuffle is the most efficient expansion. For higher
193
- // dimensions, however, we need dim x insert-values on a new broadcast
194
- // with one less leading dimension, which will be lowered "recursively"
195
- // to matching LLVM IR.
196
- // For example:
197
- // v = broadcast s : f32 to vector<4x2xf32>
198
- // becomes:
199
- // x = broadcast s : f32 to vector<2xf32>
200
- // v = [x,x,x,x]
201
- // becomes:
202
- // x = [s,s]
203
- // v = [x,x,x,x]
204
- Value duplicateOneRank (Value value, Location loc, VectorType srcVectorType,
205
- VectorType dstVectorType, int64_t rank, int64_t dim,
206
- ConversionPatternRewriter &rewriter) const {
207
- Type llvmType = typeConverter.convertType (dstVectorType);
208
- assert ((llvmType != nullptr ) && " unlowerable vector type" );
209
- if (rank == 1 ) {
210
- Value undef = rewriter.create <LLVM::UndefOp>(loc, llvmType);
211
- Value expand = insertOne (rewriter, typeConverter, loc, undef, value,
212
- llvmType, rank, 0 );
213
- SmallVector<int32_t , 4 > zeroValues (dim, 0 );
214
- return rewriter.create <LLVM::ShuffleVectorOp>(
215
- loc, expand, undef, rewriter.getI32ArrayAttr (zeroValues));
216
- }
217
- Value expand = expandRanks (value, loc, srcVectorType,
218
- reducedVectorTypeFront (dstVectorType), rewriter);
219
- Value result = rewriter.create <LLVM::UndefOp>(loc, llvmType);
220
- for (int64_t d = 0 ; d < dim; ++d) {
221
- result = insertOne (rewriter, typeConverter, loc, result, expand, llvmType,
222
- rank, d);
223
- }
224
- return result;
225
- }
226
-
227
- // Picks the best way to stretch a single rank. For the 1-D case, a
228
- // single insert-elt/shuffle is the most efficient expansion when at
229
- // a stretch. Otherwise, every dimension needs to be expanded
230
- // individually and individually inserted in the resulting vector.
231
- // For example:
232
- // v = broadcast w : vector<4x1x2xf32> to vector<4x2x2xf32>
233
- // becomes:
234
- // a = broadcast w[0] : vector<1x2xf32> to vector<2x2xf32>
235
- // b = broadcast w[1] : vector<1x2xf32> to vector<2x2xf32>
236
- // c = broadcast w[2] : vector<1x2xf32> to vector<2x2xf32>
237
- // d = broadcast w[3] : vector<1x2xf32> to vector<2x2xf32>
238
- // v = [a,b,c,d]
239
- // becomes:
240
- // x = broadcast w[0][0] : vector<2xf32> to vector <2x2xf32>
241
- // y = broadcast w[1][0] : vector<2xf32> to vector <2x2xf32>
242
- // a = [x, y]
243
- // etc.
244
- Value stretchOneRank (Value value, Location loc, VectorType srcVectorType,
245
- VectorType dstVectorType, int64_t rank, int64_t dim,
246
- ConversionPatternRewriter &rewriter) const {
247
- Type llvmType = typeConverter.convertType (dstVectorType);
248
- assert ((llvmType != nullptr ) && " unlowerable vector type" );
249
- Value result = rewriter.create <LLVM::UndefOp>(loc, llvmType);
250
- bool atStretch = dim != srcVectorType.getDimSize (0 );
251
- if (rank == 1 ) {
252
- assert (atStretch);
253
- Type redLlvmType =
254
- typeConverter.convertType (dstVectorType.getElementType ());
255
- Value one =
256
- extractOne (rewriter, typeConverter, loc, value, redLlvmType, rank, 0 );
257
- Value expand = insertOne (rewriter, typeConverter, loc, result, one,
258
- llvmType, rank, 0 );
259
- SmallVector<int32_t , 4 > zeroValues (dim, 0 );
260
- return rewriter.create <LLVM::ShuffleVectorOp>(
261
- loc, expand, result, rewriter.getI32ArrayAttr (zeroValues));
262
- }
263
- VectorType redSrcType = reducedVectorTypeFront (srcVectorType);
264
- VectorType redDstType = reducedVectorTypeFront (dstVectorType);
265
- Type redLlvmType = typeConverter.convertType (redSrcType);
266
- for (int64_t d = 0 ; d < dim; ++d) {
267
- int64_t pos = atStretch ? 0 : d;
268
- Value one = extractOne (rewriter, typeConverter, loc, value, redLlvmType,
269
- rank, pos);
270
- Value expand = expandRanks (one, loc, redSrcType, redDstType, rewriter);
271
- result = insertOne (rewriter, typeConverter, loc, result, expand, llvmType,
272
- rank, d);
273
- }
274
- return result;
275
- }
276
- };
277
-
278
129
// / Conversion pattern for a vector.matrix_multiply.
279
130
// / This is lowered directly to the proper llvm.intr.matrix.multiply.
280
131
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
@@ -1209,8 +1060,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
1209
1060
VectorInsertStridedSliceOpSameRankRewritePattern,
1210
1061
VectorStridedSliceOpConversion>(ctx);
1211
1062
patterns
1212
- .insert <VectorBroadcastOpConversion,
1213
- VectorReductionOpConversion,
1063
+ .insert <VectorReductionOpConversion,
1214
1064
VectorShuffleOpConversion,
1215
1065
VectorExtractElementOpConversion,
1216
1066
VectorExtractOpConversion,
0 commit comments