Skip to content

Commit 9de4970

Browse files
farzonlllvm-beanz
andauthored
[SPIRV] Preserve implicit bitcast (#151041)
fixes #146942 ## Issue The cause of the bug is in InstCombine which is converting our load of float vec4 and bitcast to i32 vec4 into one load of i32 vec4. That means wr have to do a legalization in the spirv backend to convert back ```diff - %3 = load <4 x i32>, ptr addrspace(11) %2, align 16 + %3 = load <4 x float>, ptr addrspace(11) %2, align 16 + %4 = bitcast <4 x float> %3 to <4 x i32> ``` <img width="2566" height="548" alt="Image" src="https://github.com/user-attachments/assets/0bf8813c-70f8-47df-8207-ab7da54f5382" /> https://godbolt.org/z/K4GeM4fKT ## The Fix Just removing the assert isn't enough to fix this bug. If we do so we get an assert later `Assertion failed: (!storageClassRequiresExplictLayout(SC)), function getOrCreateSPIRVPointerType, file SPIRVGlobalRegistry.cpp, line 1806.` If we just remove the assert the `CreateShuffleVector` uses the source type via the `NewLoad` when the `Output` type needs to be the `TargetType`. We also can't use`CreateBitCast` That will feed the right types for the `ShuffleVector` but it doesn't emit OpBitcast. the llvmIR isn't translated over to MIR. The fix then is to emit `spv_bitcast` just like what `SPIRVEmitIntrinsics::visitBitCastInst` does. --------- Co-authored-by: Chris B <[email protected]>
1 parent 3b5aff5 commit 9de4970

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,20 @@ class SPIRVLegalizePointerCast : public FunctionPass {
7474
// Returns the loaded value.
7575
Value *loadVectorFromVector(IRBuilder<> &B, FixedVectorType *SourceType,
7676
FixedVectorType *TargetType, Value *Source) {
77-
// We expect the codegen to avoid doing implicit bitcast from a load.
78-
assert(TargetType->getElementType() == SourceType->getElementType());
79-
assert(TargetType->getNumElements() < SourceType->getNumElements());
80-
77+
assert(TargetType->getNumElements() <= SourceType->getNumElements());
8178
LoadInst *NewLoad = B.CreateLoad(SourceType, Source);
8279
buildAssignType(B, SourceType, NewLoad);
80+
Value *AssignValue = NewLoad;
81+
if (TargetType->getElementType() != SourceType->getElementType()) {
82+
AssignValue = B.CreateIntrinsic(Intrinsic::spv_bitcast,
83+
{TargetType, SourceType}, {NewLoad});
84+
buildAssignType(B, TargetType, AssignValue);
85+
}
8386

8487
SmallVector<int> Mask(/* Size= */ TargetType->getNumElements());
8588
for (unsigned I = 0; I < TargetType->getNumElements(); ++I)
8689
Mask[I] = I;
87-
Value *Output = B.CreateShuffleVector(NewLoad, NewLoad, Mask);
90+
Value *Output = B.CreateShuffleVector(AssignValue, AssignValue, Mask);
8891
buildAssignType(B, TargetType, Output);
8992
return Output;
9093
}
@@ -135,8 +138,9 @@ class SPIRVLegalizePointerCast : public FunctionPass {
135138
Output = loadFirstValueFromAggregate(B, SVT->getElementType(),
136139
OriginalOperand, LI);
137140
}
138-
// Destination is a smaller vector than source.
141+
// Destination is a smaller vector than source or different vector type.
139142
// - float3 v3 = vector4;
143+
// - float4 v2 = int4;
140144
else if (SVT && DVT)
141145
Output = loadVectorFromVector(B, SVT, DVT, OriginalOperand);
142146
// Destination is the scalar type stored at the start of an aggregate.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
3+
4+
@.str = private unnamed_addr constant [4 x i8] c"In3\00", align 1
5+
@.str.2 = private unnamed_addr constant [5 x i8] c"Out4\00", align 1
6+
@.str.3 = private unnamed_addr constant [5 x i8] c"Out3\00", align 1
7+
8+
9+
; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
10+
; CHECK-DAG: %[[#INT4:]] = OpTypeVector %[[#INT32]] 4
11+
; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
12+
; CHECK-DAG: %[[#FLOAT4:]] = OpTypeVector %[[#FLOAT]] 4
13+
; CHECK-DAG: %[[#INT3:]] = OpTypeVector %[[#INT32]] 3
14+
; CHECK-DAG: %[[#UNDEF_INT4:]] = OpUndef %[[#INT4]]
15+
16+
define void @case1() local_unnamed_addr {
17+
; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
18+
; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
19+
; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
20+
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str)
21+
%2 = tail call target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4i32_12_1t(i32 0, i32 5, i32 1, i32 0, i1 false, ptr nonnull @.str.2)
22+
%3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
23+
%4 = load <4 x i32>, ptr addrspace(11) %3, align 16
24+
%5 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4i32_12_1t(target("spirv.VulkanBuffer", [0 x <4 x i32>], 12, 1) %2, i32 0)
25+
store <4 x i32> %4, ptr addrspace(11) %5, align 16
26+
ret void
27+
}
28+
29+
define void @case2() local_unnamed_addr {
30+
; CHECK: %[[#BUFFER_LOAD:]] = OpLoad %[[#FLOAT4]] %{{[0-9]+}} Aligned 16
31+
; CHECK: %[[#CAST_LOAD:]] = OpBitcast %[[#INT4]] %[[#BUFFER_LOAD]]
32+
; CHECK: %[[#VEC_SHUFFLE:]] = OpVectorShuffle %[[#INT4]] %[[#CAST_LOAD]] %[[#CAST_LOAD]] 0 1 2 3
33+
; CHECK: %[[#VEC_TRUNCATE:]] = OpVectorShuffle %[[#INT3]] %[[#VEC_SHUFFLE]] %[[#UNDEF_INT4]] 0 1 2
34+
%1 = tail call target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v4f32_12_0t(i32 0, i32 2, i32 1, i32 0, i1 false, ptr nonnull @.str)
35+
%2 = tail call target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0v3i32_12_1t(i32 0, i32 5, i32 1, i32 0, i1 false, ptr nonnull @.str.3)
36+
%3 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v4f32_12_0t(target("spirv.VulkanBuffer", [0 x <4 x float>], 12, 0) %1, i32 0)
37+
%4 = load <4 x i32>, ptr addrspace(11) %3, align 16
38+
%5 = shufflevector <4 x i32> %4, <4 x i32> poison, <3 x i32> <i32 0, i32 1, i32 2>
39+
%6 = tail call noundef align 16 dereferenceable(16) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0v3i32_12_1t(target("spirv.VulkanBuffer", [0 x <3 x i32>], 12, 1) %2, i32 0)
40+
store <3 x i32> %5, ptr addrspace(11) %6, align 16
41+
ret void
42+
}

0 commit comments

Comments
 (0)