Skip to content

Commit 119a230

Browse files
committed
[mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack
This patch introduces a new helper, `getCollapsedVecType`, and updates `vectorizeAsTensorUnpackOp` to use it. The motivation stems from improving how `vector.shape_cast` operations are generated when vectorizing `linalg.unpack`. Previously, the vectorizer relied on * `tensor::CollapseShapeOp::inferCollapsedType` to compute the collapsed vector type. This approach is suboptimal because: * `inferCollapsedType` lacks awareness of scalable vector flags. * Linalg vectorization should not depend on Tensor dialect utilities. Instead of relocating `inferCollapsedType`, we introduce `getCollapsedVecType` — a lightweight, specialized hook that: * Assumes no dynamic sizes. * Handles scalable flags alongside shape dimensions. This change also reduces temporary variables in `vectorizeAsTensorUnpackOp` and paves the way for a cleaner update in #149293.
1 parent b40e535 commit 119a230

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,46 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18311831
return success();
18321832
}
18331833

1834+
/// Given the re-associations, "collapses" the input Vector type
1835+
///
1836+
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
1837+
/// differences:
1838+
/// * We can safely assume that there are no dynamic sizes.
1839+
/// * Scalable flags are updated alongside regular dims.
1840+
///
1841+
/// When collapsing scalable flags, conservatively avoids cases with two
1842+
/// scalable dims. We could re-visit this in the future.
1843+
static VectorType getCollapsedVecType(VectorType type,
1844+
ArrayRef<AffineMap> reassociation) {
1845+
assert(type.getNumScalableDims() < 2 &&
1846+
"Collapsing more than 1 scalable dim is not supported ATM");
1847+
1848+
// Use the fact that reassociation is valid to simplify the logic: only use
1849+
// each map's rank.
1850+
assert(isReassociationValid(reassociation) && "invalid reassociation");
1851+
1852+
auto shape = type.getShape();
1853+
auto scalableFlags = type.getScalableDims();
1854+
SmallVector<int64_t> newShape;
1855+
SmallVector<bool> newScalableFlags;
1856+
1857+
unsigned currentDim = 0;
1858+
for (AffineMap m : reassociation) {
1859+
unsigned dim = m.getNumResults();
1860+
int64_t size = 1;
1861+
bool flag = false;
1862+
for (unsigned d = 0; d < dim; ++d) {
1863+
size *= shape[currentDim + d];
1864+
flag |= scalableFlags[currentDim + d];
1865+
}
1866+
newShape.push_back(size);
1867+
newScalableFlags.push_back(flag);
1868+
currentDim += dim;
1869+
}
1870+
1871+
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
1872+
}
1873+
18341874
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
18351875
/// Vector::TransferReadOp - Reads a vector from the source tensor
18361876
/// vector::TransposeOp - Transpose the Source tensor
@@ -1928,23 +1968,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19281968
PackingMetadata packMetadata;
19291969
SmallVector<int64_t> lastDimToInsertPosPerm =
19301970
getUnPackInverseSrcPerm(unpackOp, packMetadata);
1931-
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1932-
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
1933-
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1934-
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
1935-
RankedTensorType stripMineTensorType =
1936-
RankedTensorType::get(stripMineShape, stripMineElemType);
19371971
// Transpose the appropriate rows to match output.
19381972
vector::TransposeOp transposeOp = vector::TransposeOp::create(
19391973
rewriter, loc, readResult, lastDimToInsertPosPerm);
19401974

19411975
// Collapse the vector to the size required by result.
1942-
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
1943-
stripMineTensorType, packMetadata.reassociations);
1944-
mlir::VectorType vecCollapsedType =
1945-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1976+
VectorType collapsedVecType = getCollapsedVecType(
1977+
transposeOp.getType(),
1978+
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1979+
rewriter.getContext(), packMetadata.reassociations)));
19461980
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
1947-
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
1981+
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
19481982

19491983
Operation *write = createWriteOrMaskedWrite(
19501984
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),

0 commit comments

Comments
 (0)