From f4664aa5d2723babef4d11ebba6616a019ff3ed0 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 1 Aug 2025 11:51:48 -0400 Subject: [PATCH 1/5] [mlir] Use memref's alignment attribute directly. --- mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 7a705336bf11c..0411589ed583d 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -511,8 +511,7 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { Operation *memrefAccessOp = loadOrStoreOp.getOperation(); auto memrefMemAccess = memrefAccessOp->getAttrOfType( spirv::attributeName()); - auto memrefAlignment = - memrefAccessOp->getAttrOfType("alignment"); + auto memrefAlignment = loadOrStoreOp.getAlignmentAttr(); if (memrefMemAccess && memrefAlignment) return MemoryRequirements{memrefMemAccess, memrefAlignment}; From 1c3455ce071a6a0de24bc32147bcc144827f8c87 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 1 Aug 2025 12:17:59 -0400 Subject: [PATCH 2/5] [mlir] MemRefToSPIRV propagate alignment attribute. --- .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 9 ++++++--- .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 0411589ed583d..9b8e39fdd0335 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -465,7 +465,8 @@ struct MemoryRequirements { /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if /// any. static FailureOr -calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { +calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, + uint64_t preferredAlignment) { MLIRContext *ctx = accessedPtr.getContext(); auto memoryAccess = spirv::MemoryAccess::None; @@ -494,7 +495,8 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); - auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); + auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes; + auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue); return MemoryRequirements{memAccessAttr, alignment}; } @@ -516,7 +518,8 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { return MemoryRequirements{memrefMemAccess, memrefAlignment}; return calculateMemoryRequirements(accessedPtr, - loadOrStoreOp.getNontemporal()); + loadOrStoreOp.getNontemporal(), + loadOrStoreOp.getAlignment().value_or(0)); } LogicalResult diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index d0ddac8cd801c..a00a6e0cbfe8a 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -85,6 +85,21 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i : return %0: i1 } +// CHECK-LABEL: func @load_aligned +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, PhysicalStorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] + // CHECK: %[[VAL:.+]] = spirv.Load "PhysicalStorageBuffer" %[[ADDR]] ["Aligned", 32] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 + %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + // CHECK-LABEL: func @store_i1 // CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spirv.storage_class>, // CHECK-SAME: %[[IDX:.+]]: index From 526787f4fb8a87830db1061e8f781ce73e9f74f1 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 1 Aug 2025 15:04:50 -0400 Subject: [PATCH 3/5] [mlir] Fix calculateMemoryRequirements in MemRefToSPIRV. There was an early return in calculateMemoryRequirements that looked explicitly for alignment and only set the alignment attribute. However, this was not correct for the following reasons: * Alignment was set only if both the alignment and the memory_access attributes were both present in the memref operation, without handling the case when only the alignment was exclusively present. * In the case alignment and memory_access attributes were both present, the memory_access attribute would not be updated to aligned if the memory_access attribute was not marked aligned. * In the case alignment and memory_access attributes were both present, other memory requirements (e.g., non_temporal) would not be added as attributes. --- .../Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 13 +++++-------- .../MemRefToSPIRV/memref-to-spirv.mlir | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 9b8e39fdd0335..2204cacf959ce 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -475,7 +475,10 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, } auto ptrType = cast(accessedPtr.getType()); - if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { + bool mayOmitAlignment = + !preferredAlignment && + ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer; + if (mayOmitAlignment) { if (memoryAccess == spirv::MemoryAccess::None) { return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; } @@ -484,6 +487,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, } // PhysicalStorageBuffers require the `Aligned` attribute. + // Other storage types may show an `Aligned` attribute. auto pointeeType = dyn_cast(ptrType.getPointeeType()); if (!pointeeType) return failure(); @@ -510,13 +514,6 @@ calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { llvm::is_one_of::value, "Must be called on either memref::LoadOp or memref::StoreOp"); - Operation *memrefAccessOp = loadOrStoreOp.getOperation(); - auto memrefMemAccess = memrefAccessOp->getAttrOfType( - spirv::attributeName()); - auto memrefAlignment = loadOrStoreOp.getAlignmentAttr(); - if (memrefMemAccess && memrefAlignment) - return MemoryRequirements{memrefMemAccess, memrefAlignment}; - return calculateMemoryRequirements(accessedPtr, loadOrStoreOp.getNontemporal(), loadOrStoreOp.getAlignment().value_or(0)); diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index a00a6e0cbfe8a..95c7349476230 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -86,8 +86,23 @@ func.func @load_i1(%src: memref<4xi1, #spirv.storage_class>, %i : } // CHECK-LABEL: func @load_aligned +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] + // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned", 32] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 + %0 = memref.load %src[%i] { alignment = 32 } : memref<4xi1, #spirv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + +// CHECK-LABEL: func @load_aligned_psb // CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) -func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { +func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, PhysicalStorageBuffer> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 From 04a811b820b47bb9c5713639648a3680db068457 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 1 Aug 2025 15:32:03 -0400 Subject: [PATCH 4/5] [mlir] Ensure memref's alignment is within SPIRV's alignment bounds. --- mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 2204cacf959ce..e730998f153b0 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Visitors.h" #include +#include #include #define DEBUG_TYPE "memref-to-spirv-pattern" @@ -467,6 +468,11 @@ struct MemoryRequirements { static FailureOr calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, uint64_t preferredAlignment) { + + if (std::numeric_limits::max() < preferredAlignment) { + return failure(); + } + MLIRContext *ctx = accessedPtr.getContext(); auto memoryAccess = spirv::MemoryAccess::None; From 0337e226ce41bbcc4b57c2e70186c510361afe80 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 1 Aug 2025 15:37:02 -0400 Subject: [PATCH 5/5] [mlir] Add aligned nontemporal load test to MemRefToSPIRV. --- .../Conversion/MemRefToSPIRV/memref-to-spirv.mlir | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 95c7349476230..7c765f70136bb 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -100,6 +100,21 @@ func.func @load_aligned(%src: memref<4xi1, #spirv.storage_class>, return %0: i1 } +// CHECK-LABEL: func @load_aligned_nontemporal +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) +func.func @load_aligned_nontemporal(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spirv.storage_class> to !spirv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32 + // CHECK: %[[ADDR:.+]] = spirv.AccessChain %[[SRC_CAST]][%[[ZERO]], %[[IDX_CAST]]] + // CHECK: %[[VAL:.+]] = spirv.Load "StorageBuffer" %[[ADDR]] ["Aligned|Nontemporal", 32] : i8 + // CHECK: %[[ZERO_I8:.+]] = spirv.Constant 0 : i8 + // CHECK: %[[BOOL:.+]] = spirv.INotEqual %[[VAL]], %[[ZERO_I8]] : i8 + %0 = memref.load %src[%i] { alignment = 32, nontemporal = true } : memref<4xi1, #spirv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + // CHECK-LABEL: func @load_aligned_psb // CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spirv.storage_class>, %[[IDX:.+]]: index) func.func @load_aligned_psb(%src: memref<4xi1, #spirv.storage_class>, %i : index) -> i1 {