Skip to content

Commit 149d4b5

Browse files
authored
[mlir][EmitC]Allow Fields to have initial values (#151437)
This will ensure that: - The `field` of a class can have an initial value - The `field` op is emitted correctly - The `getfield` op is emitted correctly
1 parent b36f05c commit 149d4b5

File tree

5 files changed

+119
-50
lines changed

5 files changed

+119
-50
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,13 +1659,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
16591659
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"}
16601660
// Example with no attribute:
16611661
emitc.field @fieldName0 : !emitc.array<1xf32>
1662+
// Example with an initial value:
1663+
emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0>
1664+
// Example with an initial value and attributes:
1665+
emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> {
1666+
emitc.opaque = "input_tensor"}
16621667
```
16631668
}];
16641669

16651670
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
1666-
OptionalAttr<AnyAttr>:$attrs);
1671+
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value);
16671672

1668-
let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
1673+
let assemblyFormat = [{
1674+
$sym_name
1675+
`:` custom<EmitCFieldOpTypeAndInitialValue>($type, $initial_value)
1676+
attr-dict
1677+
}];
16691678

16701679
let hasVerifier = 1;
16711680
}
@@ -1686,7 +1695,7 @@ def EmitC_GetFieldOp
16861695
}];
16871696

16881697
let arguments = (ins FlatSymbolRefAttr:$field_name);
1689-
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
1698+
let results = (outs EmitCType:$result);
16901699
let assemblyFormat = "$field_name `:` type($result) attr-dict";
16911700
}
16921701

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
13981398
//===----------------------------------------------------------------------===//
13991399
// FieldOp
14001400
//===----------------------------------------------------------------------===//
1401+
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
1402+
TypeAttr type,
1403+
Attribute initialValue) {
1404+
p << type;
1405+
if (initialValue) {
1406+
p << " = ";
1407+
p.printAttributeWithoutType(initialValue);
1408+
}
1409+
}
1410+
1411+
static Type getInitializerTypeForField(Type type) {
1412+
if (auto array = llvm::dyn_cast<ArrayType>(type))
1413+
return RankedTensorType::get(array.getShape(), array.getElementType());
1414+
return type;
1415+
}
1416+
1417+
static ParseResult
1418+
parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1419+
Attribute &initialValue) {
1420+
Type type;
1421+
if (parser.parseType(type))
1422+
return failure();
1423+
1424+
typeAttr = TypeAttr::get(type);
1425+
1426+
if (parser.parseOptionalEqual())
1427+
return success();
1428+
1429+
if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
1430+
return failure();
1431+
1432+
if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1433+
initialValue))
1434+
return parser.emitError(parser.getNameLoc())
1435+
<< "initial value should be a integer, float, elements or opaque "
1436+
"attribute";
1437+
return success();
1438+
}
1439+
14011440
LogicalResult FieldOp::verify() {
14021441
if (!isSupportedEmitCType(getType()))
14031442
return emitOpError("expected valid emitc type");
@@ -1410,9 +1449,6 @@ LogicalResult FieldOp::verify() {
14101449
if (!symName || symName.getValue().empty())
14111450
return emitOpError("field must have a non-empty symbol name");
14121451

1413-
if (!getAttrs())
1414-
return success();
1415-
14161452
return success();
14171453
}
14181454

mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
5858

5959
auto argAttrs = funcOp.getArgAttrs();
6060
for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
61-
StringAttr fieldName;
62-
Attribute argAttr = nullptr;
63-
64-
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
65-
if (argAttrs && idx < argAttrs->size())
66-
argAttr = (*argAttrs)[idx];
61+
StringAttr fieldName =
62+
rewriter.getStringAttr("fieldName" + std::to_string(idx));
6763

6864
TypeAttr typeAttr = TypeAttr::get(val.getType());
6965
fields.push_back({fieldName, typeAttr});
70-
emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr,
71-
argAttr);
66+
67+
FieldOp fieldop = rewriter.create<emitc::FieldOp>(
68+
funcOp->getLoc(), fieldName, typeAttr, nullptr);
69+
70+
if (argAttrs && idx < argAttrs->size()) {
71+
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
72+
}
7273
}
7374

7475
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,8 @@ struct CppEmitter {
333333
/// Determine whether expression \p op should be emitted in a deferred way.
334334
static bool hasDeferredEmission(Operation *op) {
335335
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
336-
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
336+
emitc::MemberOfPtrOp, emitc::SubscriptOp,
337+
emitc::GetFieldOp>(op);
337338
}
338339

339340
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
@@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
10491050

10501051
static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
10511052
raw_ostream &os = emitter.ostream();
1052-
if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
1053+
if (failed(emitter.emitVariableDeclaration(
1054+
fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName())))
10531055
return failure();
1054-
os << " " << fieldOp.getSymName() << ";";
1055-
return success();
1056-
}
1057-
1058-
static LogicalResult printOperation(CppEmitter &emitter,
1059-
GetFieldOp getFieldOp) {
1060-
raw_indented_ostream &os = emitter.ostream();
1061-
1062-
Value result = getFieldOp.getResult();
1063-
if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
1064-
return failure();
1065-
os << " ";
1066-
if (failed(emitter.emitOperand(result)))
1067-
return failure();
1068-
os << " = ";
1056+
std::optional<Attribute> initialValue = fieldOp.getInitialValue();
1057+
if (initialValue) {
1058+
os << " = ";
1059+
if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue)))
1060+
return failure();
1061+
}
10691062

1070-
os << getFieldOp.getFieldName().str();
1063+
os << ";";
10711064
return success();
10721065
}
10731066

@@ -1204,7 +1197,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
12041197
os << ") {\n";
12051198
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
12061199
return failure();
1207-
os << "}\n";
1200+
os << "}";
12081201

12091202
return success();
12101203
}
@@ -1245,7 +1238,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
12451238
os << ") {\n";
12461239
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
12471240
return failure();
1248-
os << "}\n";
1241+
os << "}";
12491242

12501243
return success();
12511244
}
@@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
17001693
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
17011694
emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
17021695
emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
1703-
emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
1704-
emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
1705-
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1706-
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
1707-
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1708-
emitc::VerbatimOp>(
1696+
emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
1697+
emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1698+
emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1699+
emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
1700+
emitc::VariableOp, emitc::VerbatimOp>(
17091701

17101702
[&](auto op) { return printOperation(*this, op); })
17111703
// Func ops.
@@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
17151707
cacheDeferredOpResult(op.getResult(), op.getName());
17161708
return success();
17171709
})
1710+
.Case<emitc::GetFieldOp>([&](auto op) {
1711+
cacheDeferredOpResult(op.getResult(), op.getFieldName());
1712+
return success();
1713+
})
17181714
.Case<emitc::LiteralOp>([&](auto op) {
17191715
cacheDeferredOpResult(op.getResult(), op.getValue());
17201716
return success();

mlir/test/mlir-translate/emitc_classops.mlir

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,12 @@ emitc.class @modelClass {
1414

1515
// CHECK-LABEL: class modelClass {
1616
// CHECK-NEXT: public:
17-
// CHECK-NEXT: float[1] fieldName0;
18-
// CHECK-NEXT: float[1] fieldName1;
17+
// CHECK-NEXT: float fieldName0[1];
18+
// CHECK-NEXT: float fieldName1[1];
1919
// CHECK-NEXT: void execute() {
2020
// CHECK-NEXT: size_t v1 = 0;
21-
// CHECK-NEXT: float[1] v2 = fieldName0;
22-
// CHECK-NEXT: float[1] v3 = fieldName1;
2321
// CHECK-NEXT: return;
2422
// CHECK-NEXT: }
25-
// CHECK-EMPTY:
2623
// CHECK-NEXT: };
2724

2825
emitc.class final @finalClass {
@@ -39,13 +36,43 @@ emitc.class final @finalClass {
3936

4037
// CHECK-LABEL: class finalClass final {
4138
// CHECK-NEXT: public:
42-
// CHECK-NEXT: float[1] fieldName0;
43-
// CHECK-NEXT: float[1] fieldName1;
39+
// CHECK-NEXT: float fieldName0[1];
40+
// CHECK-NEXT: float fieldName1[1];
4441
// CHECK-NEXT: void execute() {
4542
// CHECK-NEXT: size_t v1 = 0;
46-
// CHECK-NEXT: float[1] v2 = fieldName0;
47-
// CHECK-NEXT: float[1] v3 = fieldName1;
4843
// CHECK-NEXT: return;
4944
// CHECK-NEXT: }
50-
// CHECK-EMPTY:
5145
// CHECK-NEXT: };
46+
47+
emitc.class @mainClass {
48+
emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
49+
emitc.func @get_fieldName0() {
50+
%0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
51+
return
52+
}
53+
}
54+
55+
// CHECK-LABEL: class mainClass {
56+
// CHECK-NEXT: public:
57+
// CHECK-NEXT: float fieldName0[2] = {0.0e+00f, 0.0e+00f};
58+
// CHECK-NEXT: void get_fieldName0() {
59+
// CHECK-NEXT: return;
60+
// CHECK-NEXT: }
61+
// CHECK-NEXT: };
62+
63+
emitc.class @reflectionClass {
64+
emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> = #emitc.opaque<"{ { \22another_feature\22, \22fieldName0\22 } }">
65+
emitc.func @get_reflectionMap() {
66+
%0 = emitc.get_field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>">
67+
return
68+
}
69+
}
70+
71+
// CHECK-LABEL: class reflectionClass {
72+
// CHECK-NEXT: public:
73+
// CHECK-NEXT: const std::map<std::string, std::string> reflectionMap = { { "another_feature", "fieldName0" } };
74+
// CHECK-NEXT: void get_reflectionMap() {
75+
// CHECK-NEXT: return;
76+
// CHECK-NEXT: }
77+
// CHECK-NEXT: };
78+

0 commit comments

Comments
 (0)