From 4d304c888e4aecac25ee4a17e52ab5e4861e1a6a Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Fri, 1 Aug 2025 14:03:04 -0700 Subject: [PATCH 1/4] Precommit test --- .../WebAssembly/simd-dot-reductions.ll | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll new file mode 100644 index 0000000000000..76c20c404e6f0 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll @@ -0,0 +1,32 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mattr=+simd128 | FileCheck %s + +target triple = "wasm32-unknown-unknown" +define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: dot: +; CHECK: .functype dot (v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.extmul_low_i16x8_s +; CHECK-NEXT: local.tee 2 +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.extmul_high_i16x8_s +; CHECK-NEXT: local.tee 1 +; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 +; CHECK-NEXT: i32x4.add +; CHECK-NEXT: # fallthrough-return + %sext1 = sext <8 x i16> %a to <8 x i32> + %sext2 = sext <8 x i16> %b to <8 x i32> + %mul = mul nsw <8 x i32> %sext1, %sext2 + %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %res = add <4 x i32> %shuffle1, %shuffle2 + ret <4 x i32> %res +} + From cb9aac0407cb67fbf705a7c18c2b842bb4623466 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Fri, 1 Aug 2025 14:21:51 -0700 Subject: [PATCH 2/4] Added combine support for dot --- .../WebAssembly/WebAssemblyISelLowering.cpp | 51 +++++++++++++++++++ .../WebAssembly/simd-dot-reductions.ll | 13 +---- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index cd434f7a331e4..648e3b6b2b440 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // Combine wide-vector muls, with extend inputs, to extmul_half. setTargetDAGCombine(ISD::MUL); + // Combine add with vector shuffle of muls to dots + setTargetDAGCombine(ISD::ADD); + // Combine vector mask reductions into alltrue/anytrue setTargetDAGCombine(ISD::SETCC); @@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N, return SDValue(); } +static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) { + assert(N->getOpcode() == ISD::ADD); + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); + + if (VT != MVT::v4i32) + return SDValue(); + + auto IsShuffleWithMask = [](SDValue V, ArrayRef ShuffleValue) { + if (V.getOpcode() != ISD::VECTOR_SHUFFLE) + return SDValue(); + if (cast(V)->getMask() != ShuffleValue) + return SDValue(); + return V; + }; + auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6}); + auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7}); + // two SDValues must be muls + if (!ShuffleA || !ShuffleB) + return SDValue(); + + if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) || + ShuffleA.getOperand(1) != ShuffleB.getOperand(1)) + return SDValue(); + + auto IsMulExtend = + [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair { + if (V.getOpcode() != ISD::MUL) + return {}; + + auto V0 = V.getOperand(0), V1 = V.getOperand(1); + if (V0.getOpcode() != I || V1.getOpcode() != I) + return {}; + return {V0.getOperand(0), V1.getOperand(0)}; + }; + + auto [LowA, LowB] = + IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S); + auto [HighA, HighB] = + IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S); + + if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB) + return SDValue(); + + return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB); +} static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) { assert(N->getOpcode() == ISD::MUL); EVT VT = N->getValueType(0); @@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, } case ISD::MUL: return performMulCombine(N, DCI.DAG); + case ISD::ADD: + return performAddCombine(N, DCI.DAG); } } diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll index 76c20c404e6f0..7ac49794491a1 100644 --- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll +++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll @@ -5,21 +5,10 @@ target triple = "wasm32-unknown-unknown" define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) { ; CHECK-LABEL: dot: ; CHECK: .functype dot (v128, v128) -> (v128) -; CHECK-NEXT: .local v128 ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i32x4.extmul_low_i16x8_s -; CHECK-NEXT: local.tee 2 -; CHECK-NEXT: local.get 0 -; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i32x4.extmul_high_i16x8_s -; CHECK-NEXT: local.tee 1 -; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 -; CHECK-NEXT: local.get 2 -; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 -; CHECK-NEXT: i32x4.add +; CHECK-NEXT: i32x4.dot_i16x8_s ; CHECK-NEXT: # fallthrough-return %sext1 = sext <8 x i16> %a to <8 x i32> %sext2 = sext <8 x i16> %b to <8 x i32> From 86fe99b07c58ebd696bd6bd24ae4e74a728c336c Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 5 Aug 2025 10:32:51 -0700 Subject: [PATCH 3/4] Transition to tablegen for pattern --- .../WebAssembly/WebAssemblyISelLowering.cpp | 52 ------------------- .../WebAssembly/WebAssemblyInstrSIMD.td | 21 ++++++++ 2 files changed, 21 insertions(+), 52 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 0955e2d2f39b0..3f80b2ab2bd6d 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -192,9 +192,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( // Combine wide-vector muls, with extend inputs, to extmul_half. setTargetDAGCombine(ISD::MUL); - // Combine add with vector shuffle of muls to dots - setTargetDAGCombine(ISD::ADD); - // Combine vector mask reductions into alltrue/anytrue setTargetDAGCombine(ISD::SETCC); @@ -3439,53 +3436,6 @@ static SDValue performSETCCCombine(SDNode *N, return SDValue(); } -static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) { - assert(N->getOpcode() == ISD::ADD); - EVT VT = N->getValueType(0); - SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); - - if (VT != MVT::v4i32) - return SDValue(); - - auto IsShuffleWithMask = [](SDValue V, ArrayRef ShuffleValue) { - if (V.getOpcode() != ISD::VECTOR_SHUFFLE) - return SDValue(); - if (cast(V)->getMask() != ShuffleValue) - return SDValue(); - return V; - }; - auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6}); - auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7}); - // two SDValues must be muls - if (!ShuffleA || !ShuffleB) - return SDValue(); - - if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) || - ShuffleA.getOperand(1) != ShuffleB.getOperand(1)) - return SDValue(); - - auto IsMulExtend = - [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair { - if (V.getOpcode() != ISD::MUL) - return {}; - - auto V0 = V.getOperand(0), V1 = V.getOperand(1); - if (V0.getOpcode() != I || V1.getOpcode() != I) - return {}; - return {V0.getOperand(0), V1.getOperand(0)}; - }; - - auto [LowA, LowB] = - IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S); - auto [HighA, HighB] = - IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S); - - if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB) - return SDValue(); - - return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB); -} - static SDValue TryWideExtMulCombine(SDNode *N, SelectionDAG &DAG) { EVT VT = N->getValueType(0); if (VT != MVT::v8i32 && VT != MVT::v16i32) @@ -3647,7 +3597,5 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, } case ISD::MUL: return performMulCombine(N, DCI); - case ISD::ADD: - return performAddCombine(N, DCI.DAG); } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index 143298b700928..15da6567af6f4 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1210,6 +1210,27 @@ defm EXTMUL_LOW_U : defm EXTMUL_HIGH_U : SIMDExtBinary; +// Pattern for dot +def : Pat< + (v4i32 (add + (wasm_shuffle + (v4i32 (extmul_low_s v8i16:$lhs, v8i16:$rhs)), + (v4i32 (extmul_high_s v8i16:$lhs, v8i16:$rhs)), + (i32 0), (i32 1), (i32 2), (i32 3), + (i32 8), (i32 9), (i32 10), (i32 11), + (i32 16), (i32 17), (i32 18), (i32 19), + (i32 24), (i32 25), (i32 26), (i32 27)), + (wasm_shuffle + (v4i32 (extmul_low_s v8i16:$lhs, v8i16:$rhs)), + (v4i32 (extmul_high_s v8i16:$lhs, v8i16:$rhs)), + (i32 4), (i32 5), (i32 6), (i32 7), + (i32 12), (i32 13), (i32 14), (i32 15), + (i32 20), (i32 21), (i32 22), (i32 23), + (i32 28), (i32 29), (i32 30), (i32 31))) + ), + (v4i32 (DOT v8i16:$lhs, v8i16:$rhs)) +>; + //===----------------------------------------------------------------------===// // Floating-point unary arithmetic //===----------------------------------------------------------------------===// From 34f58f17590368f6c55a6a40b0f023a8ef1ce351 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 5 Aug 2025 11:13:28 -0700 Subject: [PATCH 4/4] Addresses PR reviews --- .../WebAssembly/simd-dot-reductions.ll | 75 ++++++++++++++++++- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll index 7ac49794491a1..fd50287a231d3 100644 --- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll +++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll @@ -2,9 +2,10 @@ ; RUN: llc < %s -mattr=+simd128 | FileCheck %s target triple = "wasm32-unknown-unknown" -define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) { -; CHECK-LABEL: dot: -; CHECK: .functype dot (v128, v128) -> (v128) + +define <4 x i32> @dot_sext_1(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: dot_sext_1: +; CHECK: .functype dot_sext_1 (v128, v128) -> (v128) ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 @@ -12,10 +13,76 @@ define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) { ; CHECK-NEXT: # fallthrough-return %sext1 = sext <8 x i16> %a to <8 x i32> %sext2 = sext <8 x i16> %b to <8 x i32> - %mul = mul nsw <8 x i32> %sext1, %sext2 + %mul = mul <8 x i32> %sext1, %sext2 + %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %res = add <4 x i32> %shuffle1, %shuffle2 + ret <4 x i32> %res +} + + +define <4 x i32> @dot_sext_2(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: dot_sext_2: +; CHECK: .functype dot_sext_2 (v128, v128) -> (v128) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.dot_i16x8_s +; CHECK-NEXT: # fallthrough-return + %sext1 = sext <8 x i16> %a to <8 x i32> + %sext2 = sext <8 x i16> %b to <8 x i32> + %mul = mul <8 x i32> %sext1, %sext2 + %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %res = add <4 x i32> %shuffle2, %shuffle1 + ret <4 x i32> %res +} + +define <4 x i32> @dot_zext(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: dot_zext: +; CHECK: .functype dot_zext (v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.extmul_low_i16x8_u +; CHECK-NEXT: local.tee 2 +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.extmul_high_i16x8_u +; CHECK-NEXT: local.tee 1 +; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27 +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 +; CHECK-NEXT: i32x4.add +; CHECK-NEXT: # fallthrough-return + %zext1 = zext <8 x i16> %a to <8 x i32> + %zext2 = zext <8 x i16> %b to <8 x i32> + %mul = mul <8 x i32> %zext1, %zext2 %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> %res = add <4 x i32> %shuffle1, %shuffle2 ret <4 x i32> %res } +define <4 x i32> @dot_wrong_shuffle(<8 x i16> %a, <8 x i16> %b) { +; CHECK-LABEL: dot_wrong_shuffle: +; CHECK: .functype dot_wrong_shuffle (v128, v128) -> (v128) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.extmul_low_i16x8_s +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i32x4.extmul_high_i16x8_s +; CHECK-NEXT: i32x4.add +; CHECK-NEXT: # fallthrough-return + %sext1 = sext <8 x i16> %a to <8 x i32> + %sext2 = sext <8 x i16> %b to <8 x i32> + %mul = mul <8 x i32> %sext1, %sext2 + %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> + %res = add <4 x i32> %shuffle1, %shuffle2 + ret <4 x i32> %res +}