Skip to content

Extend the enum check to pointer and union reads #144353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 206 additions & 66 deletions compiler/rustc_mir_transform/src/check_enums.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
use rustc_abi::{HasDataLayout, Scalar, Size, TagEncoding, Variants, WrappingRange};
use rustc_hir::LangItem;
use rustc_index::IndexVec;
use rustc_middle::bug;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::layout::PrimitiveExt;
use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
use rustc_middle::ty::layout::{IntegerExt, PrimitiveExt};
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypingEnv};
use rustc_session::Session;
use tracing::debug;

Expand Down Expand Up @@ -148,79 +148,195 @@ impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
self.enums
}

/// Registers a new enum check in the finder.
fn register_new_check(
&mut self,
enum_ty: Ty<'tcx>,
enum_def: AdtDef<'tcx>,
source_op: Operand<'tcx>,
) {
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)) else {
return;
};
// If the operand is a pointer, we want to pass on the size of the operand to the check,
// as we will dereference the pointer and look at the value directly.
let Ok(op_layout) = (if let ty::RawPtr(pointee_ty, _) =
source_op.ty(self.local_decls, self.tcx).kind()
{
self.tcx.layout_of(self.typing_env.as_query_input(*pointee_ty))
} else {
self.tcx
.layout_of(self.typing_env.as_query_input(source_op.ty(self.local_decls, self.tcx)))
}) else {
return;
};

match enum_layout.variants {
Variants::Empty if op_layout.is_uninhabited() => return,
// An empty enum that tries to be constructed from an inhabited value, this
// is never correct.
Variants::Empty => {
// The enum layout is uninhabited but we construct it from sth inhabited.
// This is always UB.
self.enums.push(EnumCheckType::Uninhabited);
}
// Construction of Single value enums is always fine.
Variants::Single { .. } => {}
// Construction of an enum with multiple variants but no niche optimizations.
Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: Scalar::Initialized { value, .. },
..
} => {
let valid_discrs =
enum_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();

let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
self.enums.push(EnumCheckType::Direct {
source_op: source_op.to_copy(),
discr,
op_size: op_layout.size,
valid_discrs,
});
}
// Construction of an enum with multiple variants and niche optimizations.
Variants::Multiple {
tag_encoding: TagEncoding::Niche { .. },
tag: Scalar::Initialized { value, valid_range, .. },
tag_field,
..
} => {
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
self.enums.push(EnumCheckType::WithNiche {
source_op: source_op.to_copy(),
discr,
op_size: op_layout.size,
offset: enum_layout.fields.offset(tag_field.as_usize()),
valid_range,
});
}
_ => return,
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, ___location: Location) {
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
let ty::Adt(adt_def, _) = ty.kind() else {
return;
};
if !adt_def.is_enum() {
return;
}
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, ___location: Location) {
self.super_place(place, context, ___location);
// We only want to emit this check on pointer reads.
match context {
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy
| NonMutatingUseContext::Move
| NonMutatingUseContext::SharedBorrow,
) => (),
_ => return,
}
// Get the place and type we visit.
let pointer_ty = place.ty(self.local_decls, self.tcx).ty;

let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
// We only want to check places based on raw pointers to enums or ManuallyDrop<Enum>.
let &ty::RawPtr(pointee_ty, _) = pointer_ty.kind() else {
return;
};
let ty::Adt(enum_adt_def, _) = pointee_ty.kind() else {
return;
};

let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
(pointee_ty, enum_adt_def)
} else if enum_adt_def.is_manually_drop() {
// Find the type contained in the ManuallyDrop and check whether it is an enum.
let Some((manual_drop_arg, adt_def)) =
pointee_ty.walk().skip(1).next().map_or(None, |arg| {
if let Some(ty) = arg.as_type()
&& let ty::Adt(adt_def, _) = ty.kind()
{
Some((ty, adt_def))
} else {
None
}
})
else {
return;
};
let Ok(op_layout) = self
.tcx
.layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))

(manual_drop_arg, adt_def)
} else {
return;
};
// Exclude c_void.
if enum_ty.is_c_void(self.tcx) {
return;
}

self.register_new_check(enum_ty, *enum_adt_def, Operand::Copy(*place));
}

fn visit_projection_elem(
&mut self,
place_ref: PlaceRef<'tcx>,
elem: PlaceElem<'tcx>,
context: visit::PlaceContext,
___location: Location,
) {
self.super_projection_elem(place_ref, elem, context, ___location);
// Check whether we are reading an enum or a ManuallyDrop<Enum> from a union.
let ty::Adt(union_adt_def, _) = place_ref.ty(self.local_decls, self.tcx).ty.kind() else {
return;
};
if !union_adt_def.is_union() {
return;
}
let PlaceElem::Field(_, extracted_ty) = elem else {
return;
};
let ty::Adt(enum_adt_def, _) = extracted_ty.kind() else {
return;
};
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
(extracted_ty, enum_adt_def)
} else if enum_adt_def.is_manually_drop() {
// Find the type contained in the ManuallyDrop and check whether it is an enum.
let Some((manual_drop_arg, adt_def)) =
extracted_ty.walk().skip(1).next().map_or(None, |arg| {
if let Some(ty) = arg.as_type()
&& let ty::Adt(adt_def, _) = ty.kind()
{
Some((ty, adt_def))
} else {
None
}
})
else {
return;
};

match enum_layout.variants {
Variants::Empty if op_layout.is_uninhabited() => return,
// An empty enum that tries to be constructed from an inhabited value, this
// is never correct.
Variants::Empty => {
// The enum layout is uninhabited but we construct it from sth inhabited.
// This is always UB.
self.enums.push(EnumCheckType::Uninhabited);
}
// Construction of Single value enums is always fine.
Variants::Single { .. } => {}
// Construction of an enum with multiple variants but no niche optimizations.
Variants::Multiple {
tag_encoding: TagEncoding::Direct,
tag: Scalar::Initialized { value, .. },
..
} => {
let valid_discrs =
adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();

let discr =
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
self.enums.push(EnumCheckType::Direct {
source_op: op.to_copy(),
discr,
op_size: op_layout.size,
valid_discrs,
});
}
// Construction of an enum with multiple variants and niche optimizations.
Variants::Multiple {
tag_encoding: TagEncoding::Niche { .. },
tag: Scalar::Initialized { value, valid_range, .. },
tag_field,
..
} => {
let discr =
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
self.enums.push(EnumCheckType::WithNiche {
source_op: op.to_copy(),
discr,
op_size: op_layout.size,
offset: enum_layout.fields.offset(tag_field.as_usize()),
valid_range,
});
}
_ => return,
(manual_drop_arg, adt_def)
} else {
return;
};

self.register_new_check(
enum_ty,
*enum_adt_def,
Operand::Copy(place_ref.to_place(self.tcx)),
);
}

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, ___location: Location) {
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
let ty::Adt(adt_def, _) = ty.kind() else {
return;
};
if !adt_def.is_enum() {
return;
}

self.super_rvalue(rvalue, ___location);
self.register_new_check(*ty, *adt_def, op.to_copy());
}
self.super_rvalue(rvalue, ___location);
}
}

Expand All @@ -246,7 +362,7 @@ fn insert_discr_cast_to_u128<'tcx>(
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
block_data: &mut BasicBlockData<'tcx>,
source_op: Operand<'tcx>,
discr: TyAndSize<'tcx>,
mut discr: TyAndSize<'tcx>,
op_size: Size,
offset: Option<Size>,
source_info: SourceInfo,
Expand All @@ -262,6 +378,29 @@ fn insert_discr_cast_to_u128<'tcx>(
}
};

// If the enum is behind a pointer, cast it to a *[const|mut] MaybeUninit<T> and then extract the discriminant through that.
let source_op = if let ty::RawPtr(pointee_ty, mutbl) = source_op.ty(local_decls, tcx).kind()
&& !discr.ty.is_raw_ptr()
{
let mu_ptr_ty = Ty::new_ptr(tcx, Ty::new_maybe_uninit(tcx, *pointee_ty), *mutbl);
let mu_ptr_decl =
local_decls.push(LocalDecl::with_source_info(mu_ptr_ty, source_info)).into();
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_ptr_ty);
block_data.statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((mu_ptr_decl, rvalue))),
));

Operand::Copy(mu_ptr_decl.project_deeper(&[ProjectionElem::Deref], tcx))
} else {
source_op
};

// Correct the discriminant ty to an integer, to not screw up our casts to the discriminant ty.
if discr.ty.is_raw_ptr() {
discr.ty = tcx.data_layout().ptr_sized_integer().to_ty(tcx, false);
}

let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
// The discriminant is less wide than the operand, cast the operand into
// [MaybeUninit; N] and then index into it.
Expand Down Expand Up @@ -335,7 +474,8 @@ fn insert_direct_enum_check<'tcx>(
new_block: BasicBlock,
) {
// Insert a new target block that is branched to in case of an invalid discriminant.
let invalid_discr_block_data = BasicBlockData::new(None, false);
let invalid_discr_block_data =
BasicBlockData::new(None, basic_blocks[current_block].is_cleanup);
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
let block_data = &mut basic_blocks[current_block];
let discr_place = insert_discr_cast_to_u128(
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,9 +665,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
body,
&[
// Add some UB checks before any UB gets optimized away.
&check_enums::CheckEnums,
&check_alignment::CheckAlignment,
&check_null::CheckNull,
&check_enums::CheckEnums,
// Before inlining: trim down MIR with passes to reduce inlining work.

// Has to be done before inlining, otherwise actual call will be almost always inlined.
Expand Down
19 changes: 19 additions & 0 deletions tests/ui/mir/enum/pointer_manually_drop_read_break.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@ run-crash
//@ compile-flags: -C debug-assertions
//@ error-pattern: trying to construct an enum from an invalid value 0x1

#[allow(dead_code)]
#[repr(u16)]
enum Single {
A,
}

fn main() {
let illegal_val: u16 = 1;
let illegal_val_ptr = &raw const illegal_val;
let foo: *const std::mem::ManuallyDrop<Single> =
unsafe { std::mem::transmute(illegal_val_ptr) };

let val: Single = unsafe { foo.cast::<Single>().read() };
println!("{}", val as u16);
}
18 changes: 18 additions & 0 deletions tests/ui/mir/enum/pointer_manually_drop_read_ok.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//@ run-pass
//@ compile-flags: -C debug-assertions

#[allow(dead_code)]
#[repr(u16)]
enum Single {
A,
}

fn main() {
let illegal_val: u16 = 0;
let illegal_val_ptr = &raw const illegal_val;
let foo: *const std::mem::ManuallyDrop<Single> =
unsafe { std::mem::transmute(illegal_val_ptr) };

let val: Single = unsafe { foo.cast::<Single>().read() };
println!("{}", val as u16);
}
18 changes: 18 additions & 0 deletions tests/ui/mir/enum/pointer_read_break.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//@ run-crash
//@ compile-flags: -C debug-assertions
//@ error-pattern: trying to construct an enum from an invalid value 0x1

#[allow(dead_code)]
#[repr(u16)]
enum Single {
A,
}

fn main() {
let illegal_val: u16 = 1;
let illegal_val_ptr = &raw const illegal_val;
let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) };

let val: Single = unsafe { foo.read() };
println!("{}", val as u16);
}
Loading
Loading