Skip to content

Commit 5949f45

Browse files
authored
[mlir][EmitC]Expand the MemRefToEmitC pass - Lowering AllocOp (#148257)
This aims to lower `memref.alloc` to `emitc.call_opaque “malloc” ` or `emitc.call_opaque “aligned_alloc” ` From: ``` module{ func.func @allocating() { %alloc_5 = memref.alloc() : memref<999xi32> return } } ``` To: ``` module { emitc.include <"stdlib.h"> func.func @allocating() { %0 = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t %1 = "emitc.constant"() <{value = 999 : index}> : () -> index %2 = emitc.mul %0, %1 : (!emitc.size_t, index) -> !emitc.size_t %3 = emitc.call_opaque "malloc"(%2) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">> %4 = emitc.cast %3 : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32> return } } ``` Which is then translated as: ``` #include <stdlib.h> void allocating() { size_t v1 = sizeof(int32_t); size_t v2 = 999; size_t v3 = v1 * v2; void* v4 = malloc(v3); int32_t* v5 = (int32_t*) v4; return; } ```
1 parent 0d05e55 commit 5949f45

File tree

6 files changed

+192
-7
lines changed

6 files changed

+192
-7
lines changed

mlir/docs/Dialects/emitc.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The following convention is followed:
1818
GCC or Clang.
1919
* If `emitc.array` with a dimension of size zero is used, then the code
2020
requires [a GCC extension](https://gcc.gnu.org/onlinedocs/gcc/Zero-Length.html).
21+
* If `aligned_alloc` is passed to an `emitc.call_opaque` operation, then C++17
22+
or C11 is required.
2123
* Else the generated code is compatible with C99.
2224

2325
These restrictions are neither inherent to the EmitC dialect itself nor to the

mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
99
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
1010

11+
constexpr const char *alignedAllocFunctionName = "aligned_alloc";
12+
constexpr const char *mallocFunctionName = "malloc";
13+
constexpr const char *cppStandardLibraryHeader = "cstdlib";
14+
constexpr const char *cStandardLibraryHeader = "stdlib.h";
15+
1116
namespace mlir {
1217
class DialectRegistry;
1318
class RewritePatternSet;

mlir/include/mlir/Conversion/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,13 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
841841
// MemRefToEmitC
842842
//===----------------------------------------------------------------------===//
843843

844-
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
844+
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc", "ModuleOp"> {
845845
let summary = "Convert MemRef dialect to EmitC dialect";
846846
let dependentDialects = ["emitc::EmitCDialect"];
847+
let options = [Option<
848+
"lowerToCpp", "lower-to-cpp", "bool",
849+
/*default=*/"false",
850+
/*description=*/"Target C++ (true) instead of C (false)">];
847851
}
848852

849853
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,18 @@
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/IR/TypeRange.h"
22+
#include "mlir/IR/Value.h"
2223
#include "mlir/Transforms/DialectConversion.h"
24+
#include <cstdint>
2325

2426
using namespace mlir;
2527

28+
static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
29+
return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
30+
memRefType.getRank() != 0 &&
31+
!llvm::is_contained(memRefType.getShape(), 0);
32+
}
33+
2634
namespace {
2735
/// Implement the interface to convert MemRef to EmitC.
2836
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
8997
return resultTy;
9098
}
9199

100+
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
101+
using OpConversionPattern::OpConversionPattern;
102+
LogicalResult
103+
matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
104+
ConversionPatternRewriter &rewriter) const override {
105+
Location loc = allocOp.getLoc();
106+
MemRefType memrefType = allocOp.getType();
107+
if (!isMemRefTypeLegalForEmitC(memrefType)) {
108+
return rewriter.notifyMatchFailure(
109+
loc, "incompatible memref type for EmitC conversion");
110+
}
111+
112+
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
113+
Type elementType = memrefType.getElementType();
114+
IndexType indexType = rewriter.getIndexType();
115+
emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
116+
loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
117+
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
118+
119+
int64_t numElements = 1;
120+
for (int64_t dimSize : memrefType.getShape()) {
121+
numElements *= dimSize;
122+
}
123+
Value numElementsValue = rewriter.create<emitc::ConstantOp>(
124+
loc, indexType, rewriter.getIndexAttr(numElements));
125+
126+
Value totalSizeBytes = rewriter.create<emitc::MulOp>(
127+
loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
128+
129+
emitc::CallOpaqueOp allocCall;
130+
StringAttr allocFunctionName;
131+
Value alignmentValue;
132+
SmallVector<Value, 2> argsVec;
133+
if (allocOp.getAlignment()) {
134+
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
135+
alignmentValue = rewriter.create<emitc::ConstantOp>(
136+
loc, sizeTType,
137+
rewriter.getIntegerAttr(indexType,
138+
allocOp.getAlignment().value_or(0)));
139+
argsVec.push_back(alignmentValue);
140+
} else {
141+
allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
142+
}
143+
144+
argsVec.push_back(totalSizeBytes);
145+
ValueRange args(argsVec);
146+
147+
allocCall = rewriter.create<emitc::CallOpaqueOp>(
148+
loc,
149+
emitc::PointerType::get(
150+
emitc::OpaqueType::get(rewriter.getContext(), "void")),
151+
allocFunctionName, args);
152+
153+
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
154+
emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
155+
loc, targetPointerType, allocCall.getResult(0));
156+
157+
rewriter.replaceOp(allocOp, castOp);
158+
return success();
159+
}
160+
};
161+
92162
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
93163
using OpConversionPattern::OpConversionPattern;
94164

@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
223293
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
224294
typeConverter.addConversion(
225295
[&](MemRefType memRefType) -> std::optional<Type> {
226-
if (!memRefType.hasStaticShape() ||
227-
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
228-
llvm::is_contained(memRefType.getShape(), 0)) {
296+
if (!isMemRefTypeLegalForEmitC(memRefType)) {
229297
return {};
230298
}
231299
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
252320

253321
void mlir::populateMemRefToEmitCConversionPatterns(
254322
RewritePatternSet &patterns, const TypeConverter &converter) {
255-
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
256-
ConvertStore>(converter, patterns.getContext());
323+
patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
324+
ConvertLoad, ConvertStore>(converter, patterns.getContext());
257325
}

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
1616
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/IR/Attributes.h"
1819
#include "mlir/Pass/Pass.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021

@@ -28,9 +29,11 @@ using namespace mlir;
2829
namespace {
2930
struct ConvertMemRefToEmitCPass
3031
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
32+
using Base::Base;
3133
void runOnOperation() override {
3234
TypeConverter converter;
33-
35+
ConvertMemRefToEmitCOptions options;
36+
options.lowerToCpp = this->lowerToCpp;
3437
// Fallback for other types.
3538
converter.addConversion([](Type type) -> std::optional<Type> {
3639
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
5053
if (failed(applyPartialConversion(getOperation(), target,
5154
std::move(patterns))))
5255
return signalPassFailure();
56+
57+
mlir::ModuleOp module = getOperation();
58+
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
59+
if (callOp.getCallee() != alignedAllocFunctionName &&
60+
callOp.getCallee() != mallocFunctionName) {
61+
return mlir::WalkResult::advance();
62+
}
63+
64+
for (auto &op : *module.getBody()) {
65+
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
66+
if (!includeOp) {
67+
continue;
68+
}
69+
if (includeOp.getIsStandardInclude() &&
70+
((options.lowerToCpp &&
71+
includeOp.getInclude() == cppStandardLibraryHeader) ||
72+
(!options.lowerToCpp &&
73+
includeOp.getInclude() == cStandardLibraryHeader))) {
74+
return mlir::WalkResult::interrupt();
75+
}
76+
}
77+
78+
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
79+
StringAttr includeAttr =
80+
builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
81+
: cStandardLibraryHeader);
82+
builder.create<mlir::emitc::IncludeOp>(
83+
module.getLoc(), includeAttr,
84+
/*is_standard_include=*/builder.getUnitAttr());
85+
return mlir::WalkResult::interrupt();
86+
});
5387
}
5488
};
5589
} // namespace
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP
2+
// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP
3+
4+
func.func @alloc() {
5+
%alloc = memref.alloc() : memref<999xi32>
6+
return
7+
}
8+
9+
// CPP: module {
10+
// CPP-NEXT: emitc.include <"cstdlib">
11+
// CPP-LABEL: alloc()
12+
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
13+
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
14+
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
15+
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
16+
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
17+
// CPP-NEXT: return
18+
19+
// NOCPP: module {
20+
// NOCPP-NEXT: emitc.include <"stdlib.h">
21+
// NOCPP-LABEL: alloc()
22+
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
23+
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
24+
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
25+
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
26+
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
27+
// NOCPP-NEXT: return
28+
29+
func.func @alloc_aligned() {
30+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<999xf32>
31+
return
32+
}
33+
34+
// CPP-LABEL: alloc_aligned
35+
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
36+
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
37+
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
38+
// CPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
39+
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
40+
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
41+
// CPP-NEXT: return
42+
43+
// NOCPP-LABEL: alloc_aligned
44+
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
45+
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
46+
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
47+
// NOCPP-NEXT: %[[ALIGNMENT:.*]] = "emitc.constant"() <{value = 64 : index}> : () -> !emitc.size_t
48+
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "aligned_alloc"(%[[ALIGNMENT]], %[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t, !emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
49+
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<f32>
50+
// NOCPP-NEXT: return
51+
52+
func.func @allocating_multi() {
53+
%alloc_5 = memref.alloc() : memref<7x999xi32>
54+
return
55+
}
56+
57+
// CPP-LABEL: allocating_multi
58+
// CPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
59+
// CPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
60+
// CPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
61+
// CPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">
62+
// CPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
63+
// CPP-NEXT: return
64+
65+
// NOCPP-LABEL: allocating_multi
66+
// NOCPP-NEXT: %[[ALLOC:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
67+
// NOCPP-NEXT: %[[ALLOC_SIZE:.*]] = "emitc.constant"() <{value = 6993 : index}> : () -> index
68+
// NOCPP-NEXT: %[[ALLOC_TOTAL_SIZE:.*]] = emitc.mul %[[ALLOC]], %[[ALLOC_SIZE]] : (!emitc.size_t, index) -> !emitc.size_t
69+
// NOCPP-NEXT: %[[ALLOC_PTR:.*]] = emitc.call_opaque "malloc"(%[[ALLOC_TOTAL_SIZE]]) : (!emitc.size_t) -> !emitc.ptr<!emitc.opaque<"void">>
70+
// NOCPP-NEXT: %[[ALLOC_CAST:.*]] = emitc.cast %[[ALLOC_PTR]] : !emitc.ptr<!emitc.opaque<"void">> to !emitc.ptr<i32>
71+
// NOCPP-NEXT: return
72+

0 commit comments

Comments
 (0)