-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][EmitC]Allow Fields to have initial values #151437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Jaden Angella (Jaddyen) ChangesThis will ensure that:
Full diff: https://github.com/llvm/llvm-project/pull/151437.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7fe2da8f7e044..f06a24c61454a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1659,13 +1659,23 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"}
// Example with no attribute:
emitc.field @fieldName0 : !emitc.array<1xf32>
+ // Example with an initial value:
+ emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0>
+ // Example with an initial value and attributes:
+ emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> {
+ emitc.opaque = "input_tensor"}
```
}];
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
+ OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
OptionalAttr<AnyAttr>:$attrs);
- let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
+ let assemblyFormat = [{
+ $sym_name
+ `:` custom<EmitCFieldOpTypeAndInitialValue>($type, $initial_value)
+ attr-dict
+ }];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4c0902293cbf9..0f5f054ed3a66 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
+static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
+ TypeAttr type,
+ Attribute initialValue) {
+ p << type;
+ if (initialValue) {
+ p << " = ";
+ p.printAttributeWithoutType(initialValue);
+ }
+}
+
+static Type getInitializerTypeForField(Type type) {
+ if (auto array = llvm::dyn_cast<ArrayType>(type))
+ return RankedTensorType::get(array.getShape(), array.getElementType());
+ return type;
+}
+
+static ParseResult
+parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &initialValue) {
+ Type type;
+ if (parser.parseType(type))
+ return failure();
+
+ typeAttr = TypeAttr::get(type);
+
+ if (parser.parseOptionalEqual())
+ return success();
+
+ if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
+ return failure();
+
+ if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
+ initialValue))
+ return parser.emitError(parser.getNameLoc())
+ << "initial value should be a integer, float, elements or opaque "
+ "attribute";
+ return success();
+}
+
LogicalResult FieldOp::verify() {
if (!isSupportedEmitCType(getType()))
return emitOpError("expected valid emitc type");
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fa05ad8063b99..79d9d7428e914 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -68,7 +68,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr,
- argAttr);
+ /*initial_value=*/nullptr, argAttr);
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index dcd2e11e83c6a..c4c9a597189a5 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -333,7 +333,8 @@ struct CppEmitter {
/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
- emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
+ emitc::MemberOfPtrOp, emitc::SubscriptOp,
+ emitc::GetFieldOp>(op);
}
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
@@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
raw_ostream &os = emitter.ostream();
- if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
+ if (failed(emitter.emitVariableDeclaration(
+ fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName())))
return failure();
- os << " " << fieldOp.getSymName() << ";";
- return success();
-}
-
-static LogicalResult printOperation(CppEmitter &emitter,
- GetFieldOp getFieldOp) {
- raw_indented_ostream &os = emitter.ostream();
-
- Value result = getFieldOp.getResult();
- if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
- return failure();
- os << " ";
- if (failed(emitter.emitOperand(result)))
- return failure();
- os << " = ";
+ std::optional<Attribute> initialValue = fieldOp.getInitialValue();
+ if (initialValue) {
+ os << " = ";
+ if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue)))
+ return failure();
+ }
- os << getFieldOp.getFieldName().str();
+ os << ";";
return success();
}
@@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
- emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
- emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
- emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
- emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
- emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
- emitc::VerbatimOp>(
+ emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
+ emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
+ emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+ emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
+ emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
@@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
cacheDeferredOpResult(op.getResult(), op.getName());
return success();
})
+ .Case<emitc::GetFieldOp>([&](auto op) {
+ cacheDeferredOpResult(op.getResult(), op.getFieldName());
+ return success();
+ })
.Case<emitc::LiteralOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), op.getValue());
return success();
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 029fa78a3f528..2f6627f234fe7 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -19,9 +19,9 @@ module attributes { } {
// CHECK: module {
// CHECK-NEXT: emitc.class @modelClass {
-// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
-// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
-// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
+// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "another_feature"}}
+// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "some_feature"}}
+// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "output_0"}}
// CHECK-NEXT: emitc.func @execute() {
// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32>
// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32>
diff --git a/mlir/test/mlir-translate/emitc_classops.mlir b/mlir/test/mlir-translate/emitc_classops.mlir
index 4b7ddf4630d55..d5f18ac63dd36 100644
--- a/mlir/test/mlir-translate/emitc_classops.mlir
+++ b/mlir/test/mlir-translate/emitc_classops.mlir
@@ -14,12 +14,10 @@ emitc.class @modelClass {
// CHECK-LABEL: class modelClass {
// CHECK-NEXT: public:
-// CHECK-NEXT: float[1] fieldName0;
-// CHECK-NEXT: float[1] fieldName1;
+// CHECK-NEXT: float fieldName0[1];
+// CHECK-NEXT: float fieldName1[1];
// CHECK-NEXT: void execute() {
// CHECK-NEXT: size_t v1 = 0;
-// CHECK-NEXT: float[1] v2 = fieldName0;
-// CHECK-NEXT: float[1] v3 = fieldName1;
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-EMPTY:
@@ -39,12 +37,27 @@ emitc.class final @finalClass {
// CHECK-LABEL: class finalClass final {
// CHECK-NEXT: public:
-// CHECK-NEXT: float[1] fieldName0;
-// CHECK-NEXT: float[1] fieldName1;
+// CHECK-NEXT: float fieldName0[1];
+// CHECK-NEXT: float fieldName1[1];
// CHECK-NEXT: void execute() {
// CHECK-NEXT: size_t v1 = 0;
-// CHECK-NEXT: float[1] v2 = fieldName0;
-// CHECK-NEXT: float[1] v3 = fieldName1;
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-EMPTY:
+// CHECK-NEXT: };
+
+emitc.class @mainClass {
+ emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
+ emitc.func @get_fieldName0() {
+ %0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
+ return
+ }
+}
+
+// CHECK-LABEL: class mainClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT: float fieldName0[2] = {0.0e+00f, 0.0e+00f};
+// CHECK-NEXT: void get_fieldName0() {
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-EMPTY:
|
// CHECK-NEXT: }; | ||
|
||
emitc.class @mainClass { | ||
emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a test also where the type is e.g. a std::map, with an initializer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, addressed in changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, just a comment.
|
||
static Type getInitializerTypeForField(Type type) { | ||
if (auto array = llvm::dyn_cast<ArrayType>(type)) | ||
return RankedTensorType::get(array.getShape(), array.getElementType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here? (It feels a bit weird to go from ArrayType to RankedTensorType, if this is convention followed elsewhere or just here, good to document)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is convention that is already documented since this is directly from how we have it in GlobalOp
. And we have this:https://mlir.llvm.org/docs/Dialects/EmitC/#pointertype:~:text=If%20tensors%20are%20used%2C%20C%2B%2B%20is%20generated
|
||
if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>( | ||
initialValue)) | ||
return parser.emitError(parser.getNameLoc()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the difference in output if emitError() is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use emitError()
``` | ||
}]; | ||
|
||
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, | ||
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value, | ||
OptionalAttr<AnyAttr>:$attrs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having an explicit attrs attribute is different from attr-dict which is something all ops have. They are called discardable attributes (attr-dict), but they aren't discarded randomly, it mostly means "not semantically meaningful to the op". Here I think you are actually using it for a short lived ID between two passes you control (or maybe not even more?), so probably good to remove as it doesn't carry any semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, thanks for the pointer!
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} | ||
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} | ||
// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} | ||
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "another_feature"}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And if you remove above, then this gets reset to what it was before (here it ended up that it is needed as the two things in your asm is different and this resulted in populating the one you wanted, but can also just reference the Operation*'s attributes for the name_hint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, thanks for the pointer!
// CHECK-NEXT: size_t v1 = 0; | ||
// CHECK-NEXT: float[1] v2 = fieldName0; | ||
// CHECK-NEXT: float[1] v3 = fieldName1; | ||
// CHECK-NEXT: return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These seem to go away, but I'm not sure what in this change causes that, given that you didn't change the input IR w/ an initial value. I can see that you're changing how the fields are printed, but that doesn't seem to be related to these statements. What part of this changes removed these? Should these assignments have been removed(I assume yes)? Do you need some calculation in execute()
to keep those alive, or show the usage of the fields (e.g. float x = v3[0] + v2[0]
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They go away because they aren't initialized.
The change in this patch ensures that we only emit get_field
in the case that we initialized the field.
This is in line with how get_global
works.
Yes to the two last questions!
// CHECK-EMPTY: | ||
// CHECK-NEXT: }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to enforce the empty line? I'd assume you'd be fine if the output was fixed to not emit that.
// CHECK-EMPTY: | |
// CHECK-NEXT: }; | |
// CHECK: }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, thanks for the pointer!
This will ensure that:
field
of a class can have an initial valuefield
op is emitted correctlygetfield
op is emitted correctly