Skip to content

[mlir][EmitC]Allow Fields to have initial values #151437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1659,13 +1659,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"}
// Example with no attribute:
emitc.field @fieldName0 : !emitc.array<1xf32>
// Example with an initial value:
emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0>
// Example with an initial value and attributes:
emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> {
emitc.opaque = "input_tensor"}
```
}];

let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
OptionalAttr<AnyAttr>:$attrs);
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value);

let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
let assemblyFormat = [{
$sym_name
`:` custom<EmitCFieldOpTypeAndInitialValue>($type, $initial_value)
attr-dict
}];

let hasVerifier = 1;
}
Expand All @@ -1686,7 +1695,7 @@ def EmitC_GetFieldOp
}];

let arguments = (ins FlatSymbolRefAttr:$field_name);
let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
let results = (outs EmitCType:$result);
let assemblyFormat = "$field_name `:` type($result) attr-dict";
}

Expand Down
42 changes: 39 additions & 3 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
TypeAttr type,
Attribute initialValue) {
p << type;
if (initialValue) {
p << " = ";
p.printAttributeWithoutType(initialValue);
}
}

static Type getInitializerTypeForField(Type type) {
if (auto array = llvm::dyn_cast<ArrayType>(type))
return RankedTensorType::get(array.getShape(), array.getElementType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here? (It feels a bit weird to go from ArrayType to RankedTensorType, if this is convention followed elsewhere or just here, good to document)

Copy link
Contributor Author

@Jaddyen Jaddyen Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is convention that is already documented since this is directly from how we have it in GlobalOp. And we have this:https://mlir.llvm.org/docs/Dialects/EmitC/#pointertype:~:text=If%20tensors%20are%20used%2C%20C%2B%2B%20is%20generated

return type;
}

static ParseResult
parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &initialValue) {
Type type;
if (parser.parseType(type))
return failure();

typeAttr = TypeAttr::get(type);

if (parser.parseOptionalEqual())
return success();

if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
return failure();

if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
initialValue))
return parser.emitError(parser.getNameLoc())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference in output if emitError() is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use emitError()

<< "initial value should be a integer, float, elements or opaque "
"attribute";
return success();
}

LogicalResult FieldOp::verify() {
if (!isSupportedEmitCType(getType()))
return emitOpError("expected valid emitc type");
Expand All @@ -1410,9 +1449,6 @@ LogicalResult FieldOp::verify() {
if (!symName || symName.getValue().empty())
return emitOpError("field must have a non-empty symbol name");

if (!getAttrs())
return success();

return success();
}

Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {

auto argAttrs = funcOp.getArgAttrs();
for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
StringAttr fieldName;
Attribute argAttr = nullptr;

fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
if (argAttrs && idx < argAttrs->size())
argAttr = (*argAttrs)[idx];
StringAttr fieldName =
rewriter.getStringAttr("fieldName" + std::to_string(idx));

TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr,
argAttr);

FieldOp fieldop = rewriter.create<emitc::FieldOp>(
funcOp->getLoc(), fieldName, typeAttr, nullptr);

if (argAttrs && idx < argAttrs->size()) {
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
}
}

rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
Expand Down
48 changes: 22 additions & 26 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ struct CppEmitter {
/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
emitc::MemberOfPtrOp, emitc::SubscriptOp,
emitc::GetFieldOp>(op);
}

/// Determine whether expression \p expressionOp should be emitted inline, i.e.
Expand Down Expand Up @@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {

static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
raw_ostream &os = emitter.ostream();
if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
if (failed(emitter.emitVariableDeclaration(
fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName())))
return failure();
os << " " << fieldOp.getSymName() << ";";
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
GetFieldOp getFieldOp) {
raw_indented_ostream &os = emitter.ostream();

Value result = getFieldOp.getResult();
if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
return failure();
os << " ";
if (failed(emitter.emitOperand(result)))
return failure();
os << " = ";
std::optional<Attribute> initialValue = fieldOp.getInitialValue();
if (initialValue) {
os << " = ";
if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue)))
return failure();
}

os << getFieldOp.getFieldName().str();
os << ";";
return success();
}

Expand Down Expand Up @@ -1204,7 +1197,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
os << "}\n";
os << "}";

return success();
}
Expand Down Expand Up @@ -1245,7 +1238,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
os << "}\n";
os << "}";

return success();
}
Expand Down Expand Up @@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
emitc::VariableOp, emitc::VerbatimOp>(

[&](auto op) { return printOperation(*this, op); })
// Func ops.
Expand All @@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
cacheDeferredOpResult(op.getResult(), op.getName());
return success();
})
.Case<emitc::GetFieldOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), op.getFieldName());
return success();
})
.Case<emitc::LiteralOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), op.getValue());
return success();
Expand Down
47 changes: 37 additions & 10 deletions mlir/test/mlir-translate/emitc_classops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@ emitc.class @modelClass {

// CHECK-LABEL: class modelClass {
// CHECK-NEXT: public:
// CHECK-NEXT: float[1] fieldName0;
// CHECK-NEXT: float[1] fieldName1;
// CHECK-NEXT: float fieldName0[1];
// CHECK-NEXT: float fieldName1[1];
// CHECK-NEXT: void execute() {
// CHECK-NEXT: size_t v1 = 0;
// CHECK-NEXT: float[1] v2 = fieldName0;
// CHECK-NEXT: float[1] v3 = fieldName1;
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-EMPTY:
// CHECK-NEXT: };

emitc.class final @finalClass {
Expand All @@ -39,13 +36,43 @@ emitc.class final @finalClass {

// CHECK-LABEL: class finalClass final {
// CHECK-NEXT: public:
// CHECK-NEXT: float[1] fieldName0;
// CHECK-NEXT: float[1] fieldName1;
// CHECK-NEXT: float fieldName0[1];
// CHECK-NEXT: float fieldName1[1];
// CHECK-NEXT: void execute() {
// CHECK-NEXT: size_t v1 = 0;
// CHECK-NEXT: float[1] v2 = fieldName0;
// CHECK-NEXT: float[1] v3 = fieldName1;
// CHECK-NEXT: return;
Comment on lines 42 to 43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to go away, but I'm not sure what in this change causes that, given that you didn't change the input IR w/ an initial value. I can see that you're changing how the fields are printed, but that doesn't seem to be related to these statements. What part of this changes removed these? Should these assignments have been removed(I assume yes)? Do you need some calculation in execute() to keep those alive, or show the usage of the fields (e.g. float x = v3[0] + v2[0])?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They go away because they aren't initialized.
The change in this patch ensures that we only emit get_field in the case that we initialized the field.
This is in line with how get_global works.
Yes to the two last questions!

// CHECK-NEXT: }
// CHECK-EMPTY:
// CHECK-NEXT: };

emitc.class @mainClass {
emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test also where the type is e.g. a std::map, with an initializer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, addressed in changes.

emitc.func @get_fieldName0() {
%0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
return
}
}

// CHECK-LABEL: class mainClass {
// CHECK-NEXT: public:
// CHECK-NEXT: float fieldName0[2] = {0.0e+00f, 0.0e+00f};
// CHECK-NEXT: void get_fieldName0() {
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-NEXT: };

emitc.class @reflectionClass {
emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> = #emitc.opaque<"{ { \22another_feature\22, \22fieldName0\22 } }">
emitc.func @get_reflectionMap() {
%0 = emitc.get_field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>">
return
}
}

// CHECK-LABEL: class reflectionClass {
// CHECK-NEXT: public:
// CHECK-NEXT: const std::map<std::string, std::string> reflectionMap = { { "another_feature", "fieldName0" } };
// CHECK-NEXT: void get_reflectionMap() {
// CHECK-NEXT: return;
// CHECK-NEXT: }
// CHECK-NEXT: };