From 3768def1fc42b74efd8517add59f8b32b00e895a Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Fri, 25 Jul 2025 16:43:57 +0000 Subject: [PATCH 1/5] Adding lowering for copyop --- .../MemRefToEmitC/MemRefToEmitC.cpp | 60 +++++++++++++++++++ .../MemRefToEmitC/memref-to-emitc-copy.mlir | 25 ++++++++ 2 files changed, 85 insertions(+) create mode 100644 mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d4d4b08..34ea4989c8156 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -87,6 +87,66 @@ struct ConvertAlloca final : public OpConversionPattern { } }; +struct ConvertCopy final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = copyOp.getLoc(); + auto srcMemrefType = dyn_cast(copyOp.getSource().getType()); + auto targetMemrefType = dyn_cast(copyOp.getTarget().getType()); + + if (!srcMemrefType || !targetMemrefType) { + return failure(); + } + + // 1. Cast source memref to a pointer. + auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType()); + auto srcArrayValue = + dyn_cast>(operands.getSource()); + auto stcArrayPtr = + emitc::PointerType::get(srcArrayValue.getType().getElementType()); + auto srcPtr = rewriter.create(loc, srcPtrType, + stcArrayPtr.getPointee()); + + // 2. Cast target memref to a pointer. + auto targetPtrType = + emitc::PointerType::get(targetMemrefType.getElementType()); + + auto arrayValue = + dyn_cast>(operands.getTarget()); + + // Cast the target memref value to a pointer type. + auto targetPtr = + rewriter.create(loc, targetPtrType, arrayValue); + + // 3. Calculate the size in bytes of the memref. + auto elementSize = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"), + mlir::ValueRange{}, + mlir::ArrayAttr::get( + rewriter.getContext(), + {mlir::TypeAttr::get(srcMemrefType.getElementType())})); + + auto numElements = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), + srcMemrefType.getNumElements())); + auto byteSize = rewriter.create(loc, rewriter.getIndexType(), + elementSize.getResult(0), + numElements.getResult()); + + // 4. Emit the memcpy call. + rewriter.create(loc, TypeRange{}, "memcpy", + ValueRange{targetPtr.getResult(), + srcPtr.getResult(), + byteSize.getResult()}); + + return success(); + } +}; + Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { Type resultTy; if (opTy.getRank() == 0) { diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir new file mode 100644 index 0000000000000..d031d60508df2 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s + +func.func @copying(%arg0 : memref<2x4xf32>) { + memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> + return +} + +// func.func @copying_memcpy(%arg_0: !emitc.ptr) { +// %size = "emitc.constant"() <{value = 8 : index}> :() -> index +// %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index +// %total_bytes = emitc.mul %size, %element_size : (index, index) -> index + +// emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr, !emitc.ptr, index) -> () +// return +// } + +// CHECK-LABEL: copying_memcpy +// CHECK-SAME: %arg_0: !emitc.ptr +// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index +// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index +// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index +// CHECK-NEXT: emitc.call_opaque "memcpy" +// CHECK-SAME: (%arg_0, %arg_0, %total_bytes) +// CHECK-NEXT: : (!emitc.ptr, !emitc.ptr, index) -> () +// CHECK-NEXT: return \ No newline at end of file From bc4dd4fafacfdee0cd3d2bf3f4b8a9c2831cb8fa Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Tue, 29 Jul 2025 18:14:39 +0000 Subject: [PATCH 2/5] use subscript and apply ops --- .../Conversion/MemRefToEmitC/MemRefToEmitC.h | 2 + .../MemRefToEmitC/MemRefToEmitC.cpp | 137 ++++++++++-------- .../MemRefToEmitC/MemRefToEmitCPass.cpp | 27 +++- .../MemRefToEmitC/memref-to-emitc-copy.mlir | 35 +++-- .../MemRefToEmitC/memref-to-emitc-failed.mlir | 8 - 5 files changed, 114 insertions(+), 95 deletions(-) diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index b595b6a308bea..4ea6649d64a92 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -10,8 +10,10 @@ constexpr const char *alignedAllocFunctionName = "aligned_alloc"; constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *memcpyFunctionName = "memcpy"; constexpr const char *cppStandardLibraryHeader = "cstdlib"; constexpr const char *cStandardLibraryHeader = "stdlib.h"; +constexpr const char *stringLibraryHeader = "string.h"; namespace mlir { class DialectRegistry; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 34ea4989c8156..adb0eb77fdf35 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -87,66 +87,6 @@ struct ConvertAlloca final : public OpConversionPattern { } }; -struct ConvertCopy final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, - ConversionPatternRewriter &rewriter) const override { - Location loc = copyOp.getLoc(); - auto srcMemrefType = dyn_cast(copyOp.getSource().getType()); - auto targetMemrefType = dyn_cast(copyOp.getTarget().getType()); - - if (!srcMemrefType || !targetMemrefType) { - return failure(); - } - - // 1. Cast source memref to a pointer. - auto srcPtrType = emitc::PointerType::get(srcMemrefType.getElementType()); - auto srcArrayValue = - dyn_cast>(operands.getSource()); - auto stcArrayPtr = - emitc::PointerType::get(srcArrayValue.getType().getElementType()); - auto srcPtr = rewriter.create(loc, srcPtrType, - stcArrayPtr.getPointee()); - - // 2. Cast target memref to a pointer. - auto targetPtrType = - emitc::PointerType::get(targetMemrefType.getElementType()); - - auto arrayValue = - dyn_cast>(operands.getTarget()); - - // Cast the target memref value to a pointer type. - auto targetPtr = - rewriter.create(loc, targetPtrType, arrayValue); - - // 3. Calculate the size in bytes of the memref. - auto elementSize = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getStringAttr("sizeof"), - mlir::ValueRange{}, - mlir::ArrayAttr::get( - rewriter.getContext(), - {mlir::TypeAttr::get(srcMemrefType.getElementType())})); - - auto numElements = rewriter.create( - loc, rewriter.getIndexType(), - rewriter.getIntegerAttr(rewriter.getIndexType(), - srcMemrefType.getNumElements())); - auto byteSize = rewriter.create(loc, rewriter.getIndexType(), - elementSize.getResult(0), - numElements.getResult()); - - // 4. Emit the memcpy call. - rewriter.create(loc, TypeRange{}, "memcpy", - ValueRange{targetPtr.getResult(), - srcPtr.getResult(), - byteSize.getResult()}); - - return success(); - } -}; - Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { Type resultTy; if (opTy.getRank() == 0) { @@ -157,6 +97,29 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, + ConversionPatternRewriter &rewriter) { + emitc::CallOpaqueOp elementSize = rewriter.create( + loc, emitc::SizeTType::get(rewriter.getContext()), + rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), + {TypeAttr::get(memrefType.getElementType())})); + + IndexType indexType = rewriter.getIndexType(); + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + emitc::ConstantOp numElementsValue = rewriter.create( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + emitc::MulOp totalSizeBytes = rewriter.create( + loc, sizeTType, elementSize.getResult(0), numElementsValue); + + return totalSizeBytes.getResult(); +} + struct ConvertAlloc final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -219,6 +182,55 @@ struct ConvertAlloc final : public OpConversionPattern { } }; +struct ConvertCopy final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = copyOp.getLoc(); + MemRefType srcMemrefType = + dyn_cast(copyOp.getSource().getType()); + MemRefType targetMemrefType = + dyn_cast(copyOp.getTarget().getType()); + + if (!isMemRefTypeLegalForEmitC(srcMemrefType) || + !isMemRefTypeLegalForEmitC(targetMemrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + emitc::ConstantOp zeroIndex = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + auto srcArrayValue = + dyn_cast>(operands.getSource()); + + emitc::SubscriptOp srcSubPtr = rewriter.create( + loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex}); + emitc::ApplyOp srcPtr = rewriter.create( + loc, emitc::PointerType::get(srcMemrefType.getElementType()), + rewriter.getStringAttr("&"), srcSubPtr); + + auto arrayValue = + dyn_cast>(operands.getTarget()); + emitc::SubscriptOp targetSubPtr = rewriter.create( + loc, arrayValue, ValueRange{zeroIndex, zeroIndex}); + emitc::ApplyOp targetPtr = rewriter.create( + loc, emitc::PointerType::get(targetMemrefType.getElementType()), + rewriter.getStringAttr("&"), targetSubPtr); + + emitc::CallOpaqueOp memCpyCall = rewriter.create( + loc, TypeRange{}, "memcpy", + ValueRange{ + targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); + + rewriter.replaceOp(copyOp, memCpyCall.getResults()); + + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -380,6 +392,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index e78dd76d6e256..8e965b42f1043 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC @@ -27,6 +28,15 @@ namespace mlir { using namespace mlir; namespace { + +emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module, + StringRef headerName) { + StringAttr includeAttr = builder.getStringAttr(headerName); + return builder.create( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); +} + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase { using Base::Base; @@ -57,7 +67,8 @@ struct ConvertMemRefToEmitCPass mlir::ModuleOp module = getOperation(); module.walk([&](mlir::emitc::CallOpaqueOp callOp) { if (callOp.getCallee() != alignedAllocFunctionName && - callOp.getCallee() != mallocFunctionName) { + callOp.getCallee() != mallocFunctionName && + callOp.getCallee() != memcpyFunctionName) { return mlir::WalkResult::advance(); } @@ -76,12 +87,14 @@ struct ConvertMemRefToEmitCPass } mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); - StringAttr includeAttr = - builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader - : cStandardLibraryHeader); - builder.create( - module.getLoc(), includeAttr, - /*is_standard_include=*/builder.getUnitAttr()); + StringRef headerName; + if (callOp.getCallee() == memcpyFunctionName) + headerName = stringLibraryHeader; + else + headerName = options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader; + + addHeader(builder, module, headerName); return mlir::WalkResult::interrupt(); }); } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir index d031d60508df2..4b6eb50807513 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -1,25 +1,24 @@ -// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s +// RUN: mlir-opt -convert-memref-to-emitc %s | FileCheck %s func.func @copying(%arg0 : memref<2x4xf32>) { memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> return } -// func.func @copying_memcpy(%arg_0: !emitc.ptr) { -// %size = "emitc.constant"() <{value = 8 : index}> :() -> index -// %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index -// %total_bytes = emitc.mul %size, %element_size : (index, index) -> index - -// emitc.call_opaque "memcpy"(%arg_0, %arg_0, %total_bytes) : (!emitc.ptr, !emitc.ptr, index) -> () -// return -// } +// CHECK: module { +// CHECK-NEXT: emitc.include <"string.h"> +// CHECK-LABEL: copying +// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32> +// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue +// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue +// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index +// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT:} -// CHECK-LABEL: copying_memcpy -// CHECK-SAME: %arg_0: !emitc.ptr -// CHECK-NEXT: %size = "emitc.constant"() <{value = 8 : index}> :() -> index -// CHECK-NEXT: %element_size = "emitc.constant"() <{value = 4 : index}> :() -> index -// CHECK-NEXT: %total_bytes = emitc.mul %size, %element_size : (index, index) -> index -// CHECK-NEXT: emitc.call_opaque "memcpy" -// CHECK-SAME: (%arg_0, %arg_0, %total_bytes) -// CHECK-NEXT: : (!emitc.ptr, !emitc.ptr, index) -> () -// CHECK-NEXT: return \ No newline at end of file diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index fda01974d3fc8..b6eccfc8f0050 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -1,13 +1,5 @@ // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics -func.func @memref_op(%arg0 : memref<2x4xf32>) { - // expected-error@+1 {{failed to legalize operation 'memref.copy'}} - memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> - return -} - -// ----- - func.func @alloca_with_dynamic_shape() { %0 = index.constant 1 // expected-error@+1 {{failed to legalize operation 'memref.alloca'}} From 7e7f59c010d2f53714d629971d8d2dc3c3d17a2e Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Mon, 4 Aug 2025 18:11:17 +0000 Subject: [PATCH 3/5] allow for multi dimensional arays --- .../Conversion/MemRefToEmitC/MemRefToEmitC.h | 3 +- .../MemRefToEmitC/MemRefToEmitC.cpp | 88 ++++++++++++------- .../MemRefToEmitC/MemRefToEmitCPass.cpp | 19 ++-- 3 files changed, 68 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index 4ea6649d64a92..5abfb3d7e72dd 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -13,7 +13,8 @@ constexpr const char *mallocFunctionName = "malloc"; constexpr const char *memcpyFunctionName = "memcpy"; constexpr const char *cppStandardLibraryHeader = "cstdlib"; constexpr const char *cStandardLibraryHeader = "stdlib.h"; -constexpr const char *stringLibraryHeader = "string.h"; +constexpr const char *cppStringLibraryHeader = "cstring"; +constexpr const char *cStringLibraryHeader = "string.h"; namespace mlir { class DialectRegistry; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index adb0eb77fdf35..cabbfac4a1dca 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" @@ -98,23 +99,25 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { } Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, - ConversionPatternRewriter &rewriter) { - emitc::CallOpaqueOp elementSize = rewriter.create( - loc, emitc::SizeTType::get(rewriter.getContext()), - rewriter.getStringAttr("sizeof"), ValueRange{}, - ArrayAttr::get(rewriter.getContext(), + OpBuilder &builder) { + assert(isMemRefTypeLegalForEmitC(memrefType) && + "incompatible memref type for EmitC conversion"); + emitc::CallOpaqueOp elementSize = builder.create( + loc, emitc::SizeTType::get(builder.getContext()), + builder.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(builder.getContext(), {TypeAttr::get(memrefType.getElementType())})); - IndexType indexType = rewriter.getIndexType(); + IndexType indexType = builder.getIndexType(); int64_t numElements = 1; for (int64_t dimSize : memrefType.getShape()) { numElements *= dimSize; } - emitc::ConstantOp numElementsValue = rewriter.create( - loc, indexType, rewriter.getIndexAttr(numElements)); + emitc::ConstantOp numElementsValue = builder.create( + loc, indexType, builder.getIndexAttr(numElements)); - Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); - emitc::MulOp totalSizeBytes = rewriter.create( + Type sizeTType = emitc::SizeTType::get(builder.getContext()); + emitc::MulOp totalSizeBytes = builder.create( loc, sizeTType, elementSize.getResult(0), numElementsValue); return totalSizeBytes.getResult(); @@ -189,41 +192,64 @@ struct ConvertCopy final : public OpConversionPattern { matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { Location loc = copyOp.getLoc(); - MemRefType srcMemrefType = - dyn_cast(copyOp.getSource().getType()); + MemRefType srcMemrefType = cast(copyOp.getSource().getType()); MemRefType targetMemrefType = - dyn_cast(copyOp.getTarget().getType()); + cast(copyOp.getTarget().getType()); - if (!isMemRefTypeLegalForEmitC(srcMemrefType) || - !isMemRefTypeLegalForEmitC(targetMemrefType)) { + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) { return rewriter.notifyMatchFailure( - loc, "incompatible memref type for EmitC conversion"); + loc, "incompatible source memref type for EmitC conversion"); + } + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible target memref type for EmitC conversion"); } + auto createPointerFromEmitcArray = + [&](mlir::Location loc, mlir::OpBuilder &rewriter, + mlir::TypedValue arrayValue, + mlir::MemRefType memrefType, + emitc::ConstantOp zeroIndex) -> emitc::ApplyOp { + // Get the rank of the array to create the correct number of zero indices. + int64_t rank = arrayValue.getType().getRank(); + llvm::SmallVector indices; + for (int i = 0; i < rank; ++i) { + indices.push_back(zeroIndex); + } + + // Create a subscript operation to get the element at index [0, 0, ..., + // 0]. + emitc::SubscriptOp subPtr = rewriter.create( + loc, arrayValue, mlir::ValueRange(indices)); + + // Create an apply operation to take the address of the subscripted + // element. + emitc::ApplyOp ptr = rewriter.create( + loc, emitc::PointerType::get(memrefType.getElementType()), + rewriter.getStringAttr("&"), subPtr); + + return ptr; + }; + + // Create a constant zero index. emitc::ConstantOp zeroIndex = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + auto srcArrayValue = dyn_cast>(operands.getSource()); + emitc::ApplyOp srcPtr = createPointerFromEmitcArray( + loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex); - emitc::SubscriptOp srcSubPtr = rewriter.create( - loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex}); - emitc::ApplyOp srcPtr = rewriter.create( - loc, emitc::PointerType::get(srcMemrefType.getElementType()), - rewriter.getStringAttr("&"), srcSubPtr); - - auto arrayValue = + auto targetArrayValue = dyn_cast>(operands.getTarget()); - emitc::SubscriptOp targetSubPtr = rewriter.create( - loc, arrayValue, ValueRange{zeroIndex, zeroIndex}); - emitc::ApplyOp targetPtr = rewriter.create( - loc, emitc::PointerType::get(targetMemrefType.getElementType()), - rewriter.getStringAttr("&"), targetSubPtr); + emitc::ApplyOp targetPtr = createPointerFromEmitcArray( + loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex); + OpBuilder builder = rewriter; emitc::CallOpaqueOp memCpyCall = rewriter.create( loc, TypeRange{}, "memcpy", - ValueRange{ - targetPtr.getResult(), srcPtr.getResult(), - calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); + ValueRange{targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)}); rewriter.replaceOp(copyOp, memCpyCall.getResults()); diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index 8e965b42f1043..c60e7488fdb38 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -29,8 +29,8 @@ using namespace mlir; namespace { -emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module, - StringRef headerName) { +emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module, + StringRef headerName) { StringAttr includeAttr = builder.getStringAttr(headerName); return builder.create( module.getLoc(), includeAttr, @@ -68,33 +68,32 @@ struct ConvertMemRefToEmitCPass module.walk([&](mlir::emitc::CallOpaqueOp callOp) { if (callOp.getCallee() != alignedAllocFunctionName && callOp.getCallee() != mallocFunctionName && - callOp.getCallee() != memcpyFunctionName) { + callOp.getCallee() != memcpyFunctionName) return mlir::WalkResult::advance(); - } for (auto &op : *module.getBody()) { emitc::IncludeOp includeOp = llvm::dyn_cast(op); - if (!includeOp) { + if (!includeOp) continue; - } + if (includeOp.getIsStandardInclude() && ((options.lowerToCpp && includeOp.getInclude() == cppStandardLibraryHeader) || (!options.lowerToCpp && - includeOp.getInclude() == cStandardLibraryHeader))) { + includeOp.getInclude() == cStandardLibraryHeader))) return mlir::WalkResult::interrupt(); - } } mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); StringRef headerName; if (callOp.getCallee() == memcpyFunctionName) - headerName = stringLibraryHeader; + headerName = + options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader; else headerName = options.lowerToCpp ? cppStandardLibraryHeader : cStandardLibraryHeader; - addHeader(builder, module, headerName); + addStandardHeader(builder, module, headerName); return mlir::WalkResult::interrupt(); }); } From 7650eaaa0c41e370a07581d8d2ccd5c9ddfc0824 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Mon, 4 Aug 2025 18:19:32 +0000 Subject: [PATCH 4/5] update test file --- .../MemRefToEmitC/MemRefToEmitC.cpp | 7 ----- .../MemRefToEmitC/memref-to-emitc-copy.mlir | 26 ++++++++++--------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index cabbfac4a1dca..4130e9be88a89 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -210,20 +210,14 @@ struct ConvertCopy final : public OpConversionPattern { mlir::TypedValue arrayValue, mlir::MemRefType memrefType, emitc::ConstantOp zeroIndex) -> emitc::ApplyOp { - // Get the rank of the array to create the correct number of zero indices. int64_t rank = arrayValue.getType().getRank(); llvm::SmallVector indices; for (int i = 0; i < rank; ++i) { indices.push_back(zeroIndex); } - // Create a subscript operation to get the element at index [0, 0, ..., - // 0]. emitc::SubscriptOp subPtr = rewriter.create( loc, arrayValue, mlir::ValueRange(indices)); - - // Create an apply operation to take the address of the subscripted - // element. emitc::ApplyOp ptr = rewriter.create( loc, emitc::PointerType::get(memrefType.getElementType()), rewriter.getStringAttr("&"), subPtr); @@ -231,7 +225,6 @@ struct ConvertCopy final : public OpConversionPattern { return ptr; }; - // Create a constant zero index. emitc::ConstantOp zeroIndex = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir index 4b6eb50807513..88325e57762d3 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -1,23 +1,25 @@ // RUN: mlir-opt -convert-memref-to-emitc %s | FileCheck %s -func.func @copying(%arg0 : memref<2x4xf32>) { - memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> +func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { + memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> return } // CHECK: module { // CHECK-NEXT: emitc.include <"string.h"> // CHECK-LABEL: copying -// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32> -// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue -// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue) -> !emitc.ptr -// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue -// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue) -> !emitc.ptr -// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t -// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index -// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> +// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index +// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr +// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index +// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t +// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () // CHECK-NEXT: return // CHECK-NEXT: } // CHECK-NEXT:} From cc171204103df8d1d92dcd1377221720a532dcf8 Mon Sep 17 00:00:00 2001 From: Jaddyen Date: Tue, 5 Aug 2025 17:44:17 +0000 Subject: [PATCH 5/5] test cpp output --- .../MemRefToEmitC/MemRefToEmitC.cpp | 37 ++++++------ .../MemRefToEmitC/MemRefToEmitCPass.cpp | 15 +++-- .../MemRefToEmitC/memref-to-emitc-copy.mlir | 57 ++++++++++++------- 3 files changed, 65 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 4130e9be88a89..c8124d2f16943 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -196,20 +196,20 @@ struct ConvertCopy final : public OpConversionPattern { MemRefType targetMemrefType = cast(copyOp.getTarget().getType()); - if (!isMemRefTypeLegalForEmitC(srcMemrefType)) { + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) return rewriter.notifyMatchFailure( loc, "incompatible source memref type for EmitC conversion"); - } - if (!isMemRefTypeLegalForEmitC(targetMemrefType)) { + + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) return rewriter.notifyMatchFailure( loc, "incompatible target memref type for EmitC conversion"); - } + + emitc::ConstantOp zeroIndex = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); auto createPointerFromEmitcArray = - [&](mlir::Location loc, mlir::OpBuilder &rewriter, - mlir::TypedValue arrayValue, - mlir::MemRefType memrefType, - emitc::ConstantOp zeroIndex) -> emitc::ApplyOp { + [loc, &rewriter, &zeroIndex]( + mlir::TypedValue arrayValue) -> emitc::ApplyOp { int64_t rank = arrayValue.getType().getRank(); llvm::SmallVector indices; for (int i = 0; i < rank; ++i) { @@ -219,30 +219,25 @@ struct ConvertCopy final : public OpConversionPattern { emitc::SubscriptOp subPtr = rewriter.create( loc, arrayValue, mlir::ValueRange(indices)); emitc::ApplyOp ptr = rewriter.create( - loc, emitc::PointerType::get(memrefType.getElementType()), + loc, emitc::PointerType::get(arrayValue.getType().getElementType()), rewriter.getStringAttr("&"), subPtr); return ptr; }; - emitc::ConstantOp zeroIndex = rewriter.create( - loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); - auto srcArrayValue = - dyn_cast>(operands.getSource()); - emitc::ApplyOp srcPtr = createPointerFromEmitcArray( - loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex); + cast>(operands.getSource()); + emitc::ApplyOp srcPtr = createPointerFromEmitcArray(srcArrayValue); auto targetArrayValue = - dyn_cast>(operands.getTarget()); - emitc::ApplyOp targetPtr = createPointerFromEmitcArray( - loc, rewriter, targetArrayValue, targetMemrefType, zeroIndex); + cast>(operands.getTarget()); + emitc::ApplyOp targetPtr = createPointerFromEmitcArray(targetArrayValue); - OpBuilder builder = rewriter; emitc::CallOpaqueOp memCpyCall = rewriter.create( loc, TypeRange{}, "memcpy", - ValueRange{targetPtr.getResult(), srcPtr.getResult(), - calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)}); + ValueRange{ + targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); rewriter.replaceOp(copyOp, memCpyCall.getResults()); diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index c60e7488fdb38..3ffff9fca106a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -37,6 +37,16 @@ emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module, /*is_standard_include=*/builder.getUnitAttr()); } +bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options, + emitc::IncludeOp includeOp) { + return ((options.lowerToCpp && + (includeOp.getInclude() == cppStandardLibraryHeader || + includeOp.getInclude() == cppStringLibraryHeader)) || + (!options.lowerToCpp && + (includeOp.getInclude() == cStandardLibraryHeader || + includeOp.getInclude() == cStringLibraryHeader))); +} + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase { using Base::Base; @@ -77,10 +87,7 @@ struct ConvertMemRefToEmitCPass continue; if (includeOp.getIsStandardInclude() && - ((options.lowerToCpp && - includeOp.getInclude() == cppStandardLibraryHeader) || - (!options.lowerToCpp && - includeOp.getInclude() == cStandardLibraryHeader))) + isExpectedStandardInclude(options, includeOp)) return mlir::WalkResult::interrupt(); } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir index 88325e57762d3..1b515ba02dd46 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -1,26 +1,45 @@ -// RUN: mlir-opt -convert-memref-to-emitc %s | FileCheck %s +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> return } -// CHECK: module { -// CHECK-NEXT: emitc.include <"string.h"> -// CHECK-LABEL: copying -// CHECK-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> -// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> -// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> -// CHECK-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index -// CHECK-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue -// CHECK-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr -// CHECK-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue -// CHECK-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr -// CHECK-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t -// CHECK-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index -// CHECK-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t -// CHECK-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () -// CHECK-NEXT: return -// CHECK-NEXT: } -// CHECK-NEXT:} +// NOCPP: module { +// NOCPP-NEXT: emitc.include <"string.h"> +// NOCPP-LABEL: copying +// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> +// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index +// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr +// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr +// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index +// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// NOCPP-NEXT: return +// NOCPP-NEXT: } +// NOCPP-NEXT:} +// CPP: module { +// CPP-NEXT: emitc.include <"cstring"> +// CPP-LABEL: copying +// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> +// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index +// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr +// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr +// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index +// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CPP-NEXT: return +// CPP-NEXT: } +// CPP-NEXT:}