diff --git a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h index b595b6a308bea..5abfb3d7e72dd 100644 --- a/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h +++ b/mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h @@ -10,8 +10,11 @@ constexpr const char *alignedAllocFunctionName = "aligned_alloc"; constexpr const char *mallocFunctionName = "malloc"; +constexpr const char *memcpyFunctionName = "memcpy"; constexpr const char *cppStandardLibraryHeader = "cstdlib"; constexpr const char *cStandardLibraryHeader = "stdlib.h"; +constexpr const char *cppStringLibraryHeader = "cstring"; +constexpr const char *cStringLibraryHeader = "string.h"; namespace mlir { class DialectRegistry; diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d4d4b08..c8124d2f16943 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" @@ -97,6 +98,31 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, + OpBuilder &builder) { + assert(isMemRefTypeLegalForEmitC(memrefType) && + "incompatible memref type for EmitC conversion"); + emitc::CallOpaqueOp elementSize = builder.create( + loc, emitc::SizeTType::get(builder.getContext()), + builder.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(builder.getContext(), + {TypeAttr::get(memrefType.getElementType())})); + + IndexType indexType = builder.getIndexType(); + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + emitc::ConstantOp numElementsValue = builder.create( + loc, indexType, builder.getIndexAttr(numElements)); + + Type sizeTType = emitc::SizeTType::get(builder.getContext()); + emitc::MulOp totalSizeBytes = builder.create( + loc, sizeTType, elementSize.getResult(0), numElementsValue); + + return totalSizeBytes.getResult(); +} + struct ConvertAlloc final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -159,6 +185,66 @@ struct ConvertAlloc final : public OpConversionPattern { } }; +struct ConvertCopy final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = copyOp.getLoc(); + MemRefType srcMemrefType = cast(copyOp.getSource().getType()); + MemRefType targetMemrefType = + cast(copyOp.getTarget().getType()); + + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible source memref type for EmitC conversion"); + + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible target memref type for EmitC conversion"); + + emitc::ConstantOp zeroIndex = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(0)); + + auto createPointerFromEmitcArray = + [loc, &rewriter, &zeroIndex]( + mlir::TypedValue arrayValue) -> emitc::ApplyOp { + int64_t rank = arrayValue.getType().getRank(); + llvm::SmallVector indices; + for (int i = 0; i < rank; ++i) { + indices.push_back(zeroIndex); + } + + emitc::SubscriptOp subPtr = rewriter.create( + loc, arrayValue, mlir::ValueRange(indices)); + emitc::ApplyOp ptr = rewriter.create( + loc, emitc::PointerType::get(arrayValue.getType().getElementType()), + rewriter.getStringAttr("&"), subPtr); + + return ptr; + }; + + auto srcArrayValue = + cast>(operands.getSource()); + emitc::ApplyOp srcPtr = createPointerFromEmitcArray(srcArrayValue); + + auto targetArrayValue = + cast>(operands.getTarget()); + emitc::ApplyOp targetPtr = createPointerFromEmitcArray(targetArrayValue); + + emitc::CallOpaqueOp memCpyCall = rewriter.create( + loc, TypeRange{}, "memcpy", + ValueRange{ + targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); + + rewriter.replaceOp(copyOp, memCpyCall.getResults()); + + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -320,6 +406,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add(converter, patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index e78dd76d6e256..3ffff9fca106a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringRef.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC @@ -27,6 +28,25 @@ namespace mlir { using namespace mlir; namespace { + +emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module, + StringRef headerName) { + StringAttr includeAttr = builder.getStringAttr(headerName); + return builder.create( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); +} + +bool isExpectedStandardInclude(ConvertMemRefToEmitCOptions options, + emitc::IncludeOp includeOp) { + return ((options.lowerToCpp && + (includeOp.getInclude() == cppStandardLibraryHeader || + includeOp.getInclude() == cppStringLibraryHeader)) || + (!options.lowerToCpp && + (includeOp.getInclude() == cStandardLibraryHeader || + includeOp.getInclude() == cStringLibraryHeader))); +} + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase { using Base::Base; @@ -57,31 +77,30 @@ struct ConvertMemRefToEmitCPass mlir::ModuleOp module = getOperation(); module.walk([&](mlir::emitc::CallOpaqueOp callOp) { if (callOp.getCallee() != alignedAllocFunctionName && - callOp.getCallee() != mallocFunctionName) { + callOp.getCallee() != mallocFunctionName && + callOp.getCallee() != memcpyFunctionName) return mlir::WalkResult::advance(); - } for (auto &op : *module.getBody()) { emitc::IncludeOp includeOp = llvm::dyn_cast(op); - if (!includeOp) { + if (!includeOp) continue; - } + if (includeOp.getIsStandardInclude() && - ((options.lowerToCpp && - includeOp.getInclude() == cppStandardLibraryHeader) || - (!options.lowerToCpp && - includeOp.getInclude() == cStandardLibraryHeader))) { + isExpectedStandardInclude(options, includeOp)) return mlir::WalkResult::interrupt(); - } } mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); - StringAttr includeAttr = - builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader - : cStandardLibraryHeader); - builder.create( - module.getLoc(), includeAttr, - /*is_standard_include=*/builder.getUnitAttr()); + StringRef headerName; + if (callOp.getCallee() == memcpyFunctionName) + headerName = + options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader; + else + headerName = options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader; + + addStandardHeader(builder, module, headerName); return mlir::WalkResult::interrupt(); }); } diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir new file mode 100644 index 0000000000000..1b515ba02dd46 --- /dev/null +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=true" %s -split-input-file | FileCheck %s --check-prefix=CPP +// RUN: mlir-opt -convert-memref-to-emitc="lower-to-cpp=false" %s -split-input-file | FileCheck %s --check-prefix=NOCPP + +func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) { + memref.copy %arg0, %arg1 : memref<9x4x5x7xf32> to memref<9x4x5x7xf32> + return +} + +// NOCPP: module { +// NOCPP-NEXT: emitc.include <"string.h"> +// NOCPP-LABEL: copying +// NOCPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> +// NOCPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// NOCPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// NOCPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index +// NOCPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// NOCPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr +// NOCPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// NOCPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr +// NOCPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// NOCPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index +// NOCPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t +// NOCPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// NOCPP-NEXT: return +// NOCPP-NEXT: } +// NOCPP-NEXT:} + +// CPP: module { +// CPP-NEXT: emitc.include <"cstring"> +// CPP-LABEL: copying +// CPP-SAME: %[[arg0:.*]]: memref<9x4x5x7xf32>, %[[arg1:.*]]: memref<9x4x5x7xf32> +// CPP-NEXT: %0 = builtin.unrealized_conversion_cast %arg1 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CPP-NEXT: %1 = builtin.unrealized_conversion_cast %arg0 : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32> +// CPP-NEXT: %2 = "emitc.constant"() <{value = 0 : index}> : () -> index +// CPP-NEXT: %3 = emitc.subscript %1[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CPP-NEXT: %4 = emitc.apply "&"(%3) : (!emitc.lvalue) -> !emitc.ptr +// CPP-NEXT: %5 = emitc.subscript %0[%2, %2, %2, %2] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue +// CPP-NEXT: %6 = emitc.apply "&"(%5) : (!emitc.lvalue) -> !emitc.ptr +// CPP-NEXT: %7 = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t +// CPP-NEXT: %8 = "emitc.constant"() <{value = 1260 : index}> : () -> index +// CPP-NEXT: %9 = emitc.mul %7, %8 : (!emitc.size_t, index) -> !emitc.size_t +// CPP-NEXT: emitc.call_opaque "memcpy"(%6, %4, %9) : (!emitc.ptr, !emitc.ptr, !emitc.size_t) -> () +// CPP-NEXT: return +// CPP-NEXT: } +// CPP-NEXT:} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index fda01974d3fc8..b6eccfc8f0050 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -1,13 +1,5 @@ // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics -func.func @memref_op(%arg0 : memref<2x4xf32>) { - // expected-error@+1 {{failed to legalize operation 'memref.copy'}} - memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32> - return -} - -// ----- - func.func @alloca_with_dynamic_shape() { %0 = index.constant 1 // expected-error@+1 {{failed to legalize operation 'memref.alloca'}}