-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[WebAssembly] Add fold support for dot #151775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-webassembly Author: Jasmine Tang (badumbatish) ChangesFixes #50154 Full diff: https://github.com/llvm/llvm-project/pull/151775.diff 2 Files Affected:
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);
+ // Combine add with vector shuffle of muls to dots
+ setTargetDAGCombine(ISD::ADD);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::ADD);
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+ if (VT != MVT::v4i32)
+ return SDValue();
+
+ auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+ if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+ return SDValue();
+ if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+ return SDValue();
+ return V;
+ };
+ auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+ auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+ // two SDValues must be muls
+ if (!ShuffleA || !ShuffleB)
+ return SDValue();
+
+ if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+ ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+ return SDValue();
+
+ auto IsMulExtend =
+ [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+ if (V.getOpcode() != ISD::MUL)
+ return {};
+
+ auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+ if (V0.getOpcode() != I || V1.getOpcode() != I)
+ return {};
+ return {V0.getOperand(0), V1.getOperand(0)};
+ };
+
+ auto [LowA, LowB] =
+ IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+ auto [HighA, HighB] =
+ IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+ if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+ return SDValue();
+
+ return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
+ case ISD::ADD:
+ return performAddCombine(N, DCI.DAG);
}
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..7ac49794491a1
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,21 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK: .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.dot_i16x8_s
+; CHECK-NEXT: # fallthrough-return
+ %sext1 = sext <8 x i16> %a to <8 x i32>
+ %sext2 = sext <8 x i16> %b to <8 x i32>
+ %mul = mul nsw <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %res = add <4 x i32> %shuffle1, %shuffle2
+ ret <4 x i32> %res
+}
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is it possible to do this as a tablegen pattern? Not that it's necessarily the right thing to do, just wondering if it's easy to do or not!
; CHECK-NEXT: # fallthrough-return | ||
%sext1 = sext <8 x i16> %a to <8 x i32> | ||
%sext2 = sext <8 x i16> %b to <8 x i32> | ||
%mul = mul nsw <8 x i32> %sext1, %sext2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should work without the nsw flag?
I actually didn't think that tablegen can handle identical arguments in a pattern, i just tried it out just now and i think i might be able to make it work |
I'm assuming the 'illegal' types will make this more difficult in tablegen? This approach looks good to me. |
}; | ||
auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6}); | ||
auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7}); | ||
// two SDValues must be muls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment in the wrong place?
And it looks like you're assuming an order of these shuffles, is there some canonicalization that you're relying on?
}; | ||
|
||
auto [LowA, LowB] = | ||
IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking for EXTEND_LOW
/ EXTEND_HIGH
here, would it not be more simple to just look for SIGN_EXTEND
nodes?
; RUN: llc < %s -mattr=+simd128 | FileCheck %s | ||
|
||
target triple = "wasm32-unknown-unknown" | ||
define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some negative tests, just so we can be sure a zext
isn't slipping in, or the wrong kind of shuffle.
Fixes #50154