From b90e0fb67c2edf611b7a23705e647a146be80af5 Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Thu, 31 Jul 2025 19:28:48 -0700 Subject: [PATCH] [codegen] assume the tag, not the relative discriminant --- compiler/rustc_codegen_ssa/src/mir/operand.rs | 43 +++++++++++------ .../codegen-llvm/enum/enum-discriminant-eq.rs | 46 ++++++++++--------- tests/codegen-llvm/enum/enum-match.rs | 24 +++++----- 3 files changed, 66 insertions(+), 47 deletions(-) diff --git a/compiler/rustc_codegen_ssa/src/mir/operand.rs b/compiler/rustc_codegen_ssa/src/mir/operand.rs index 5459f95c18600..d851c3329802c 100644 --- a/compiler/rustc_codegen_ssa/src/mir/operand.rs +++ b/compiler/rustc_codegen_ssa/src/mir/operand.rs @@ -498,6 +498,35 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64); (is_niche, tagged_discr, 0) } else { + // Thanks to parameter attributes and load metadata, LLVM already knows + // the general valid range of the tag. It's possible, though, for there + // to be an impossible value *in the middle*, which those ranges don't + // communicate, so it's worth an `assume` to let the optimizer know. + // Most importantly, this means when optimizing a variant test like + // `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that + // to `!is_niche` because the `complex` part can't possibly match. + // + // This was previously asserted on `tagged_discr` below, where the + // impossible value is more obvious, but that caused an intermediate + // value to become multi-use and thus not optimize, so instead this + // assumes on the original input which is always multi-use. See + // + // + // FIXME: If we ever get range assume operand bundles in LLVM (so we + // don't need the `icmp`s in the instruction stream any more), it + // might be worth moving this back to being on the switch argument + // where it's more obviously applicable. + if niche_variants.contains(&untagged_variant) + && bx.cx().sess().opts.optimize != OptLevel::No + { + let impossible = niche_start + .wrapping_add(u128::from(untagged_variant.as_u32())) + .wrapping_sub(u128::from(niche_variants.start().as_u32())); + let impossible = bx.cx().const_uint_big(tag_llty, impossible); + let ne = bx.icmp(IntPredicate::IntNE, tag, impossible); + bx.assume(ne); + } + // With multiple niched variants we'll have to actually compute // the variant index from the stored tag. // @@ -588,20 +617,6 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> { let untagged_variant_const = bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32())); - // Thanks to parameter attributes and load metadata, LLVM already knows - // the general valid range of the tag. It's possible, though, for there - // to be an impossible value *in the middle*, which those ranges don't - // communicate, so it's worth an `assume` to let the optimizer know. - // Most importantly, this means when optimizing a variant test like - // `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that - // to `!is_niche` because the `complex` part can't possibly match. - if niche_variants.contains(&untagged_variant) - && bx.cx().sess().opts.optimize != OptLevel::No - { - let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const); - bx.assume(ne); - } - let discr = bx.select(is_niche, tagged_discr, untagged_variant_const); // In principle we could insert assumes on the possible range of `discr`, but diff --git a/tests/codegen-llvm/enum/enum-discriminant-eq.rs b/tests/codegen-llvm/enum/enum-discriminant-eq.rs index d599685c2e51b..214712fe911a4 100644 --- a/tests/codegen-llvm/enum/enum-discriminant-eq.rs +++ b/tests/codegen-llvm/enum/enum-discriminant-eq.rs @@ -91,16 +91,16 @@ pub enum Mid { pub fn mid_bool_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_bool_eq_discr( + // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, 3 + // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2 // CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i8 %a, 1 - // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1 - // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1 + // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, 3 + // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2 // CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i8 %b, 1 - // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1 - // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1 // CHECK: ret i1 %[[R]] @@ -111,16 +111,16 @@ pub fn mid_bool_eq_discr(a: Mid, b: Mid) -> bool { pub fn mid_ord_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_ord_eq_discr( + // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, 3 + // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_REL_DISCR:.+]] = add nsw i8 %a, -2 // CHECK: %[[A_IS_NICHE:.+]] = icmp sgt i8 %a, 1 - // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %[[A_REL_DISCR]], 1 - // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1 + // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, 3 + // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_REL_DISCR:.+]] = add nsw i8 %b, -2 // CHECK: %[[B_IS_NICHE:.+]] = icmp sgt i8 %b, 1 - // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %[[B_REL_DISCR]], 1 - // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1 // CHECK: %[[R:.+]] = icmp eq i8 %[[A_DISCR]], %[[B_DISCR]] @@ -140,16 +140,16 @@ pub fn mid_nz32_eq_discr(a: Mid>, b: Mid>) -> bool { pub fn mid_ac_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_ac_eq_discr( - // LLVM20: %[[A_REL_DISCR:.+]] = xor i8 %a, -128 - // CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0 // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i8 %a, -127 // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) + // LLVM20: %[[A_REL_DISCR:.+]] = xor i8 %a, -128 + // CHECK: %[[A_IS_NICHE:.+]] = icmp slt i8 %a, 0 // LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %[[A_REL_DISCR]], i8 1 - // LLVM20: %[[B_REL_DISCR:.+]] = xor i8 %b, -128 - // CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0 // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i8 %b, -127 // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) + // LLVM20: %[[B_REL_DISCR:.+]] = xor i8 %b, -128 + // CHECK: %[[B_IS_NICHE:.+]] = icmp slt i8 %b, 0 // LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i8 %[[B_REL_DISCR]], i8 1 // LLVM21: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i8 %a, i8 -127 @@ -166,21 +166,25 @@ pub fn mid_ac_eq_discr(a: Mid, b: Mid) -> bool { pub fn mid_giant_eq_discr(a: Mid, b: Mid) -> bool { // CHECK-LABEL: @mid_giant_eq_discr( + // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i128 %a, 6 + // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) // CHECK: %[[A_TRUNC:.+]] = trunc nuw nsw i128 %a to i64 - // CHECK: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5 + // LLVM20: %[[A_REL_DISCR:.+]] = add nsw i64 %[[A_TRUNC]], -5 // CHECK: %[[A_IS_NICHE:.+]] = icmp samesign ugt i128 %a, 4 - // CHECK: %[[A_NOT_HOLE:.+]] = icmp ne i64 %[[A_REL_DISCR]], 1 - // CHECK: tail call void @llvm.assume(i1 %[[A_NOT_HOLE]]) - // CHECK: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1 + // LLVM20: %[[A_DISCR:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_REL_DISCR]], i64 1 + // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i128 %b, 6 + // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) // CHECK: %[[B_TRUNC:.+]] = trunc nuw nsw i128 %b to i64 - // CHECK: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5 + // LLVM20: %[[B_REL_DISCR:.+]] = add nsw i64 %[[B_TRUNC]], -5 // CHECK: %[[B_IS_NICHE:.+]] = icmp samesign ugt i128 %b, 4 - // CHECK: %[[B_NOT_HOLE:.+]] = icmp ne i64 %[[B_REL_DISCR]], 1 - // CHECK: tail call void @llvm.assume(i1 %[[B_NOT_HOLE]]) - // CHECK: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1 + // LLVM20: %[[B_DISCR:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_REL_DISCR]], i64 1 + + // LLVM21: %[[A_MODIFIED_TAG:.+]] = select i1 %[[A_IS_NICHE]], i64 %[[A_TRUNC]], i64 6 + // LLVM21: %[[B_MODIFIED_TAG:.+]] = select i1 %[[B_IS_NICHE]], i64 %[[B_TRUNC]], i64 6 + // LLVM21: %[[R:.+]] = icmp eq i64 %[[A_MODIFIED_TAG]], %[[B_MODIFIED_TAG]] - // CHECK: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]] + // LLVM20: %[[R:.+]] = icmp eq i64 %[[A_DISCR]], %[[B_DISCR]] // CHECK: ret i1 %[[R]] discriminant_value(&a) == discriminant_value(&b) } diff --git a/tests/codegen-llvm/enum/enum-match.rs b/tests/codegen-llvm/enum/enum-match.rs index 57db44ec74e8f..091c4e9adf49b 100644 --- a/tests/codegen-llvm/enum/enum-match.rs +++ b/tests/codegen-llvm/enum/enum-match.rs @@ -138,18 +138,18 @@ pub fn match3(e: Option<&u8>) -> i16 { #[derive(PartialEq)] pub enum MiddleNiche { - A, - B, - C(bool), - D, - E, + A, // tag 2 + B, // tag 3 + C(bool), // untagged + D, // tag 5 + E, // tag 6 } // CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 -?[0-9]+, -?[0-9]+\))?}} i8 @match4(i8{{.+}}%0) // CHECK-NEXT: start: -// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2 -// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %[[REL_VAR]], 2 +// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %0, 4 // CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) +// CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2 // CHECK-NEXT: %[[NOT_NICHE:.+]] = icmp{{( samesign)?}} ult i8 %0, 2 // CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[NOT_NICHE]], i8 2, i8 %[[REL_VAR]] // CHECK-NEXT: switch i8 %[[DISCR]] @@ -443,19 +443,19 @@ pub enum HugeVariantIndex { V255(Never), V256(Never), - Possible257, - Bool258(bool), - Possible259, + Possible257, // tag 2 + Bool258(bool), // untagged + Possible259, // tag 4 } // CHECK-LABEL: define{{( dso_local)?}} noundef{{( range\(i8 [0-9]+, [0-9]+\))?}} i8 @match5(i8{{.+}}%0) // CHECK-NEXT: start: +// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i8 %0, 3 +// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) // CHECK-NEXT: %[[REL_VAR:.+]] = add{{( nsw)?}} i8 %0, -2 // CHECK-NEXT: %[[REL_VAR_WIDE:.+]] = zext i8 %[[REL_VAR]] to i64 // CHECK-NEXT: %[[IS_NICHE:.+]] = icmp{{( samesign)?}} ugt i8 %0, 1 // CHECK-NEXT: %[[NICHE_DISCR:.+]] = add nuw nsw i64 %[[REL_VAR_WIDE]], 257 -// CHECK-NEXT: %[[NOT_IMPOSSIBLE:.+]] = icmp ne i64 %[[NICHE_DISCR]], 258 -// CHECK-NEXT: call void @llvm.assume(i1 %[[NOT_IMPOSSIBLE]]) // CHECK-NEXT: %[[DISCR:.+]] = select i1 %[[IS_NICHE]], i64 %[[NICHE_DISCR]], i64 258 // CHECK-NEXT: switch i64 %[[DISCR]], // CHECK-NEXT: i64 257,