From 34076b4648ff455586ca03c97f21bfeb49642cbc Mon Sep 17 00:00:00 2001 From: Bastian Kersting Date: Wed, 23 Jul 2025 09:45:56 +0000 Subject: [PATCH] Extend the enum check to pointer and union reads This change extends the previously added enum discriminant check to enums read through a union or pointer. At the moment we only insert the check when transmuting to an enum. Although I hoped for it, this check isn't yet inserted for calls to `MaybeUninit::assume_init`, because the pass is running on polymorphic MIR and thus doesn't have the information yet to know whether the type that is read is an enum. --- .../rustc_mir_transform/src/check_enums.rs | 272 +++++++++++++----- compiler/rustc_mir_transform/src/lib.rs | 2 +- .../enum/pointer_manually_drop_read_break.rs | 19 ++ .../mir/enum/pointer_manually_drop_read_ok.rs | 18 ++ tests/ui/mir/enum/pointer_read_break.rs | 18 ++ tests/ui/mir/enum/pointer_read_ok.rs | 17 ++ .../enum/union_manually_drop_read_break.rs | 20 ++ .../mir/enum/union_manually_drop_read_ok.rs | 19 ++ tests/ui/mir/enum/union_read_break.rs | 21 ++ tests/ui/mir/enum/union_read_ok.rs | 20 ++ 10 files changed, 359 insertions(+), 67 deletions(-) create mode 100644 tests/ui/mir/enum/pointer_manually_drop_read_break.rs create mode 100644 tests/ui/mir/enum/pointer_manually_drop_read_ok.rs create mode 100644 tests/ui/mir/enum/pointer_read_break.rs create mode 100644 tests/ui/mir/enum/pointer_read_ok.rs create mode 100644 tests/ui/mir/enum/union_manually_drop_read_break.rs create mode 100644 tests/ui/mir/enum/union_manually_drop_read_ok.rs create mode 100644 tests/ui/mir/enum/union_read_break.rs create mode 100644 tests/ui/mir/enum/union_read_ok.rs diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs index 33a87cb987306..2610472bc433a 100644 --- a/compiler/rustc_mir_transform/src/check_enums.rs +++ b/compiler/rustc_mir_transform/src/check_enums.rs @@ -1,11 +1,11 @@ -use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange}; +use rustc_abi::{HasDataLayout, Scalar, Size, TagEncoding, Variants, WrappingRange}; use rustc_hir::LangItem; use rustc_index::IndexVec; use rustc_middle::bug; -use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; -use rustc_middle::ty::layout::PrimitiveExt; -use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv}; +use rustc_middle::ty::layout::{IntegerExt, PrimitiveExt}; +use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypingEnv}; use rustc_session::Session; use tracing::debug; @@ -148,79 +148,195 @@ impl<'a, 'tcx> EnumFinder<'a, 'tcx> { fn into_found_enums(self) -> Vec> { self.enums } + + /// Registers a new enum check in the finder. + fn register_new_check( + &mut self, + enum_ty: Ty<'tcx>, + enum_def: AdtDef<'tcx>, + source_op: Operand<'tcx>, + ) { + let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)) else { + return; + }; + // If the operand is a pointer, we want to pass on the size of the operand to the check, + // as we will dereference the pointer and look at the value directly. + let Ok(op_layout) = (if let ty::RawPtr(pointee_ty, _) = + source_op.ty(self.local_decls, self.tcx).kind() + { + self.tcx.layout_of(self.typing_env.as_query_input(*pointee_ty)) + } else { + self.tcx + .layout_of(self.typing_env.as_query_input(source_op.ty(self.local_decls, self.tcx))) + }) else { + return; + }; + + match enum_layout.variants { + Variants::Empty if op_layout.is_uninhabited() => return, + // An empty enum that tries to be constructed from an inhabited value, this + // is never correct. + Variants::Empty => { + // The enum layout is uninhabited but we construct it from sth inhabited. + // This is always UB. + self.enums.push(EnumCheckType::Uninhabited); + } + // Construction of Single value enums is always fine. + Variants::Single { .. } => {} + // Construction of an enum with multiple variants but no niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Direct, + tag: Scalar::Initialized { value, .. }, + .. + } => { + let valid_discrs = + enum_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect(); + + let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::Direct { + source_op: source_op.to_copy(), + discr, + op_size: op_layout.size, + valid_discrs, + }); + } + // Construction of an enum with multiple variants and niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Niche { .. }, + tag: Scalar::Initialized { value, valid_range, .. }, + tag_field, + .. + } => { + let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::WithNiche { + source_op: source_op.to_copy(), + discr, + op_size: op_layout.size, + offset: enum_layout.fields.offset(tag_field.as_usize()), + valid_range, + }); + } + _ => return, + } + } } impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> { - fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { - if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue { - let ty::Adt(adt_def, _) = ty.kind() else { - return; - }; - if !adt_def.is_enum() { - return; - } + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { + self.super_place(place, context, location); + // We only want to emit this check on pointer reads. + match context { + PlaceContext::NonMutatingUse( + NonMutatingUseContext::Copy + | NonMutatingUseContext::Move + | NonMutatingUseContext::SharedBorrow, + ) => (), + _ => return, + } + // Get the place and type we visit. + let pointer_ty = place.ty(self.local_decls, self.tcx).ty; - let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + // We only want to check places based on raw pointers to enums or ManuallyDrop. + let &ty::RawPtr(pointee_ty, _) = pointer_ty.kind() else { + return; + }; + let ty::Adt(enum_adt_def, _) = pointee_ty.kind() else { + return; + }; + + let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() { + (pointee_ty, enum_adt_def) + } else if enum_adt_def.is_manually_drop() { + // Find the type contained in the ManuallyDrop and check whether it is an enum. + let Some((manual_drop_arg, adt_def)) = + pointee_ty.walk().skip(1).next().map_or(None, |arg| { + if let Some(ty) = arg.as_type() + && let ty::Adt(adt_def, _) = ty.kind() + { + Some((ty, adt_def)) + } else { + None + } + }) + else { return; }; - let Ok(op_layout) = self - .tcx - .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx))) + + (manual_drop_arg, adt_def) + } else { + return; + }; + // Exclude c_void. + if enum_ty.is_c_void(self.tcx) { + return; + } + + self.register_new_check(enum_ty, *enum_adt_def, Operand::Copy(*place)); + } + + fn visit_projection_elem( + &mut self, + place_ref: PlaceRef<'tcx>, + elem: PlaceElem<'tcx>, + context: visit::PlaceContext, + location: Location, + ) { + self.super_projection_elem(place_ref, elem, context, location); + // Check whether we are reading an enum or a ManuallyDrop from a union. + let ty::Adt(union_adt_def, _) = place_ref.ty(self.local_decls, self.tcx).ty.kind() else { + return; + }; + if !union_adt_def.is_union() { + return; + } + let PlaceElem::Field(_, extracted_ty) = elem else { + return; + }; + let ty::Adt(enum_adt_def, _) = extracted_ty.kind() else { + return; + }; + let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() { + (extracted_ty, enum_adt_def) + } else if enum_adt_def.is_manually_drop() { + // Find the type contained in the ManuallyDrop and check whether it is an enum. + let Some((manual_drop_arg, adt_def)) = + extracted_ty.walk().skip(1).next().map_or(None, |arg| { + if let Some(ty) = arg.as_type() + && let ty::Adt(adt_def, _) = ty.kind() + { + Some((ty, adt_def)) + } else { + None + } + }) else { return; }; - match enum_layout.variants { - Variants::Empty if op_layout.is_uninhabited() => return, - // An empty enum that tries to be constructed from an inhabited value, this - // is never correct. - Variants::Empty => { - // The enum layout is uninhabited but we construct it from sth inhabited. - // This is always UB. - self.enums.push(EnumCheckType::Uninhabited); - } - // Construction of Single value enums is always fine. - Variants::Single { .. } => {} - // Construction of an enum with multiple variants but no niche optimizations. - Variants::Multiple { - tag_encoding: TagEncoding::Direct, - tag: Scalar::Initialized { value, .. }, - .. - } => { - let valid_discrs = - adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect(); - - let discr = - TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; - self.enums.push(EnumCheckType::Direct { - source_op: op.to_copy(), - discr, - op_size: op_layout.size, - valid_discrs, - }); - } - // Construction of an enum with multiple variants and niche optimizations. - Variants::Multiple { - tag_encoding: TagEncoding::Niche { .. }, - tag: Scalar::Initialized { value, valid_range, .. }, - tag_field, - .. - } => { - let discr = - TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; - self.enums.push(EnumCheckType::WithNiche { - source_op: op.to_copy(), - discr, - op_size: op_layout.size, - offset: enum_layout.fields.offset(tag_field.as_usize()), - valid_range, - }); - } - _ => return, + (manual_drop_arg, adt_def) + } else { + return; + }; + + self.register_new_check( + enum_ty, + *enum_adt_def, + Operand::Copy(place_ref.to_place(self.tcx)), + ); + } + + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue { + let ty::Adt(adt_def, _) = ty.kind() else { + return; + }; + if !adt_def.is_enum() { + return; } - self.super_rvalue(rvalue, location); + self.register_new_check(*ty, *adt_def, op.to_copy()); } + self.super_rvalue(rvalue, location); } } @@ -246,7 +362,7 @@ fn insert_discr_cast_to_u128<'tcx>( local_decls: &mut IndexVec>, block_data: &mut BasicBlockData<'tcx>, source_op: Operand<'tcx>, - discr: TyAndSize<'tcx>, + mut discr: TyAndSize<'tcx>, op_size: Size, offset: Option, source_info: SourceInfo, @@ -262,6 +378,29 @@ fn insert_discr_cast_to_u128<'tcx>( } }; + // If the enum is behind a pointer, cast it to a *[const|mut] MaybeUninit and then extract the discriminant through that. + let source_op = if let ty::RawPtr(pointee_ty, mutbl) = source_op.ty(local_decls, tcx).kind() + && !discr.ty.is_raw_ptr() + { + let mu_ptr_ty = Ty::new_ptr(tcx, Ty::new_maybe_uninit(tcx, *pointee_ty), *mutbl); + let mu_ptr_decl = + local_decls.push(LocalDecl::with_source_info(mu_ptr_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_ptr_ty); + block_data.statements.push(Statement::new( + source_info, + StatementKind::Assign(Box::new((mu_ptr_decl, rvalue))), + )); + + Operand::Copy(mu_ptr_decl.project_deeper(&[ProjectionElem::Deref], tcx)) + } else { + source_op + }; + + // Correct the discriminant ty to an integer, to not screw up our casts to the discriminant ty. + if discr.ty.is_raw_ptr() { + discr.ty = tcx.data_layout().ptr_sized_integer().to_ty(tcx, false); + } + let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() { // The discriminant is less wide than the operand, cast the operand into // [MaybeUninit; N] and then index into it. @@ -335,7 +474,8 @@ fn insert_direct_enum_check<'tcx>( new_block: BasicBlock, ) { // Insert a new target block that is branched to in case of an invalid discriminant. - let invalid_discr_block_data = BasicBlockData::new(None, false); + let invalid_discr_block_data = + BasicBlockData::new(None, basic_blocks[current_block].is_cleanup); let invalid_discr_block = basic_blocks.push(invalid_discr_block_data); let block_data = &mut basic_blocks[current_block]; let discr_place = insert_discr_cast_to_u128( diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 08f25276cecc1..2b5d5bafb3ced 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -665,9 +665,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<' body, &[ // Add some UB checks before any UB gets optimized away. + &check_enums::CheckEnums, &check_alignment::CheckAlignment, &check_null::CheckNull, - &check_enums::CheckEnums, // Before inlining: trim down MIR with passes to reduce inlining work. // Has to be done before inlining, otherwise actual call will be almost always inlined. diff --git a/tests/ui/mir/enum/pointer_manually_drop_read_break.rs b/tests/ui/mir/enum/pointer_manually_drop_read_break.rs new file mode 100644 index 0000000000000..96e4f94a2eefc --- /dev/null +++ b/tests/ui/mir/enum/pointer_manually_drop_read_break.rs @@ -0,0 +1,19 @@ +//@ run-crash +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x1 + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A, +} + +fn main() { + let illegal_val: u16 = 1; + let illegal_val_ptr = &raw const illegal_val; + let foo: *const std::mem::ManuallyDrop = + unsafe { std::mem::transmute(illegal_val_ptr) }; + + let val: Single = unsafe { foo.cast::().read() }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/pointer_manually_drop_read_ok.rs b/tests/ui/mir/enum/pointer_manually_drop_read_ok.rs new file mode 100644 index 0000000000000..2e9557e78460e --- /dev/null +++ b/tests/ui/mir/enum/pointer_manually_drop_read_ok.rs @@ -0,0 +1,18 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A, +} + +fn main() { + let illegal_val: u16 = 0; + let illegal_val_ptr = &raw const illegal_val; + let foo: *const std::mem::ManuallyDrop = + unsafe { std::mem::transmute(illegal_val_ptr) }; + + let val: Single = unsafe { foo.cast::().read() }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/pointer_read_break.rs b/tests/ui/mir/enum/pointer_read_break.rs new file mode 100644 index 0000000000000..99bf3579c910f --- /dev/null +++ b/tests/ui/mir/enum/pointer_read_break.rs @@ -0,0 +1,18 @@ +//@ run-crash +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x1 + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A, +} + +fn main() { + let illegal_val: u16 = 1; + let illegal_val_ptr = &raw const illegal_val; + let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) }; + + let val: Single = unsafe { foo.read() }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/pointer_read_ok.rs b/tests/ui/mir/enum/pointer_read_ok.rs new file mode 100644 index 0000000000000..bcf8b02bc2d6f --- /dev/null +++ b/tests/ui/mir/enum/pointer_read_ok.rs @@ -0,0 +1,17 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A, +} + +fn main() { + let illegal_val: u16 = 0; + let illegal_val_ptr = &raw const illegal_val; + let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) }; + + let val: Single = unsafe { foo.read() }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/union_manually_drop_read_break.rs b/tests/ui/mir/enum/union_manually_drop_read_break.rs new file mode 100644 index 0000000000000..7f75567b73ebc --- /dev/null +++ b/tests/ui/mir/enum/union_manually_drop_read_break.rs @@ -0,0 +1,20 @@ +//@ run-crash +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x1 + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A, +} + +union Foo { + a: std::mem::ManuallyDrop, +} + +fn main() { + let foo = Foo { a: unsafe { std::mem::transmute(1_u16) } }; + + let val: Single = unsafe { std::mem::ManuallyDrop::into_inner(foo.a) }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/union_manually_drop_read_ok.rs b/tests/ui/mir/enum/union_manually_drop_read_ok.rs new file mode 100644 index 0000000000000..6687d362ce4a8 --- /dev/null +++ b/tests/ui/mir/enum/union_manually_drop_read_ok.rs @@ -0,0 +1,19 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A, +} + +union Foo { + a: std::mem::ManuallyDrop, +} + +fn main() { + let foo = Foo { a: unsafe { std::mem::transmute(0_u16) } }; + + let val: Single = unsafe { std::mem::ManuallyDrop::into_inner(foo.a) }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/union_read_break.rs b/tests/ui/mir/enum/union_read_break.rs new file mode 100644 index 0000000000000..cf745d53396bc --- /dev/null +++ b/tests/ui/mir/enum/union_read_break.rs @@ -0,0 +1,21 @@ +//@ run-crash +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x1 + +#[allow(dead_code)] +#[repr(u16)] +#[derive(Copy, Clone)] +enum Single { + A, +} + +union Foo { + a: Single, +} + +fn main() { + let foo = Foo { a: unsafe { std::mem::transmute(1_u16) } }; + + let val: Single = unsafe { foo.a }; + println!("{}", val as u16); +} diff --git a/tests/ui/mir/enum/union_read_ok.rs b/tests/ui/mir/enum/union_read_ok.rs new file mode 100644 index 0000000000000..568194bf8aaab --- /dev/null +++ b/tests/ui/mir/enum/union_read_ok.rs @@ -0,0 +1,20 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +#[derive(Copy, Clone)] +enum Single { + A, +} + +union Foo { + a: Single, +} + +fn main() { + let foo = Foo { a: unsafe { std::mem::transmute(0_u16) } }; + + let val: Single = unsafe { foo.a }; + println!("{}", val as u16); +}