Skip to content

Patterns: represent constants as valtrees #144591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/mir/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ impl<'tcx> Const<'tcx> {
Self::Val(val, ty)
}

#[inline]
pub fn from_ty_value(tcx: TyCtxt<'tcx>, val: ty::Value<'tcx>) -> Self {
Self::Ty(val.ty, ty::Const::new_value(tcx, val.valtree, val.ty))
}

pub fn from_bits(
tcx: TyCtxt<'tcx>,
bits: u128,
Expand Down
70 changes: 26 additions & 44 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -832,15 +832,15 @@ pub enum PatKind<'tcx> {
},

/// One of the following:
/// * `&str` (represented as a valtree), which will be handled as a string pattern and thus
/// * `&str`, which will be handled as a string pattern and thus
/// exhaustiveness checking will detect if you use the same string twice in different
/// patterns.
/// * integer, bool, char or float (represented as a valtree), which will be handled by
/// * integer, bool, char or float, which will be handled by
/// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are
/// much simpler.
/// * `String`, if `string_deref_patterns` is enabled.
Constant {
value: mir::Const<'tcx>,
value: ty::Value<'tcx>,
},

/// Pattern obtained by converting a constant (inline or named) to its pattern
Expand Down Expand Up @@ -933,7 +933,7 @@ impl<'tcx> PatRange<'tcx> {
let lo_is_min = match self.lo {
PatRangeBoundary::NegInfinity => true,
PatRangeBoundary::Finite(value) => {
let lo = value.try_to_bits(size).unwrap() ^ bias;
let lo = value.try_to_scalar_int().unwrap().to_bits(size) ^ bias;
lo <= min
}
PatRangeBoundary::PosInfinity => false,
Expand All @@ -942,7 +942,7 @@ impl<'tcx> PatRange<'tcx> {
let hi_is_max = match self.hi {
PatRangeBoundary::NegInfinity => false,
PatRangeBoundary::Finite(value) => {
let hi = value.try_to_bits(size).unwrap() ^ bias;
let hi = value.try_to_scalar_int().unwrap().to_bits(size) ^ bias;
hi > max || hi == max && self.end == RangeEnd::Included
}
PatRangeBoundary::PosInfinity => true,
Expand All @@ -955,22 +955,17 @@ impl<'tcx> PatRange<'tcx> {
}

#[inline]
pub fn contains(
&self,
value: mir::Const<'tcx>,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<bool> {
pub fn contains(&self, value: ty::Value<'tcx>, tcx: TyCtxt<'tcx>) -> Option<bool> {
use Ordering::*;
debug_assert_eq!(self.ty, value.ty());
debug_assert_eq!(value.ty, self.ty);
let ty = self.ty;
let value = PatRangeBoundary::Finite(value);
let value = PatRangeBoundary::Finite(value.valtree);
// For performance, it's important to only do the second comparison if necessary.
Some(
match self.lo.compare_with(value, ty, tcx, typing_env)? {
match self.lo.compare_with(value, ty, tcx)? {
Less | Equal => true,
Greater => false,
} && match value.compare_with(self.hi, ty, tcx, typing_env)? {
} && match value.compare_with(self.hi, ty, tcx)? {
Less => true,
Equal => self.end == RangeEnd::Included,
Greater => false,
Expand All @@ -979,21 +974,16 @@ impl<'tcx> PatRange<'tcx> {
}

#[inline]
pub fn overlaps(
&self,
other: &Self,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<bool> {
pub fn overlaps(&self, other: &Self, tcx: TyCtxt<'tcx>) -> Option<bool> {
use Ordering::*;
debug_assert_eq!(self.ty, other.ty);
// For performance, it's important to only do the second comparison if necessary.
Some(
match other.lo.compare_with(self.hi, self.ty, tcx, typing_env)? {
match other.lo.compare_with(self.hi, self.ty, tcx)? {
Less => true,
Equal => self.end == RangeEnd::Included,
Greater => false,
} && match self.lo.compare_with(other.hi, self.ty, tcx, typing_env)? {
} && match self.lo.compare_with(other.hi, self.ty, tcx)? {
Less => true,
Equal => other.end == RangeEnd::Included,
Greater => false,
Expand All @@ -1004,11 +994,13 @@ impl<'tcx> PatRange<'tcx> {

impl<'tcx> fmt::Display for PatRange<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let PatRangeBoundary::Finite(value) = &self.lo {
if let &PatRangeBoundary::Finite(valtree) = &self.lo {
let value = ty::Value { ty: self.ty, valtree };
write!(f, "{value}")?;
}
if let PatRangeBoundary::Finite(value) = &self.hi {
if let &PatRangeBoundary::Finite(valtree) = &self.hi {
write!(f, "{}", self.end)?;
let value = ty::Value { ty: self.ty, valtree };
write!(f, "{value}")?;
} else {
// `0..` is parsed as an inclusive range, we must display it correctly.
Expand All @@ -1022,7 +1014,8 @@ impl<'tcx> fmt::Display for PatRange<'tcx> {
/// If present, the const must be of a numeric type.
#[derive(Copy, Clone, Debug, PartialEq, HashStable, TypeVisitable)]
pub enum PatRangeBoundary<'tcx> {
Finite(mir::Const<'tcx>),
/// The type of this valtree is stored in the surrounding `PatRange`.
Finite(ty::ValTree<'tcx>),
NegInfinity,
PosInfinity,
}
Expand All @@ -1033,20 +1026,15 @@ impl<'tcx> PatRangeBoundary<'tcx> {
matches!(self, Self::Finite(..))
}
#[inline]
pub fn as_finite(self) -> Option<mir::Const<'tcx>> {
pub fn as_finite(self) -> Option<ty::ValTree<'tcx>> {
match self {
Self::Finite(value) => Some(value),
Self::NegInfinity | Self::PosInfinity => None,
}
}
pub fn eval_bits(
self,
ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> u128 {
pub fn eval_bits(self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> u128 {
match self {
Self::Finite(value) => value.eval_bits(tcx, typing_env),
Self::Finite(value) => value.try_to_scalar_int().unwrap().to_bits_unchecked(),
Self::NegInfinity => {
// Unwrap is ok because the type is known to be numeric.
ty.numeric_min_and_max_as_bits(tcx).unwrap().0
Expand All @@ -1058,14 +1046,8 @@ impl<'tcx> PatRangeBoundary<'tcx> {
}
}

#[instrument(skip(tcx, typing_env), level = "debug", ret)]
pub fn compare_with(
self,
other: Self,
ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
typing_env: ty::TypingEnv<'tcx>,
) -> Option<Ordering> {
#[instrument(skip(tcx), level = "debug", ret)]
pub fn compare_with(self, other: Self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Ordering> {
use PatRangeBoundary::*;
match (self, other) {
// When comparing with infinities, we must remember that `0u8..` and `0u8..=255`
Expand Down Expand Up @@ -1093,8 +1075,8 @@ impl<'tcx> PatRangeBoundary<'tcx> {
_ => {}
}

let a = self.eval_bits(ty, tcx, typing_env);
let b = other.eval_bits(ty, tcx, typing_env);
let a = self.eval_bits(ty, tcx);
let b = other.eval_bits(ty, tcx);

match ty.kind() {
ty::Float(ty::FloatTy::F16) => {
Expand Down
15 changes: 15 additions & 0 deletions compiler/rustc_middle/src/ty/consts/valtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use std::fmt;
use std::ops::Deref;

use rustc_data_structures::intern::Interned;
use rustc_hir::def::Namespace;
use rustc_macros::{HashStable, Lift, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};

use super::ScalarInt;
use crate::mir::interpret::{ErrorHandled, Scalar};
use crate::ty::print::{FmtPrinter, PrettyPrinter};
use crate::ty::{self, Ty, TyCtxt};

/// This datastructure is used to represent the value of constants used in the type system.
Expand Down Expand Up @@ -133,6 +135,8 @@ pub type ConstToValTreeResult<'tcx> = Result<Result<ValTree<'tcx>, Ty<'tcx>>, Er
/// A type-level constant value.
///
/// Represents a typed, fully evaluated constant.
/// Note that this is used by pattern elaboration to represent values which cannot occur in types,
/// such as raw pointers and floats.
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
#[derive(HashStable, TyEncodable, TyDecodable, TypeFoldable, TypeVisitable, Lift)]
pub struct Value<'tcx> {
Expand Down Expand Up @@ -203,3 +207,14 @@ impl<'tcx> rustc_type_ir::inherent::ValueConst<TyCtxt<'tcx>> for Value<'tcx> {
self.valtree
}
}

impl<'tcx> fmt::Display for Value<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
ty::tls::with(move |tcx| {
let cv = tcx.lift(*self).unwrap();
let mut cx = FmtPrinter::new(tcx, Namespace::ValueNS);
cx.pretty_print_const_valtree(cv, /*print_ty*/ true)?;
f.write_str(&cx.into_buffer())
})
}
}
13 changes: 4 additions & 9 deletions compiler/rustc_middle/src/ty/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use rustc_hir::def_id::LocalDefId;
use rustc_span::source_map::Spanned;
use rustc_type_ir::{ConstKind, TypeFolder, VisitorResult, try_visit};

use super::print::PrettyPrinter;
use super::{GenericArg, GenericArgKind, Pattern, Region};
use crate::mir::PlaceElem;
use crate::ty::print::{FmtPrinter, Printer, with_no_trimmed_paths};
Expand Down Expand Up @@ -168,15 +167,11 @@ impl<'tcx> fmt::Debug for ty::Const<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// If this is a value, we spend some effort to make it look nice.
if let ConstKind::Value(cv) = self.kind() {
return ty::tls::with(move |tcx| {
let cv = tcx.lift(cv).unwrap();
let mut cx = FmtPrinter::new(tcx, Namespace::ValueNS);
cx.pretty_print_const_valtree(cv, /*print_ty*/ true)?;
f.write_str(&cx.into_buffer())
});
write!(f, "{}", cv)
} else {
// Fall back to something verbose.
write!(f, "{:?}", self.kind())
}
// Fall back to something verbose.
write!(f, "{:?}", self.kind())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl<'a, 'tcx> ParseCtxt<'a, 'tcx> {
});
}
};
values.push(value.eval_bits(self.tcx, self.typing_env));
values.push(value.valtree.unwrap_leaf().to_bits_unchecked());
targets.push(self.parse_block(arm.body)?);
}

Expand Down
18 changes: 9 additions & 9 deletions compiler/rustc_mir_build/src/builder/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_hir::{BindingMode, ByRef, LetStmt, LocalSource, Node};
use rustc_middle::bug;
use rustc_middle::middle::region;
use rustc_middle::mir::{self, *};
use rustc_middle::mir::*;
use rustc_middle::thir::{self, *};
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, Ty, ValTree, ValTreeKind};
use rustc_pattern_analysis::constructor::RangeEnd;
Expand Down Expand Up @@ -1260,7 +1260,7 @@ struct Ascription<'tcx> {
#[derive(Debug, Clone)]
enum TestCase<'tcx> {
Variant { adt_def: ty::AdtDef<'tcx>, variant_index: VariantIdx },
Constant { value: mir::Const<'tcx> },
Constant { value: ty::Value<'tcx> },
Range(Arc<PatRange<'tcx>>),
Slice { len: usize, variable_length: bool },
Deref { temp: Place<'tcx>, mutability: Mutability },
Expand Down Expand Up @@ -1333,11 +1333,11 @@ enum TestKind<'tcx> {
/// Test for equality with value, possibly after an unsizing coercion to
/// `ty`,
Eq {
value: Const<'tcx>,
value: ty::Value<'tcx>,
// Integer types are handled by `SwitchInt`, and constants with ADT
// types and `&[T]` types are converted back into patterns, so this can
// only be `&str`, `f32` or `f64`.
ty: Ty<'tcx>,
// only be `&str` or `f*`.
cast_ty: Ty<'tcx>,
},

/// Test whether the value falls within an inclusive or exclusive range.
Expand Down Expand Up @@ -1372,17 +1372,17 @@ pub(crate) struct Test<'tcx> {
enum TestBranch<'tcx> {
/// Success branch, used for tests with two possible outcomes.
Success,
/// Branch corresponding to this constant.
Constant(Const<'tcx>, u128),
/// Branch corresponding to this constant. Must be a scalar.
Constant(ty::Value<'tcx>),
/// Branch corresponding to this variant.
Variant(VariantIdx),
/// Failure branch for tests with two possible outcomes, and "otherwise" branch for other tests.
Failure,
}

impl<'tcx> TestBranch<'tcx> {
fn as_constant(&self) -> Option<&Const<'tcx>> {
if let Self::Constant(v, _) = self { Some(v) } else { None }
fn as_constant(&self) -> Option<ty::Value<'tcx>> {
if let Self::Constant(v) = self { Some(*v) } else { None }
}
}

Expand Down
Loading
Loading