Skip to content

Commit 9b9e2da

Browse files
committed
[Analysis] add optional index parameter to isSplatValue()
We want to allow splat value transforms to improve PR44588 and related bugs: https://bugs.llvm.org/show_bug.cgi?id=44588 ...but to do that, we need to know if values are splatted from the same, specific index (lane) rather than splatted from an arbitrary index. We can improve the undef handling with 1-liner follow-ups because the Constant API optionally allow undefs now. Differential Revision: https://reviews.llvm.org/D73549
1 parent a9ab01a commit 9b9e2da

File tree

3 files changed

+142
-13
lines changed

3 files changed

+142
-13
lines changed

llvm/include/llvm/Analysis/VectorUtils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,13 @@ Value *findScalarElement(Value *V, unsigned EltNo);
306306
/// a sequence of instructions that broadcast a single value into a vector.
307307
const Value *getSplatValue(const Value *V);
308308

309-
/// Return true if the input value is known to be a vector with all identical
310-
/// elements (potentially including undefined elements).
309+
/// Return true if each element of the vector value \p V is poisoned or equal to
310+
/// every other non-poisoned element. If an index element is specified, either
311+
/// every element of the vector is poisoned or the element at that index is not
312+
/// poisoned and equal to every other non-poisoned element.
311313
/// This may be more powerful than the related getSplatValue() because it is
312314
/// not limited by finding a scalar source value to a splatted vector.
313-
bool isSplatValue(const Value *V, unsigned Depth = 0);
315+
bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
314316

315317
/// Compute a map of integer instructions to their minimum legal type
316318
/// size.

llvm/lib/Analysis/VectorUtils.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -330,21 +330,32 @@ const llvm::Value *llvm::getSplatValue(const Value *V) {
330330
// adjusted if needed.
331331
const unsigned MaxDepth = 6;
332332

333-
bool llvm::isSplatValue(const Value *V, unsigned Depth) {
333+
bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
334334
assert(Depth <= MaxDepth && "Limit Search Depth");
335335

336336
if (isa<VectorType>(V->getType())) {
337337
if (isa<UndefValue>(V))
338338
return true;
339-
// FIXME: Constant splat analysis does not allow undef elements.
339+
// FIXME: We can allow undefs, but if Index was specified, we may want to
340+
// check that the constant is defined at that index.
340341
if (auto *C = dyn_cast<Constant>(V))
341342
return C->getSplatValue() != nullptr;
342343
}
343344

344-
// FIXME: Constant splat analysis does not allow undef elements.
345-
Constant *Mask;
346-
if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask))))
347-
return Mask->getSplatValue() != nullptr;
345+
if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
346+
// FIXME: We can safely allow undefs here. If Index was specified, we will
347+
// check that the mask elt is defined at the required index.
348+
if (!Shuf->getMask()->getSplatValue())
349+
return false;
350+
351+
// Match any index.
352+
if (Index == -1)
353+
return true;
354+
355+
// Match a specific element. The mask should be defined at and match the
356+
// specified index.
357+
return Shuf->getMaskValue(Index) == Index;
358+
}
348359

349360
// The remaining tests are all recursive, so bail out if we hit the limit.
350361
if (Depth++ == MaxDepth)
@@ -353,12 +364,12 @@ bool llvm::isSplatValue(const Value *V, unsigned Depth) {
353364
// If both operands of a binop are splats, the result is a splat.
354365
Value *X, *Y, *Z;
355366
if (match(V, m_BinOp(m_Value(X), m_Value(Y))))
356-
return isSplatValue(X, Depth) && isSplatValue(Y, Depth);
367+
return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth);
357368

358369
// If all operands of a select are splats, the result is a splat.
359370
if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z))))
360-
return isSplatValue(X, Depth) && isSplatValue(Y, Depth) &&
361-
isSplatValue(Z, Depth);
371+
return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth) &&
372+
isSplatValue(Z, Index, Depth);
362373

363374
// TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops).
364375

llvm/unittests/Analysis/VectorUtilsTest.cpp

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ TEST_F(VectorUtilsTest, isSplatValue_00) {
107107
EXPECT_TRUE(isSplatValue(A));
108108
}
109109

110+
TEST_F(VectorUtilsTest, isSplatValue_00_index0) {
111+
parseAssembly(
112+
"define <2 x i8> @test(<2 x i8> %x) {\n"
113+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n"
114+
" ret <2 x i8> %A\n"
115+
"}\n");
116+
EXPECT_TRUE(isSplatValue(A, 0));
117+
}
118+
119+
TEST_F(VectorUtilsTest, isSplatValue_00_index1) {
120+
parseAssembly(
121+
"define <2 x i8> @test(<2 x i8> %x) {\n"
122+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n"
123+
" ret <2 x i8> %A\n"
124+
"}\n");
125+
EXPECT_FALSE(isSplatValue(A, 1));
126+
}
127+
110128
TEST_F(VectorUtilsTest, isSplatValue_11) {
111129
parseAssembly(
112130
"define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -116,6 +134,24 @@ TEST_F(VectorUtilsTest, isSplatValue_11) {
116134
EXPECT_TRUE(isSplatValue(A));
117135
}
118136

137+
TEST_F(VectorUtilsTest, isSplatValue_11_index0) {
138+
parseAssembly(
139+
"define <2 x i8> @test(<2 x i8> %x) {\n"
140+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
141+
" ret <2 x i8> %A\n"
142+
"}\n");
143+
EXPECT_FALSE(isSplatValue(A, 0));
144+
}
145+
146+
TEST_F(VectorUtilsTest, isSplatValue_11_index1) {
147+
parseAssembly(
148+
"define <2 x i8> @test(<2 x i8> %x) {\n"
149+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
150+
" ret <2 x i8> %A\n"
151+
"}\n");
152+
EXPECT_TRUE(isSplatValue(A, 1));
153+
}
154+
119155
TEST_F(VectorUtilsTest, isSplatValue_01) {
120156
parseAssembly(
121157
"define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -125,7 +161,25 @@ TEST_F(VectorUtilsTest, isSplatValue_01) {
125161
EXPECT_FALSE(isSplatValue(A));
126162
}
127163

128-
// FIXME: Constant (mask) splat analysis does not allow undef elements.
164+
TEST_F(VectorUtilsTest, isSplatValue_01_index0) {
165+
parseAssembly(
166+
"define <2 x i8> @test(<2 x i8> %x) {\n"
167+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 1>\n"
168+
" ret <2 x i8> %A\n"
169+
"}\n");
170+
EXPECT_FALSE(isSplatValue(A, 0));
171+
}
172+
173+
TEST_F(VectorUtilsTest, isSplatValue_01_index1) {
174+
parseAssembly(
175+
"define <2 x i8> @test(<2 x i8> %x) {\n"
176+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 1>\n"
177+
" ret <2 x i8> %A\n"
178+
"}\n");
179+
EXPECT_FALSE(isSplatValue(A, 1));
180+
}
181+
182+
// FIXME: Allow undef matching with Constant (mask) splat analysis.
129183

130184
TEST_F(VectorUtilsTest, isSplatValue_0u) {
131185
parseAssembly(
@@ -136,6 +190,26 @@ TEST_F(VectorUtilsTest, isSplatValue_0u) {
136190
EXPECT_FALSE(isSplatValue(A));
137191
}
138192

193+
// FIXME: Allow undef matching with Constant (mask) splat analysis.
194+
195+
TEST_F(VectorUtilsTest, isSplatValue_0u_index0) {
196+
parseAssembly(
197+
"define <2 x i8> @test(<2 x i8> %x) {\n"
198+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 undef>\n"
199+
" ret <2 x i8> %A\n"
200+
"}\n");
201+
EXPECT_FALSE(isSplatValue(A, 0));
202+
}
203+
204+
TEST_F(VectorUtilsTest, isSplatValue_0u_index1) {
205+
parseAssembly(
206+
"define <2 x i8> @test(<2 x i8> %x) {\n"
207+
" %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 undef>\n"
208+
" ret <2 x i8> %A\n"
209+
"}\n");
210+
EXPECT_FALSE(isSplatValue(A, 1));
211+
}
212+
139213
TEST_F(VectorUtilsTest, isSplatValue_Binop) {
140214
parseAssembly(
141215
"define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -147,6 +221,28 @@ TEST_F(VectorUtilsTest, isSplatValue_Binop) {
147221
EXPECT_TRUE(isSplatValue(A));
148222
}
149223

224+
TEST_F(VectorUtilsTest, isSplatValue_Binop_index0) {
225+
parseAssembly(
226+
"define <2 x i8> @test(<2 x i8> %x) {\n"
227+
" %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 0>\n"
228+
" %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
229+
" %A = udiv <2 x i8> %v0, %v1\n"
230+
" ret <2 x i8> %A\n"
231+
"}\n");
232+
EXPECT_FALSE(isSplatValue(A, 0));
233+
}
234+
235+
TEST_F(VectorUtilsTest, isSplatValue_Binop_index1) {
236+
parseAssembly(
237+
"define <2 x i8> @test(<2 x i8> %x) {\n"
238+
" %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 0>\n"
239+
" %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
240+
" %A = udiv <2 x i8> %v0, %v1\n"
241+
" ret <2 x i8> %A\n"
242+
"}\n");
243+
EXPECT_FALSE(isSplatValue(A, 1));
244+
}
245+
150246
TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) {
151247
parseAssembly(
152248
"define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -157,6 +253,26 @@ TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) {
157253
EXPECT_TRUE(isSplatValue(A));
158254
}
159255

256+
TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index0) {
257+
parseAssembly(
258+
"define <2 x i8> @test(<2 x i8> %x) {\n"
259+
" %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
260+
" %A = ashr <2 x i8> <i8 42, i8 42>, %v1\n"
261+
" ret <2 x i8> %A\n"
262+
"}\n");
263+
EXPECT_FALSE(isSplatValue(A, 0));
264+
}
265+
266+
TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index1) {
267+
parseAssembly(
268+
"define <2 x i8> @test(<2 x i8> %x) {\n"
269+
" %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
270+
" %A = ashr <2 x i8> <i8 42, i8 42>, %v1\n"
271+
" ret <2 x i8> %A\n"
272+
"}\n");
273+
EXPECT_TRUE(isSplatValue(A, 1));
274+
}
275+
160276
TEST_F(VectorUtilsTest, isSplatValue_Binop_Not_Op0) {
161277
parseAssembly(
162278
"define <2 x i8> @test(<2 x i8> %x) {\n"

0 commit comments

Comments
 (0)