Skip to content

[mlir][EmitC]Allow Fields to have initial values #151437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 1, 2025

Conversation

Jaddyen
Copy link
Contributor

@Jaddyen Jaddyen commented Jul 31, 2025

This will ensure that:

  • The field of a class can have an initial value
  • The field op is emitted correctly
  • The getfield op is emitted correctly

@Jaddyen Jaddyen requested review from jpienaar, mtrofin and ilovepi July 31, 2025 02:34
@Jaddyen Jaddyen changed the title [mlir][EmitC] [mlir][EmitC]Allow Fields to have initial values Jul 31, 2025
@Jaddyen Jaddyen marked this pull request as ready for review July 31, 2025 02:39
@llvmbot
Copy link
Member

llvmbot commented Jul 31, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-emitc

Author: Jaden Angella (Jaddyen)

Changes

This will ensure that:

  • The field of a class can have an initial value
  • The field op is emitted correctly
  • The getfield op is emitted correctly

Full diff: https://github.com/llvm/llvm-project/pull/151437.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+11-1)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+39)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp (+1-1)
  • (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+20-24)
  • (modified) mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir (+3-3)
  • (modified) mlir/test/mlir-translate/emitc_classops.mlir (+21-8)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7fe2da8f7e044..f06a24c61454a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1659,13 +1659,23 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
     emitc.field @fieldName0 : !emitc.array<1xf32>  {emitc.opaque = "another_feature"}
     // Example with no attribute:
     emitc.field @fieldName0 : !emitc.array<1xf32>
+    // Example with an initial value:
+    emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0>
+    // Example with an initial value and attributes:
+    emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> {
+      emitc.opaque = "input_tensor"}
     ```
   }];
 
   let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
+      OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
       OptionalAttr<AnyAttr>:$attrs);
 
-  let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
+  let assemblyFormat = [{
+       $sym_name
+       `:` custom<EmitCFieldOpTypeAndInitialValue>($type, $initial_value)
+       attr-dict
+  }];
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4c0902293cbf9..0f5f054ed3a66 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
 //===----------------------------------------------------------------------===//
 // FieldOp
 //===----------------------------------------------------------------------===//
+static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
+                                                 TypeAttr type,
+                                                 Attribute initialValue) {
+  p << type;
+  if (initialValue) {
+    p << " = ";
+    p.printAttributeWithoutType(initialValue);
+  }
+}
+
+static Type getInitializerTypeForField(Type type) {
+  if (auto array = llvm::dyn_cast<ArrayType>(type))
+    return RankedTensorType::get(array.getShape(), array.getElementType());
+  return type;
+}
+
+static ParseResult
+parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+                                     Attribute &initialValue) {
+  Type type;
+  if (parser.parseType(type))
+    return failure();
+
+  typeAttr = TypeAttr::get(type);
+
+  if (parser.parseOptionalEqual())
+    return success();
+
+  if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
+    return failure();
+
+  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
+          initialValue))
+    return parser.emitError(parser.getNameLoc())
+           << "initial value should be a integer, float, elements or opaque "
+              "attribute";
+  return success();
+}
+
 LogicalResult FieldOp::verify() {
   if (!isSupportedEmitCType(getType()))
     return emitOpError("expected valid emitc type");
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fa05ad8063b99..79d9d7428e914 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -68,7 +68,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       TypeAttr typeAttr = TypeAttr::get(val.getType());
       fields.push_back({fieldName, typeAttr});
       emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr,
-                             argAttr);
+                             /*initial_value=*/nullptr, argAttr);
     }
 
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index dcd2e11e83c6a..c4c9a597189a5 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -333,7 +333,8 @@ struct CppEmitter {
 /// Determine whether expression \p op should be emitted in a deferred way.
 static bool hasDeferredEmission(Operation *op) {
   return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
-                         emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
+                         emitc::MemberOfPtrOp, emitc::SubscriptOp,
+                         emitc::GetFieldOp>(op);
 }
 
 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
@@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
 
 static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
   raw_ostream &os = emitter.ostream();
-  if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
+  if (failed(emitter.emitVariableDeclaration(
+          fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName())))
     return failure();
-  os << " " << fieldOp.getSymName() << ";";
-  return success();
-}
-
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    GetFieldOp getFieldOp) {
-  raw_indented_ostream &os = emitter.ostream();
-
-  Value result = getFieldOp.getResult();
-  if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
-    return failure();
-  os << " ";
-  if (failed(emitter.emitOperand(result)))
-    return failure();
-  os << " = ";
+  std::optional<Attribute> initialValue = fieldOp.getInitialValue();
+  if (initialValue) {
+    os << " = ";
+    if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue)))
+      return failure();
+  }
 
-  os << getFieldOp.getFieldName().str();
+  os << ";";
   return success();
 }
 
@@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
                 emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
                 emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
-                emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
-                emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
-                emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
-                emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
-                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
-                emitc::VerbatimOp>(
+                emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
+                emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
+                emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
+                emitc::VariableOp, emitc::VerbatimOp>(
 
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
@@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
             cacheDeferredOpResult(op.getResult(), op.getName());
             return success();
           })
+          .Case<emitc::GetFieldOp>([&](auto op) {
+            cacheDeferredOpResult(op.getResult(), op.getFieldName());
+            return success();
+          })
           .Case<emitc::LiteralOp>([&](auto op) {
             cacheDeferredOpResult(op.getResult(), op.getValue());
             return success();
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 029fa78a3f528..2f6627f234fe7 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -19,9 +19,9 @@ module attributes { } {
 
 // CHECK: module {
 // CHECK-NEXT:   emitc.class @modelClass {
-// CHECK-NEXT:     emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
-// CHECK-NEXT:     emitc.field @fieldName1 : !emitc.array<1xf32>  {emitc.name_hint = "some_feature"}
-// CHECK-NEXT:     emitc.field @fieldName2 : !emitc.array<1xf32>  {emitc.name_hint = "output_0"}
+// CHECK-NEXT:     emitc.field @fieldName0 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "another_feature"}}
+// CHECK-NEXT:     emitc.field @fieldName1 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "some_feature"}}
+// CHECK-NEXT:     emitc.field @fieldName2 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "output_0"}}
 // CHECK-NEXT:     emitc.func @execute() {
 // CHECK-NEXT:       get_field @fieldName0 : !emitc.array<1xf32>
 // CHECK-NEXT:       get_field @fieldName1 : !emitc.array<1xf32>
diff --git a/mlir/test/mlir-translate/emitc_classops.mlir b/mlir/test/mlir-translate/emitc_classops.mlir
index 4b7ddf4630d55..d5f18ac63dd36 100644
--- a/mlir/test/mlir-translate/emitc_classops.mlir
+++ b/mlir/test/mlir-translate/emitc_classops.mlir
@@ -14,12 +14,10 @@ emitc.class @modelClass {
 
 // CHECK-LABEL: class modelClass {
 // CHECK-NEXT: public:
-// CHECK-NEXT:  float[1] fieldName0;
-// CHECK-NEXT:  float[1] fieldName1;
+// CHECK-NEXT:  float fieldName0[1];
+// CHECK-NEXT:  float fieldName1[1];
 // CHECK-NEXT:  void execute() {
 // CHECK-NEXT:    size_t v1 = 0;
-// CHECK-NEXT:    float[1] v2 = fieldName0;
-// CHECK-NEXT:    float[1] v3 = fieldName1;
 // CHECK-NEXT:    return;
 // CHECK-NEXT:  }
 // CHECK-EMPTY:
@@ -39,12 +37,27 @@ emitc.class final @finalClass {
 
 // CHECK-LABEL: class finalClass final {
 // CHECK-NEXT: public:
-// CHECK-NEXT:  float[1] fieldName0;
-// CHECK-NEXT:  float[1] fieldName1;
+// CHECK-NEXT:  float fieldName0[1];
+// CHECK-NEXT:  float fieldName1[1];
 // CHECK-NEXT:  void execute() {
 // CHECK-NEXT:    size_t v1 = 0;
-// CHECK-NEXT:    float[1] v2 = fieldName0;
-// CHECK-NEXT:    float[1] v3 = fieldName1;
+// CHECK-NEXT:    return;
+// CHECK-NEXT:  }
+// CHECK-EMPTY:
+// CHECK-NEXT: };
+
+emitc.class @mainClass {
+  emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
+  emitc.func @get_fieldName0() {
+    %0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
+    return 
+  }
+}
+
+// CHECK-LABEL: class mainClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT:  float fieldName0[2] = {0.0e+00f, 0.0e+00f};
+// CHECK-NEXT:  void get_fieldName0() {
 // CHECK-NEXT:    return;
 // CHECK-NEXT:  }
 // CHECK-EMPTY:

// CHECK-NEXT: };

emitc.class @mainClass {
emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test also where the type is e.g. a std::map, with an initializer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, addressed in changes.

Copy link
Member

@mtrofin mtrofin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, just a comment.


static Type getInitializerTypeForField(Type type) {
if (auto array = llvm::dyn_cast<ArrayType>(type))
return RankedTensorType::get(array.getShape(), array.getElementType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here? (It feels a bit weird to go from ArrayType to RankedTensorType, if this is convention followed elsewhere or just here, good to document)

Copy link
Contributor Author

@Jaddyen Jaddyen Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is convention that is already documented since this is directly from how we have it in GlobalOp. And we have this:https://mlir.llvm.org/docs/Dialects/EmitC/#pointertype:~:text=If%20tensors%20are%20used%2C%20C%2B%2B%20is%20generated


if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
initialValue))
return parser.emitError(parser.getNameLoc())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference in output if emitError() is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use emitError()

```
}];

let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
OptionalAttr<AnyAttr>:$attrs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having an explicit attrs attribute is different from attr-dict which is something all ops have. They are called discardable attributes (attr-dict), but they aren't discarded randomly, it mostly means "not semantically meaningful to the op". Here I think you are actually using it for a short lived ID between two passes you control (or maybe not even more?), so probably good to remove as it doesn't carry any semantics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the pointer!

// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {attrs = {emitc.name_hint = "another_feature"}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if you remove above, then this gets reset to what it was before (here it ended up that it is needed as the two things in your asm is different and this resulted in populating the one you wanted, but can also just reference the Operation*'s attributes for the name_hint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the pointer!

Comment on lines 43 to 44
// CHECK-NEXT: size_t v1 = 0;
// CHECK-NEXT: float[1] v2 = fieldName0;
// CHECK-NEXT: float[1] v3 = fieldName1;
// CHECK-NEXT: return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to go away, but I'm not sure what in this change causes that, given that you didn't change the input IR w/ an initial value. I can see that you're changing how the fields are printed, but that doesn't seem to be related to these statements. What part of this changes removed these? Should these assignments have been removed(I assume yes)? Do you need some calculation in execute() to keep those alive, or show the usage of the fields (e.g. float x = v3[0] + v2[0])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They go away because they aren't initialized.
The change in this patch ensures that we only emit get_field in the case that we initialized the field.
This is in line with how get_global works.
Yes to the two last questions!

Comment on lines 63 to 64
// CHECK-EMPTY:
// CHECK-NEXT: };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to enforce the empty line? I'd assume you'd be fine if the output was fixed to not emit that.

Suggested change
// CHECK-EMPTY:
// CHECK-NEXT: };
// CHECK: };

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the pointer!

@Jaddyen Jaddyen requested a review from ilovepi August 1, 2025 00:51
@Jaddyen Jaddyen merged commit 149d4b5 into llvm:main Aug 1, 2025
9 checks passed
@Jaddyen Jaddyen deleted the emit-fieldop-getfieldop branch August 1, 2025 22:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants