Skip to content

Commit d1054e8

Browse files
authored
[mlir][NFC] Use range-based overload of llvm::sort (#150934)
Replace explicit begin/end iterator pairs with the range-based overload of `llvm::sort`, which simplifies the code and improves readability.
1 parent d64240b commit d1054e8

File tree

2 files changed

+32
-33
lines changed

2 files changed

+32
-33
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
476476
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
477477
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
478478
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
479-
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
480-
llvm::sort(dimensions.m.begin(), dimensions.m.end());
481-
llvm::sort(dimensions.n.begin(), dimensions.n.end());
482-
llvm::sort(dimensions.k.begin(), dimensions.k.end());
479+
llvm::sort(dimensions.batch);
480+
llvm::sort(dimensions.m);
481+
llvm::sort(dimensions.n);
482+
llvm::sort(dimensions.k);
483483
return dimensions;
484484
}
485485

@@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
797797
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
798798
/*strides=*/SmallVector<int64_t, 2>{},
799799
/*dilations=*/SmallVector<int64_t, 2>{}};
800-
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
801-
llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
802-
llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
803-
llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
804-
llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
805-
llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
800+
llvm::sort(dimensions.batch);
801+
llvm::sort(dimensions.outputImage);
802+
llvm::sort(dimensions.outputChannel);
803+
llvm::sort(dimensions.filterLoop);
804+
llvm::sort(dimensions.inputChannel);
805+
llvm::sort(dimensions.depth);
806806

807807
// Use the op carried strides/dilations attribute if present.
808808
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
38773877
llvm::SmallVector<size_t> indices(indexAttr.size());
38783878
std::iota(indices.begin(), indices.end(), 0);
38793879

3880-
llvm::sort(indices.begin(), indices.end(),
3881-
[&](const size_t a, const size_t b) {
3882-
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3883-
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3884-
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3885-
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3886-
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3887-
3888-
if (aIndex == bIndex)
3889-
continue;
3890-
3891-
if (aIndex < bIndex)
3892-
return first;
3893-
3894-
if (aIndex > bIndex)
3895-
return !first;
3896-
}
3897-
3898-
// Iterated the up until the end of the smallest member and
3899-
// they were found to be equal up to that point, so select
3900-
// the member with the lowest index count, so the "parent"
3901-
return memberIndicesA.size() < memberIndicesB.size();
3902-
});
3880+
llvm::sort(indices, [&](const size_t a, const size_t b) {
3881+
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3882+
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3883+
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3884+
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3885+
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3886+
3887+
if (aIndex == bIndex)
3888+
continue;
3889+
3890+
if (aIndex < bIndex)
3891+
return first;
3892+
3893+
if (aIndex > bIndex)
3894+
return !first;
3895+
}
3896+
3897+
// Iterated the up until the end of the smallest member and
3898+
// they were found to be equal up to that point, so select
3899+
// the member with the lowest index count, so the "parent"
3900+
return memberIndicesA.size() < memberIndicesB.size();
3901+
});
39033902

39043903
return llvm::cast<omp::MapInfoOp>(
39053904
mapInfo.getMembers()[indices.front()].getDefiningOp());

0 commit comments

Comments
 (0)