Skip to content

Commit 0145a26

Browse files
committed
[MLIR] Add explicit initial values for loop.parallel op.
Differential Revision: https://reviews.llvm.org/D75206
1 parent 63b2ff0 commit 0145a26

File tree

4 files changed

+143
-85
lines changed

4 files changed

+143
-85
lines changed

mlir/include/mlir/Dialect/LoopOps/LoopOps.td

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def ForOp : Loop_Op<"for",
5252

5353
The body region must contain exactly one block that terminates with
5454
"loop.yield". Calling ForOp::build will create such a region and insert
55-
the terminator implicitly if none is defined, so will the parsing even
55+
the terminator implicitly if none is defined, so will the parsing even
5656
in cases when it is absent from the custom format. For example:
5757

5858
```mlir
@@ -62,17 +62,17 @@ def ForOp : Loop_Op<"for",
6262
```
6363

6464
"loop.for" can also operate on loop-carried variables and returns the final values
65-
after loop termination. The initial values of the variables are passed as additional SSA
65+
after loop termination. The initial values of the variables are passed as additional SSA
6666
operands to the "loop.for" following the 3 loop control SSA values mentioned above
67-
(lower bound, upper bound and step). The operation region has equivalent arguments
67+
(lower bound, upper bound and step). The operation region has equivalent arguments
6868
for each variable representing the value of the variable at the current iteration.
69-
70-
The region must terminate with a "loop.yield" that passes all the current iteration
69+
70+
The region must terminate with a "loop.yield" that passes all the current iteration
7171
variables to the next iteration, or to the "loop.for" result, if at the last iteration.
72-
"loop.for" results hold the final values after the last iteration.
73-
72+
"loop.for" results hold the final values after the last iteration.
73+
7474
For example, to sum-reduce a memref:
75-
75+
7676
```mlir
7777
func @reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %step: index) -> (f32) {
7878
// Initial sum set to 0.
@@ -85,14 +85,14 @@ def ForOp : Loop_Op<"for",
8585
loop.yield %sum_next : f32
8686
}
8787
return %sum : f32
88-
}
88+
}
8989
```
9090

91-
If the "loop.for" defines any values, a yield must be explicitly present.
92-
The number and types of the "loop.for" results must match the initial values
91+
If the "loop.for" defines any values, a yield must be explicitly present.
92+
The number and types of the "loop.for" results must match the initial values
9393
in the "iter_args" binding and the yield operands.
94-
95-
Another example with a nested "loop.if" (see "loop.if" for details)
94+
95+
Another example with a nested "loop.if" (see "loop.if" for details)
9696
to perform conditional reduction:
9797

9898
```mlir
@@ -118,7 +118,7 @@ def ForOp : Loop_Op<"for",
118118
Index:$upperBound,
119119
Index:$step,
120120
Variadic<AnyType>:$initArgs);
121-
let results = (outs Variadic<AnyType>:$results);
121+
let results = (outs Variadic<AnyType>:$results);
122122
let regions = (region SizedRegion<1>:$region);
123123

124124
let skipDefaultBuilders = 1;
@@ -143,15 +143,15 @@ def ForOp : Loop_Op<"for",
143143
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
144144
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
145145
void setStep(Value step) { getOperation()->setOperand(2, step); }
146-
146+
147147
/// Number of region arguments for loop-carried values
148148
unsigned getNumRegionIterArgs() {
149-
return getBody()->getNumArguments() - 1;
149+
return getBody()->getNumArguments() - 1;
150150
}
151151
/// Number of operands controlling the loop: lb, ub, step
152152
unsigned getNumControlOperands() { return 3; }
153153
/// Does the operation hold operands for loop-carried values
154-
bool hasIterOperands() {
154+
bool hasIterOperands() {
155155
return getOperation()->getNumOperands() > getNumControlOperands();
156156
}
157157
/// Get Number of loop-carried values
@@ -178,16 +178,16 @@ def IfOp : Loop_Op<"if",
178178
```
179179

180180
"loop.if" may also return results that are defined in its regions. The values
181-
defined are determined by which execution path is taken.
181+
defined are determined by which execution path is taken.
182182
For example:
183183
```mlir
184184
%x, %y = loop.if %b -> (f32, f32) {
185185
%x_true = ...
186186
%y_true = ...
187187
loop.yield %x_true, %y_true : f32, f32
188188
} else {
189-
%x_false = ...
190-
%y_false = ...
189+
%x_false = ...
190+
%y_false = ...
191191
loop.yield %x_false, %y_false : f32, f32
192192
}
193193
```
@@ -196,7 +196,7 @@ def IfOp : Loop_Op<"if",
196196
defines no values, the "loop.yield" can be left out, and will be
197197
inserted implicitly. Otherwise, it must be explicit.
198198
Also, if "loop.if" defines one or more values, the 'else' block cannot
199-
be omitted.
199+
be omitted.
200200

201201
For example:
202202
```mlir
@@ -230,18 +230,20 @@ def IfOp : Loop_Op<"if",
230230
}
231231

232232
def ParallelOp : Loop_Op<"parallel",
233-
[SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">]> {
233+
[AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> {
234234
let summary = "parallel for operation";
235235
let description = [{
236-
The "loop.parallel" operation represents a loop nest taking 3 groups of SSA
237-
values as operands that represent the lower bounds, upper bounds and steps,
238-
respectively. The operation defines a variadic number of SSA values for its
239-
induction variables. It has one region capturing the loop body. The
240-
induction variables are represented as an argument of this region. These SSA
241-
values always have type index, which is the size of the machine word. The
242-
steps are values of type index, required to be positive.
243-
The lower and upper bounds specify a half-open range: the range includes the
244-
lower bound but does not include the upper bound.
236+
The "loop.parallel" operation represents a loop nest taking 4 groups of SSA
237+
values as operands that represent the lower bounds, upper bounds, steps and
238+
initial values, respectively. The operation defines a variadic number of
239+
SSA values for its induction variables. It has one region capturing the
240+
loop body. The induction variables are represented as an argument of this
241+
region. These SSA values always have type index, which is the size of the
242+
machine word. The steps are values of type index, required to be positive.
243+
The lower and upper bounds specify a half-open range: the range includes
244+
the lower bound but does not include the upper bound. The initial values
245+
have the same types as results of "loop.parallel". If there are no results,
246+
the keyword `init` can be omitted.
245247

246248
Semantically we require that the iteration space can be iterated in any
247249
order, and the loop body can be executed in parallel. If there are data
@@ -250,10 +252,11 @@ def ParallelOp : Loop_Op<"parallel",
250252
The parallel loop operation supports reduction of values produced by
251253
individual iterations into a single result. This is modeled using the
252254
loop.reduce operation (see loop.reduce for details). Each result of a
253-
loop.parallel operation is associated with a reduce operation that is an
254-
immediate child. Reduces are matched to result values in order of their
255-
appearance in the body. Consequently, we require that the body region has
256-
the same number of results as it has reduce operations.
255+
loop.parallel operation is associated with an initial value operand and
256+
reduce operation that is an immediate child. Reductions are matched to result
257+
and initial values in order of their appearance in the body. Consequently,
258+
we require that the body region has the same number of results and initial
259+
values as it has reduce operations.
257260

258261
The body region must contain exactly one block that terminates with
259262
"loop.yield" without operands. Parsing ParallelOp will create such a region
@@ -273,18 +276,16 @@ def ParallelOp : Loop_Op<"parallel",
273276

274277
let arguments = (ins Variadic<Index>:$lowerBound,
275278
Variadic<Index>:$upperBound,
276-
Variadic<Index>:$step);
279+
Variadic<Index>:$step,
280+
Variadic<AnyType>:$initVals);
277281
let results = (outs Variadic<AnyType>:$results);
278282
let regions = (region SizedRegion<1>:$region);
279283

280284
let skipDefaultBuilders = 1;
281285
let builders = [
282286
OpBuilder<"Builder *builder, OperationState &result, "
283287
"ValueRange lowerBounds, ValueRange upperBounds, "
284-
"ValueRange steps">,
285-
OpBuilder<"Builder *builder, OperationState &result, "
286-
"ValueRange lowerBounds, ValueRange upperBounds, "
287-
"ValueRange steps, ArrayRef<Type> resultTypes">
288+
"ValueRange steps, ValueRange initVals = {}">,
288289
];
289290

290291
let extraClassDeclaration = [{
@@ -293,9 +294,10 @@ def ParallelOp : Loop_Op<"parallel",
293294
return getBody()->getNumArguments();
294295
}
295296
iterator_range<Block::args_iterator> getInductionVars() {
296-
return {getBody()->args_begin(), getBody()->args_end()};
297+
return getBody()->getArguments();
297298
}
298299
unsigned getNumLoops() { return step().size(); }
300+
unsigned getNumReductions() { return initVals().size(); }
299301
}];
300302
}
301303

@@ -369,13 +371,13 @@ def YieldOp : Loop_Op<"yield", [Terminator]> {
369371
let description = [{
370372
"loop.yield" yields an SSA value from a loop dialect op region and
371373
terminates the regions. The semantics of how the values are yielded
372-
is defined by the parent operation.
374+
is defined by the parent operation.
373375
If "loop.yield" has any operands, the operands must match the parent
374-
operation's results.
375-
If the parent operation defines no values, then the "loop.yield" may be
376+
operation's results.
377+
If the parent operation defines no values, then the "loop.yield" may be
376378
left out in the custom syntax and the builders will insert one implicitly.
377379
Otherwise, it has to be present in the syntax to indicate which values
378-
are yielded.
380+
are yielded.
379381
}];
380382

381383
let arguments = (ins Variadic<AnyType>:$results);

mlir/lib/Dialect/LoopOps/LoopOps.cpp

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -304,21 +304,23 @@ static void print(OpAsmPrinter &p, IfOp op) {
304304
//===----------------------------------------------------------------------===//
305305

306306
void ParallelOp::build(Builder *builder, OperationState &result, ValueRange lbs,
307-
ValueRange ubs, ValueRange steps) {
307+
ValueRange ubs, ValueRange steps, ValueRange initVals) {
308308
result.addOperands(lbs);
309309
result.addOperands(ubs);
310310
result.addOperands(steps);
311+
result.addOperands(initVals);
312+
result.addAttribute(
313+
ParallelOp::getOperandSegmentSizeAttr(),
314+
builder->getI32VectorAttr({static_cast<int32_t>(lbs.size()),
315+
static_cast<int32_t>(ubs.size()),
316+
static_cast<int32_t>(steps.size()),
317+
static_cast<int32_t>(initVals.size())}));
311318
Region *bodyRegion = result.addRegion();
312319
ParallelOp::ensureTerminator(*bodyRegion, *builder, result.___location);
313320
for (size_t i = 0, e = steps.size(); i < e; ++i)
314321
bodyRegion->front().addArgument(builder->getIndexType());
315-
}
316-
317-
void ParallelOp::build(Builder *builder, OperationState &result, ValueRange lbs,
318-
ValueRange ubs, ValueRange steps,
319-
ArrayRef<Type> resultTypes) {
320-
result.addTypes(resultTypes);
321-
build(builder, result, lbs, ubs, steps);
322+
for (Value init : initVals)
323+
result.addTypes(init.getType());
322324
}
323325

324326
static LogicalResult verify(ParallelOp op) {
@@ -340,19 +342,28 @@ static LogicalResult verify(ParallelOp op) {
340342
// number of tuple elements in step.
341343
Block *body = op.getBody();
342344
if (body->getNumArguments() != stepValues.size())
343-
return op.emitOpError(
344-
"expects the same number of induction variables as bound and step "
345-
"values");
345+
return op.emitOpError()
346+
<< "expects the same number of induction variables: "
347+
<< body->getNumArguments()
348+
<< " as bound and step values: " << stepValues.size();
346349
for (auto arg : body->getArguments())
347350
if (!arg.getType().isIndex())
348351
return op.emitOpError(
349352
"expects arguments for the induction variable to be of index type");
350353

351354
// Check that the number of results is the same as the number of ReduceOps.
352355
SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
353-
if (op.results().size() != reductions.size())
354-
return op.emitOpError(
355-
"expects number of results to be the same as number of reductions");
356+
auto resultsSize = op.results().size();
357+
auto reductionsSize = reductions.size();
358+
auto initValsSize = op.initVals().size();
359+
if (resultsSize != reductionsSize)
360+
return op.emitOpError()
361+
<< "expects number of results: " << resultsSize
362+
<< " to be the same as number of reductions: " << reductionsSize;
363+
if (resultsSize != initValsSize)
364+
return op.emitOpError()
365+
<< "expects number of results: " << resultsSize
366+
<< " to be the same as number of initial values: " << initValsSize;
356367

357368
// Check that the types of the results and reductions are the same.
358369
for (auto resultAndReduce : llvm::zip(op.results(), reductions)) {
@@ -361,8 +372,8 @@ static LogicalResult verify(ParallelOp op) {
361372
auto reduceType = reduceOp.operand().getType();
362373
if (resultType != reduceType)
363374
return reduceOp.emitOpError()
364-
<< "expects type of reduce to be the same as result type: "
365-
<< resultType;
375+
<< "expects type of reduce: " << reduceType
376+
<< " to be the same as result type: " << resultType;
366377
}
367378
return success();
368379
}
@@ -399,17 +410,35 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
399410
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
400411
return failure();
401412

413+
// Parse step value.
414+
SmallVector<OpAsmParser::OperandType, 4> initVals;
415+
if (succeeded(parser.parseOptionalKeyword("init"))) {
416+
if (parser.parseOperandList(initVals, -1, OpAsmParser::Delimiter::Paren))
417+
return failure();
418+
}
419+
402420
// Now parse the body.
403421
Region *body = result.addRegion();
404422
SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
405423
if (parser.parseRegion(*body, ivs, types))
406424
return failure();
407425

426+
// Set `operand_segment_sizes` attribute.
427+
result.addAttribute(
428+
ParallelOp::getOperandSegmentSizeAttr(),
429+
builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
430+
static_cast<int32_t>(upper.size()),
431+
static_cast<int32_t>(steps.size()),
432+
static_cast<int32_t>(initVals.size())}));
433+
408434
// Parse attributes and optional results (in case there is a reduce).
409435
if (parser.parseOptionalAttrDict(result.attributes) ||
410436
parser.parseOptionalColonTypeList(result.types))
411437
return failure();
412438

439+
if (!initVals.empty())
440+
parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
441+
result.operands);
413442
// Add a terminator if none was parsed.
414443
ForOp::ensureTerminator(*body, builder, result.___location);
415444

@@ -420,8 +449,11 @@ static void print(OpAsmPrinter &p, ParallelOp op) {
420449
p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
421450
<< op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
422451
<< ")";
452+
if (!op.initVals().empty())
453+
p << " init (" << op.initVals() << ")";
423454
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
424-
p.printOptionalAttrDict(op.getAttrs());
455+
p.printOptionalAttrDict(
456+
op.getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
425457
if (!op.results().empty())
426458
p << " : " << op.getResultTypes();
427459
}

0 commit comments

Comments
 (0)