Skip to content

[codegen] assume the tag, not the relative discriminant #144764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,35 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
// Most importantly, this means when optimizing a variant test like
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
// to `!is_niche` because the `complex` part can't possibly match.
//
// This was previously asserted on `tagged_discr` below, where the
// impossible value is more obvious, but that caused an intermediate
// value to become multi-use and thus not optimize, so instead this
// assumes on the original input which is always multi-use. See
// <https://github.com/llvm/llvm-project/issues/134024#issuecomment-3131782555>
//
// FIXME: If we ever get range assume operand bundles in LLVM (so we
// don't need the `icmp`s in the instruction stream any more), it
// might be worth moving this back to being on the switch argument
// where it's more obviously applicable.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let impossible = niche_start
.wrapping_add(u128::from(untagged_variant.as_u32()))
.wrapping_sub(u128::from(niche_variants.start().as_u32()));
let impossible = bx.cx().const_uint_big(tag_llty, impossible);
let ne = bx.icmp(IntPredicate::IntNE, tag, impossible);
bx.assume(ne);
}

// With multiple niched variants we'll have to actually compute
// the variant index from the stored tag.
//
Expand Down Expand Up @@ -588,20 +617,6 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
let untagged_variant_const =
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
// Most importantly, this means when optimizing a variant test like
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
// to `!is_niche` because the `complex` part can't possibly match.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case this helps review, here's how you can see this just from the diff, rather than from https://doc.rust-lang.org/nightly/nightly-rustc/rustc_abi/enum.TagEncoding.html#variant.Niche:

Previously we were doing tagged_discr != untagged_variant here.

But

relative_discr = tag - niche_start
delta = niche_variants.start()
tagged_discr  = relative_discr + delta

so

   tagged_discr != untagged_variant
=> relative_discr + delta != untagged_variant
=> (tag - niche_start) + niche_variants.start() != untagged_variant
=> tag != niche_start  + untagged_variant - niche_variants.start() 

which is the calculation on line 522.

bx.assume(ne);
}

let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);

// In principle we could insert assumes on the possible range of `discr`, but
Expand Down
46 changes: 25 additions & 21 deletions tests/codegen-llvm/enum/enum-discriminant-eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ pub enum Mid<T> {
pub fn mid_bool_eq_discr(a: Mid<bool>, b: Mid<bool>) -> bool {
// CHECK-LABEL: @mid_bool_eq_discr(

// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, 3
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1

// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, 3
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1

// CHECK: ret i1 %[[R]]
Expand All @@ -111,16 +111,16 @@ pub fn mid_bool_eq_discr(a: Mid<bool>, b: Mid<bool>) -> bool {
pub fn mid_ord_eq_discr(a: Mid<Ordering>, b: Mid<Ordering>) -> bool {
// CHECK-LABEL: @mid_ord_eq_discr(

// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, 3
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2
// CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1

// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, 3
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2
// CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1

// CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]]
Expand All @@ -140,16 +140,16 @@ pub fn mid_nz32_eq_discr(a: Mid<NonZero<u32>>, b: Mid<NonZero<u32>>) -> bool {
pub fn mid_ac_eq_discr(a: Mid<AC>, b: Mid<AC>) -> bool {
// CHECK-LABEL: @mid_ac_eq_discr(

// LLVM20: %[[A_REL_DISCR:.+]] = xor i8 %a, -128
// CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, -127
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// LLVM20: %[[A_REL_DISCR:.+]] = xor i8 %a, -128
// CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0
// LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1

// LLVM20: %[[B_REL_DISCR:.+]] = xor i8 %b, -128
// CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, -127
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// LLVM20: %[[B_REL_DISCR:.+]] = xor i8 %b, -128
// CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0
// LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1

// LLVM21: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %a, i8 -127
Expand All @@ -166,21 +166,25 @@ pub fn mid_ac_eq_discr(a: Mid<AC>, b: Mid<AC>) -> bool {
pub fn mid_giant_eq_discr(a: Mid<Giant>, b: Mid<Giant>) -> bool {
// CHECK-LABEL: @mid_giant_eq_discr(

// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i128 %a, 6
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64
// CHECK: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5
// LLVM20: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5
// CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4
// CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i64 %[[A_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]])
// CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1
// LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1

// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i128 %b, 6
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64
// CHECK: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5
// LLVM20: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5
// CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4
// CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i64 %[[B_REL_DISCR]], 1
// CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]])
// CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1
// LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1

// LLVM21: %[[A_MODIFIED_TAG:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_TRUNC]], i64 6
// LLVM21: %[[B_MODIFIED_TAG:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_TRUNC]], i64 6
// LLVM21: %[[R:.+]] = icmp eq i64 %[[A_MODIFIED_TAG]], %[[B_MODIFIED_TAG]]

// CHECK: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]]
// LLVM20: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]]
// CHECK: ret i1 %[[R]]
discriminant_value(&a) == discriminant_value(&b)
}
Expand Down
24 changes: 12 additions & 12 deletions tests/codegen-llvm/enum/enum-match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,18 @@ pub fn match3(e: Option<&u8>) -> i16 {

#[derive(PartialEq)]
pub enum MiddleNiche {
A,
B,
C(bool),
D,
E,
A, // tag 2
B, // tag 3
C(bool), // untagged
D, // tag 5
E, // tag 6
}

// CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 -?[0-9]+, -?[0-9]+\))?}} i8 @match4(i8{{.+}}%0)
// CHECK-NEXT: start:
// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %[[REL_VAR]], 2
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %0, 4
// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]])
// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2
// CHECK-NEXT: %[[NOT_NICHE:.+]] = icmp{{( samesign)?}} ult i8 %0, 2
// CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[NOT_NICHE]], i8 2, i8 %[[REL_VAR]]
// CHECK-NEXT: switch i8 %[[DISCR]]
Expand Down Expand Up @@ -443,19 +443,19 @@ pub enum HugeVariantIndex {
V255(Never),
V256(Never),

Possible257,
Bool258(bool),
Possible259,
Possible257, // tag 2
Bool258(bool), // untagged
Possible259, // tag 4
}

// CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 [0-9]+, [0-9]+\))?}} i8 @match5(i8{{.+}}%0)
// CHECK-NEXT: start:
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %0, 3
// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]])
// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2
// CHECK-NEXT: %[[REL_VAR_WIDE:.+]] = zext i8 %[[REL_VAR]] to i64
// CHECK-NEXT: %[[IS_NICHE:.+]] = icmp{{( samesign)?}} ugt i8 %0, 1
// CHECK-NEXT: %[[NICHE_DISCR:.+]] = add nuw nsw i64 %[[REL_VAR_WIDE]], 257
// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i64 %[[NICHE_DISCR]], 258
// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]])
// CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[IS_NICHE]], i64 %[[NICHE_DISCR]], i64 258
// CHECK-NEXT: switch i64 %[[DISCR]],
// CHECK-NEXT: i64 257,
Expand Down
Loading