Skip to content

[mlir][EmitC] Expand the MemRefToEmitC pass - Lowering CopyOp #151206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Jaddyen
Copy link
Contributor

@Jaddyen Jaddyen commented Jul 29, 2025

This patch lowers memref.copy to emitc.call_opaque "memcpy".
From:

func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
  memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32>
  return
}

To:

#include <cstring>
void copying(float v1[9][4][5][7], float v2[9][4][5][7]) {
  size_t v3 = 0;
  float* v4 = &v2[v3][v3][v3][v3];
  float* v5 = &v1[v3][v3][v3][v3];
  size_t v6 = sizeof(float);
  size_t v7 = 1260;
  size_t v8 = v6 * v7;
  memcpy(v5, v4, v8);
  return;
}

@Jaddyen Jaddyen marked this pull request as ready for review July 29, 2025 18:23
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-emitc

Author: Jaden Angella (Jaddyen)

Changes

This patch lowers memref.copy to emitc.call_opaque "memcpy".
From:

func.func @<!-- -->copying(%arg0 : memref&lt;9x4xf32&gt;) {
  memref.copy %arg0, %arg0 : memref&lt;9x4xf32&gt; to memref&lt;9x4xf32&gt;
  return
}

To:

void copying(float v1[9][4]) {
  size_t v2 = 0;
  size_t v3 = 0;
  float* v4 = &amp;v1[v2][v3];
  size_t v5 = 0;
  size_t v6 = 0;
  float* v7 = &amp;v1[v5][v6];
  size_t v8 = sizeof(float);
  size_t v9 = 36;
  size_t v10 = v8 * v9;
  memcpy(v7, v4, v10);
  return;
}

Full diff: https://github.com/llvm/llvm-project/pull/151206.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h (+2)
  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+75-2)
  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp (+20-7)
  • (added) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir (+24)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir (-8)
diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
index b595b6a308bea..4ea6649d64a92 100644
--- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
+++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
@@ -10,8 +10,10 @@
 
 constexpr const char *alignedAllocFunctionName = "aligned_alloc";
 constexpr const char *mallocFunctionName = "malloc";
+constexpr const char *memcpyFunctionName = "memcpy";
 constexpr const char *cppStandardLibraryHeader = "cstdlib";
 constexpr const char *cStandardLibraryHeader = "stdlib.h";
+constexpr const char *stringLibraryHeader = "string.h";
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d4d4b08..adb0eb77fdf35 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -97,6 +97,29 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
   return resultTy;
 }
 
+Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+                                    ConversionPatternRewriter &rewriter) {
+  emitc::CallOpaqueOp elementSize = rewriter.create<emitc::CallOpaqueOp>(
+      loc, emitc::SizeTType::get(rewriter.getContext()),
+      rewriter.getStringAttr("sizeof"), ValueRange{},
+      ArrayAttr::get(rewriter.getContext(),
+                     {TypeAttr::get(memrefType.getElementType())}));
+
+  IndexType indexType = rewriter.getIndexType();
+  int64_t numElements = 1;
+  for (int64_t dimSize : memrefType.getShape()) {
+    numElements *= dimSize;
+  }
+  emitc::ConstantOp numElementsValue = rewriter.create<emitc::ConstantOp>(
+      loc, indexType, rewriter.getIndexAttr(numElements));
+
+  Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+  emitc::MulOp totalSizeBytes = rewriter.create<emitc::MulOp>(
+      loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+  return totalSizeBytes.getResult();
+}
+
 struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -159,6 +182,55 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
   }
 };
 
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = copyOp.getLoc();
+    MemRefType srcMemrefType =
+        dyn_cast<MemRefType>(copyOp.getSource().getType());
+    MemRefType targetMemrefType =
+        dyn_cast<MemRefType>(copyOp.getTarget().getType());
+
+    if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
+        !isMemRefTypeLegalForEmitC(targetMemrefType)) {
+      return rewriter.notifyMatchFailure(
+          loc, "incompatible memref type for EmitC conversion");
+    }
+
+    emitc::ConstantOp zeroIndex = rewriter.create<emitc::ConstantOp>(
+        loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
+    auto srcArrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+
+    emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
+        loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
+    emitc::ApplyOp srcPtr = rewriter.create<emitc::ApplyOp>(
+        loc, emitc::PointerType::get(srcMemrefType.getElementType()),
+        rewriter.getStringAttr("&"), srcSubPtr);
+
+    auto arrayValue =
+        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+    emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
+        loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
+    emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
+        loc, emitc::PointerType::get(targetMemrefType.getElementType()),
+        rewriter.getStringAttr("&"), targetSubPtr);
+
+    emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
+        loc, TypeRange{}, "memcpy",
+        ValueRange{
+            targetPtr.getResult(), srcPtr.getResult(),
+            calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+    rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+    return success();
+  }
+};
+
 struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -320,6 +392,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(
     RewritePatternSet &patterns, const TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
-               ConvertLoad, ConvertStore>(converter, patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+               ConvertGetGlobal, ConvertLoad, ConvertStore>(
+      converter, patterns.getContext());
 }
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76d6e256..8e965b42f1043 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +28,15 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
+
+emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
+                           StringRef headerName) {
+  StringAttr includeAttr = builder.getStringAttr(headerName);
+  return builder.create<emitc::IncludeOp>(
+      module.getLoc(), includeAttr,
+      /*is_standard_include=*/builder.getUnitAttr());
+}
+
 struct ConvertMemRefToEmitCPass
     : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
   using Base::Base;
@@ -57,7 +67,8 @@ struct ConvertMemRefToEmitCPass
     mlir::ModuleOp module = getOperation();
     module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
       if (callOp.getCallee() != alignedAllocFunctionName &&
-          callOp.getCallee() != mallocFunctionName) {
+          callOp.getCallee() != mallocFunctionName &&
+          callOp.getCallee() != memcpyFunctionName) {
         return mlir::WalkResult::advance();
       }
 
@@ -76,12 +87,14 @@ struct ConvertMemRefToEmitCPass
       }
 
       mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
-      StringAttr includeAttr =
-          builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
-                                                   : cStandardLibraryHeader);
-      builder.create<mlir::emitc::IncludeOp>(
-          module.getLoc(), includeAttr,
-          /*is_standard_include=*/builder.getUnitAttr());
+      StringRef headerName;
+      if (callOp.getCallee() == memcpyFunctionName)
+        headerName = stringLibraryHeader;
+      else
+        headerName = options.lowerToCpp ? cppStandardLibraryHeader
+                                        : cStandardLibraryHeader;
+
+      addHeader(builder, module, headerName);
       return mlir::WalkResult::interrupt();
     });
   }
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
new file mode 100644
index 0000000000000..4b6eb50807513
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-opt -convert-memref-to-emitc %s  | FileCheck %s
+
+func.func @copying(%arg0 : memref<2x4xf32>) {
+  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
+  return
+}
+
+// CHECK: module {
+// CHECK-NEXT:  emitc.include <"string.h">
+// CHECK-LABEL:  copying
+// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %arg0 : memref<2x4xf32> to !emitc.array<2x4xf32>
+// CHECK-NEXT: %1 = "emitc.constant"() <{value = 0 : index}> : () -> index
+// CHECK-NEXT: %2 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %3 = emitc.apply "&"(%2) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %4 = emitc.subscript %0[%1, %1] : (!emitc.array<2x4xf32>, index, index) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %5 = emitc.apply "&"(%4) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
+// CHECK-NEXT: %6 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
+// CHECK-NEXT: %7 = "emitc.constant"() <{value = 8 : index}> : () -> index
+// CHECK-NEXT: %8 = emitc.mul %6, %7 : (!emitc.size_t, index) -> !emitc.size_t
+// CHECK-NEXT: emitc.call_opaque "memcpy"(%5, %3, %8) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+// CHECK-NEXT:}
+
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index fda01974d3fc8..b6eccfc8f0050 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -1,13 +1,5 @@
 // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
 
-func.func @memref_op(%arg0 : memref<2x4xf32>) {
-  // expected-error@+1 {{failed to legalize operation 'memref.copy'}}
-  memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
-  return
-}
-
-// -----
-
 func.func @alloca_with_dynamic_shape() {
   %0 = index.constant 1
   // expected-error@+1 {{failed to legalize operation 'memref.alloca'}}

@Jaddyen Jaddyen requested review from jpienaar, mtrofin and ilovepi July 29, 2025 18:24
@@ -97,6 +97,29 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}

Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
ConversionPatternRewriter &rewriter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can take the more generic OpBuilder here instead of a Rewriter?

Comment on lines 197 to 201
if (!isMemRefTypeLegalForEmitC(srcMemrefType) ||
!isMemRefTypeLegalForEmitC(targetMemrefType)) {
return rewriter.notifyMatchFailure(
loc, "incompatible memref type for EmitC conversion");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd split this to two checks to specify in the message which memref was incompatible.

dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());

emitc::SubscriptOp srcSubPtr = rewriter.create<emitc::SubscriptOp>(
loc, srcArrayValue, ValueRange{zeroIndex, zeroIndex});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works for 2D arrays

Comment on lines 216 to 220
emitc::SubscriptOp targetSubPtr = rewriter.create<emitc::SubscriptOp>(
loc, arrayValue, ValueRange{zeroIndex, zeroIndex});
emitc::ApplyOp targetPtr = rewriter.create<emitc::ApplyOp>(
loc, emitc::PointerType::get(targetMemrefType.getElementType()),
rewriter.getStringAttr("&"), targetSubPtr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicates the work for srcArrayValue, can be folded into a lambda.

Comment on lines 192 to 195
MemRefType srcMemrefType =
dyn_cast<MemRefType>(copyOp.getSource().getType());
MemRefType targetMemrefType =
dyn_cast<MemRefType>(copyOp.getTarget().getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use cast (these arguments must be MemRefs by op definition).

@@ -27,6 +28,15 @@ namespace mlir {
using namespace mlir;

namespace {

emitc::IncludeOp addHeader(OpBuilder &builder, ModuleOp module,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be probably be named addStandardHeader as it always sets is_standard_include.

/*is_standard_include=*/builder.getUnitAttr());
StringRef headerName;
if (callOp.getCallee() == memcpyFunctionName)
headerName = stringLibraryHeader;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not follow the C/C++ distinction as with stdlib, i.e. cstring vs string.h?

}

// CHECK: module {
// CHECK-NEXT: emitc.include <"string.h">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should probably be a c++ test too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only have a test for <string.h>, right? in the c++ mode, you'd have , so you probably need a test for the cpp mode too. Sorry for not being more explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the feedback!

@Jaddyen Jaddyen requested a review from aniragil August 4, 2025 22:05
Comment on lines 209 to 212
[&](mlir::Location loc, mlir::OpBuilder &rewriter,
mlir::TypedValue<emitc::ArrayType> arrayValue,
mlir::MemRefType memrefType,
emitc::ConstantOp zeroIndex) -> emitc::ApplyOp {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • You're already capturing rewriter and zeroIndex, no need to pass them as arguments (if you move zeroIndex creation above the lambda).
  • Usually better to capture exactly lambda uses than capture everything (&).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the feedback!

emitc::SubscriptOp subPtr = rewriter.create<emitc::SubscriptOp>(
loc, arrayValue, mlir::ValueRange(indices));
emitc::ApplyOp ptr = rewriter.create<emitc::ApplyOp>(
loc, emitc::PointerType::get(memrefType.getElementType()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not take element type from arrayValue (as done for the rank)? Would save passing memrefType as arg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, thanks for the feedback!

loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));

auto srcArrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getSource());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use cast rather than dyn_cast when you expect cast to succeed (it will assert on failure).

loc, rewriter, srcArrayValue, srcMemrefType, zeroIndex);

auto targetArrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

emitc::CallOpaqueOp memCpyCall = rewriter.create<emitc::CallOpaqueOp>(
loc, TypeRange{}, "memcpy",
ValueRange{targetPtr.getResult(), srcPtr.getResult(),
calculateMemrefTotalSizeBytes(loc, srcMemrefType, builder)});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewriter is a builder, you can just pass it as argument (and should, as you're working within a conversion pattern).


for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
if (!includeOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change itself is good (removing braces from single-line blocks) but should be done on a separate PR to avoid cluttering this one with unrelated modifications.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, yes.
but im already modifying this portion of code and this would be a single line change.

if (includeOp.getIsStandardInclude() &&
((options.lowerToCpp &&
includeOp.getInclude() == cppStandardLibraryHeader) ||
(!options.lowerToCpp &&
includeOp.getInclude() == cStandardLibraryHeader))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.
Also, shouldn't the code here also check for c/cppStringLibraryHeader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i should!
thanks for the pointer!

@Jaddyen Jaddyen requested review from ilovepi and aniragil August 5, 2025 17:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants