Skip to content

Commit 17c1921

Browse files
authored
[mlir][spirv] Add support for structs decorations (#149793)
An alternative implementation could use `ArrayRef` of `NamedAttribute`s or `NamedAttrList` to store structs decorations, as the deserializer uses `NamedAttribute`s for decorations. However, using a custom struct allows us to store the `spirv::Decoration`s directly rather than its name in a `StringRef`/`StringAttr`.
1 parent 330b40e commit 17c1921

File tree

10 files changed

+270
-78
lines changed

10 files changed

+270
-78
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,34 @@ class StructType
330330
bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
331331
};
332332

333+
// Type for specifying the decoration(s) on the struct itself.
334+
struct StructDecorationInfo {
335+
Decoration decoration;
336+
Attribute decorationValue;
337+
338+
StructDecorationInfo(Decoration decoration, Attribute decorationValue)
339+
: decoration(decoration), decorationValue(decorationValue) {}
340+
341+
friend bool operator==(const StructDecorationInfo &lhs,
342+
const StructDecorationInfo &rhs) {
343+
return lhs.decoration == rhs.decoration &&
344+
lhs.decorationValue == rhs.decorationValue;
345+
}
346+
347+
friend bool operator<(const StructDecorationInfo &lhs,
348+
const StructDecorationInfo &rhs) {
349+
return llvm::to_underlying(lhs.decoration) <
350+
llvm::to_underlying(rhs.decoration);
351+
}
352+
353+
bool hasValue() const { return !isa<UnitAttr>(decorationValue); }
354+
};
355+
333356
/// Construct a literal StructType with at least one member.
334357
static StructType get(ArrayRef<Type> memberTypes,
335358
ArrayRef<OffsetInfo> offsetInfo = {},
336-
ArrayRef<MemberDecorationInfo> memberDecorations = {});
359+
ArrayRef<MemberDecorationInfo> memberDecorations = {},
360+
ArrayRef<StructDecorationInfo> structDecorations = {});
337361

338362
/// Construct an identified StructType. This creates a StructType whose body
339363
/// (member types, offset info, and decorations) is not set yet. A call to
@@ -367,6 +391,9 @@ class StructType
367391

368392
bool hasOffset() const;
369393

394+
/// Returns true if the struct has a specified decoration.
395+
bool hasDecoration(spirv::Decoration decoration) const;
396+
370397
uint64_t getMemberOffset(unsigned) const;
371398

372399
// Returns in `memberDecorations` the Decorations (apart from Offset)
@@ -380,12 +407,18 @@ class StructType
380407
unsigned i,
381408
SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const;
382409

410+
// Returns in `structDecorations` the Decorations associated with the
411+
// StructType.
412+
void getStructDecorations(SmallVectorImpl<StructType::StructDecorationInfo>
413+
&structDecorations) const;
414+
383415
/// Sets the contents of an incomplete identified StructType. This method must
384416
/// be called only for identified StructTypes and it must be called only once
385417
/// per instance. Otherwise, failure() is returned.
386418
LogicalResult
387419
trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo = {},
388-
ArrayRef<MemberDecorationInfo> memberDecorations = {});
420+
ArrayRef<MemberDecorationInfo> memberDecorations = {},
421+
ArrayRef<StructDecorationInfo> structDecorations = {});
389422

390423
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
391424
std::optional<StorageClass> storage = std::nullopt);
@@ -396,6 +429,9 @@ class StructType
396429
llvm::hash_code
397430
hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
398431

432+
llvm::hash_code
433+
hash_value(const StructType::StructDecorationInfo &structDecorationInfo);
434+
399435
// SPIR-V KHR cooperative matrix type
400436
class CooperativeMatrixType
401437
: public Type::TypeBase<CooperativeMatrixType, CompositeType,

mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
693693
// `!spirv.struct<` (id `,`)?
694694
// `(`
695695
// (spirv-type (`[` struct-member-decoration `]`)?)*
696-
// `)>`
696+
// `)`
697+
// (`,` struct-decoration)?
698+
// `>`
697699
static Type parseStructType(SPIRVDialect const &dialect,
698700
DialectAsmParser &parser) {
699701
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect,
767769
return Type();
768770
}
769771

770-
if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
772+
if (failed(parser.parseRParen()))
773+
return Type();
774+
775+
SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;
776+
777+
auto parseStructDecoration = [&]() {
778+
std::optional<spirv::Decoration> decoration =
779+
parseAndVerify<spirv::Decoration>(dialect, parser);
780+
if (!decoration)
781+
return failure();
782+
783+
// Parse decoration value if it exists.
784+
if (succeeded(parser.parseOptionalEqual())) {
785+
Attribute decorationValue;
786+
if (failed(parser.parseAttribute(decorationValue)))
787+
return failure();
788+
789+
structDecorationInfo.emplace_back(decoration.value(), decorationValue);
790+
} else {
791+
structDecorationInfo.emplace_back(decoration.value(),
792+
UnitAttr::get(dialect.getContext()));
793+
}
794+
return success();
795+
};
796+
797+
while (succeeded(parser.parseOptionalComma()))
798+
if (failed(parseStructDecoration()))
799+
return Type();
800+
801+
if (failed(parser.parseGreater()))
771802
return Type();
772803

773804
if (!identifier.empty()) {
774805
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
775-
memberDecorationInfo)))
806+
memberDecorationInfo,
807+
structDecorationInfo)))
776808
return Type();
777809
return idStructTy;
778810
}
779811

780-
return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
812+
return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
813+
structDecorationInfo);
781814
}
782815

783816
// spirv-type ::= array-type
@@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
893926
};
894927
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
895928
printMember);
896-
os << ")>";
929+
os << ")";
930+
931+
SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
932+
type.getStructDecorations(decorations);
933+
if (!decorations.empty()) {
934+
os << ", ";
935+
auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
936+
os << stringifyDecoration(decoration.decoration);
937+
if (decoration.hasValue()) {
938+
os << "=";
939+
os.printAttributeWithoutType(decoration.decorationValue);
940+
}
941+
};
942+
llvm::interleaveComma(decorations, os, eachFn);
943+
}
944+
945+
os << ">";
897946
}
898947

899948
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {

0 commit comments

Comments
 (0)