Skip to content

[WebAssembly] Add fold support for dot #151775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

badumbatish
Copy link
Contributor

Fixes #50154

@llvmbot
Copy link
Member

llvmbot commented Aug 1, 2025

@llvm/pr-subscribers-backend-webassembly

Author: Jasmine Tang (badumbatish)

Changes

Fixes #50154


Full diff: https://github.com/llvm/llvm-project/pull/151775.diff

2 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+51)
  • (added) llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll (+21)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     // Combine wide-vector muls, with extend inputs, to extmul_half.
     setTargetDAGCombine(ISD::MUL);
 
+    // Combine add with vector shuffle of muls to dots
+    setTargetDAGCombine(ISD::ADD);
+
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::ADD);
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+  if (VT != MVT::v4i32)
+    return SDValue();
+
+  auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+    if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+      return SDValue();
+    if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+      return SDValue();
+    return V;
+  };
+  auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+  auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+  // two SDValues must be muls
+  if (!ShuffleA || !ShuffleB)
+    return SDValue();
+
+  if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+      ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+    return SDValue();
+
+  auto IsMulExtend =
+      [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+    if (V.getOpcode() != ISD::MUL)
+      return {};
+
+    auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+    if (V0.getOpcode() != I || V1.getOpcode() != I)
+      return {};
+    return {V0.getOperand(0), V1.getOperand(0)};
+  };
+
+  auto [LowA, LowB] =
+      IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+  auto [HighA, HighB] =
+      IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+  if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+    return SDValue();
+
+  return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == ISD::MUL);
   EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::MUL:
     return performMulCombine(N, DCI.DAG);
+  case ISD::ADD:
+    return performAddCombine(N, DCI.DAG);
   }
 }
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..7ac49794491a1
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,21 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK:         .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.dot_i16x8_s
+; CHECK-NEXT:    # fallthrough-return
+  %sext1 = sext <8 x i16> %a to <8 x i32>
+  %sext2 = sext <8 x i16> %b to <8 x i32>
+  %mul = mul nsw <8 x i32> %sext1, %sext2
+  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %res = add <4 x i32> %shuffle1, %shuffle2
+  ret <4 x i32> %res
+}
+

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is it possible to do this as a tablegen pattern? Not that it's necessarily the right thing to do, just wondering if it's easy to do or not!

; CHECK-NEXT: # fallthrough-return
%sext1 = sext <8 x i16> %a to <8 x i32>
%sext2 = sext <8 x i16> %b to <8 x i32>
%mul = mul nsw <8 x i32> %sext1, %sext2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should work without the nsw flag?

@badumbatish
Copy link
Contributor Author

Out of curiosity, is it possible to do this as a tablegen pattern? Not that it's necessarily the right thing to do, just wondering if it's easy to do or not!

I actually didn't think that tablegen can handle identical arguments in a pattern, i just tried it out just now and i think i might be able to make it work

@sparker-arm
Copy link
Contributor

i think i might be able to make it work

I'm assuming the 'illegal' types will make this more difficult in tablegen? This approach looks good to me.

};
auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
// two SDValues must be muls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment in the wrong place?

And it looks like you're assuming an order of these shuffles, is there some canonicalization that you're relying on?

};

auto [LowA, LowB] =
IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking for EXTEND_LOW / EXTEND_HIGH here, would it not be more simple to just look for SIGN_EXTEND nodes?

; RUN: llc < %s -mattr=+simd128 | FileCheck %s

target triple = "wasm32-unknown-unknown"
define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some negative tests, just so we can be sure a zext isn't slipping in, or the wrong kind of shuffle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SIMD] pattern for i32x4.dot_i16x8_s not recognized
4 participants