Skip to content

Commit 77363fb

Browse files
authored
[mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack (#151503)
This patch introduces a new helper, `getCollapsedVecType`, and updates `vectorizeAsTensorUnpackOp` to use it. The motivation stems from improving how `vector.shape_cast` operations are generated when vectorizing `linalg.unpack`. Previously, the vectorizer relied on * `tensor::CollapseShapeOp::inferCollapsedType` to compute the collapsed vector type. This approach is suboptimal because: * `inferCollapsedType` lacks awareness of scalable vector flags. * Linalg vectorization should not depend on Tensor dialect utilities. Instead of relocating `inferCollapsedType`, we introduce `getCollapsedVecType` — a lightweight, specialized hook that: * Assumes no dynamic sizes. * Handles scalable flags alongside shape dimensions. This change also reduces temporary variables in `vectorizeAsTensorUnpackOp` and paves the way for a cleaner update in #149293.
1 parent ebe6eba commit 77363fb

File tree

1 file changed

+52
-11
lines changed

1 file changed

+52
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,53 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18311831
return success();
18321832
}
18331833

1834+
/// Given the re-associations, "collapses" the input Vector type
1835+
///
1836+
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1837+
/// differences:
1838+
/// * We can safely assume that there are no dynamic sizes.
1839+
/// * Scalable flags are updated alongside regular dims.
1840+
///
1841+
/// When collapsing scalable flags, conservatively avoids cases with two
1842+
/// scalable dims. We could re-visit this in the future.
1843+
///
1844+
/// EXAMPLE:
1845+
/// type = vector<4x16x[8]x16xf32>
1846+
/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
1847+
/// (d0, d1, d2, d3) -> (d2, d3)]
1848+
/// Result:
1849+
/// vector<64x[128]xf32>
1850+
static VectorType getCollapsedVecType(VectorType type,
1851+
ArrayRef<AffineMap> reassociation) {
1852+
assert(type.getNumScalableDims() < 2 &&
1853+
"Collapsing more than 1 scalable dim is not supported ATM");
1854+
1855+
// Use the fact that reassociation is valid to simplify the logic: only use
1856+
// each map's rank.
1857+
assert(isReassociationValid(reassociation) && "invalid reassociation");
1858+
1859+
auto shape = type.getShape();
1860+
auto scalableFlags = type.getScalableDims();
1861+
SmallVector<int64_t> newShape;
1862+
SmallVector<bool> newScalableFlags;
1863+
1864+
unsigned currentDim = 0;
1865+
for (AffineMap m : reassociation) {
1866+
unsigned dim = m.getNumResults();
1867+
int64_t size = 1;
1868+
bool flag = false;
1869+
for (unsigned d = 0; d < dim; ++d) {
1870+
size *= shape[currentDim + d];
1871+
flag |= scalableFlags[currentDim + d];
1872+
}
1873+
newShape.push_back(size);
1874+
newScalableFlags.push_back(flag);
1875+
currentDim += dim;
1876+
}
1877+
1878+
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1879+
}
1880+
18341881
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
18351882
/// Vector::TransferReadOp - Reads a vector from the source tensor
18361883
/// vector::TransposeOp - Transpose the Source tensor
@@ -1928,23 +1975,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19281975
PackingMetadata packMetadata;
19291976
SmallVector<int64_t> lastDimToInsertPosPerm =
19301977
getUnPackInverseSrcPerm(unpackOp, packMetadata);
1931-
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1932-
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1933-
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1934-
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1935-
RankedTensorType stripMineTensorType =
1936-
RankedTensorType::get(stripMineShape, stripMineElemType);
19371978
// Transpose the appropriate rows to match output.
19381979
vector::TransposeOp transposeOp = vector::TransposeOp::create(
19391980
rewriter, loc, readResult, lastDimToInsertPosPerm);
19401981

19411982
// Collapse the vector to the size required by result.
1942-
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1943-
stripMineTensorType, packMetadata.reassociations);
1944-
mlir::VectorType vecCollapsedType =
1945-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1983+
VectorType collapsedVecType = getCollapsedVecType(
1984+
transposeOp.getType(),
1985+
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1986+
rewriter.getContext(), packMetadata.reassociations)));
19461987
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1947-
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
1988+
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
19481989

19491990
Operation *write = createWriteOrMaskedWrite(
19501991
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),

0 commit comments

Comments
 (0)