Skip to content

Commit aeaa360

Browse files
committed
Extend the enum check to pointer and union reads
This change extends the previously added enum discriminant check to enums read through a union or pointer. At the moment we only insert the check when transmuting to an enum. Although I hoped for it, this check isn't yet inserted for calls to `MaybeUninit::assume_init`, because the pass is running on polymorphic MIR and thus doesn't have the information yet to know whether the type that is read is an enum.
1 parent 6c0a912 commit aeaa360

10 files changed

+359
-67
lines changed

compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 206 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
1+
use rustc_abi::{HasDataLayout, Scalar, Size, TagEncoding, Variants, WrappingRange};
22
use rustc_hir::LangItem;
33
use rustc_index::IndexVec;
44
use rustc_middle::bug;
5-
use rustc_middle::mir::visit::Visitor;
5+
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
66
use rustc_middle::mir::*;
7-
use rustc_middle::ty::layout::PrimitiveExt;
8-
use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
7+
use rustc_middle::ty::layout::{IntegerExt, PrimitiveExt};
8+
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypingEnv};
99
use rustc_session::Session;
1010
use tracing::debug;
1111

@@ -148,79 +148,195 @@ impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
148148
fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
149149
self.enums
150150
}
151+
152+
/// Registers a new enum check in the finder.
153+
fn register_new_check(
154+
&mut self,
155+
enum_ty: Ty<'tcx>,
156+
enum_def: AdtDef<'tcx>,
157+
source_op: Operand<'tcx>,
158+
) {
159+
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)) else {
160+
return;
161+
};
162+
// If the operand is a pointer, we want to pass on the size of the operand to the check,
163+
// as we will dereference the pointer and look at the value directly.
164+
let Ok(op_layout) = (if let ty::RawPtr(pointee_ty, _) =
165+
source_op.ty(self.local_decls, self.tcx).kind()
166+
{
167+
self.tcx.layout_of(self.typing_env.as_query_input(*pointee_ty))
168+
} else {
169+
self.tcx
170+
.layout_of(self.typing_env.as_query_input(source_op.ty(self.local_decls, self.tcx)))
171+
}) else {
172+
return;
173+
};
174+
175+
match enum_layout.variants {
176+
Variants::Empty if op_layout.is_uninhabited() => return,
177+
// An empty enum that tries to be constructed from an inhabited value, this
178+
// is never correct.
179+
Variants::Empty => {
180+
// The enum layout is uninhabited but we construct it from sth inhabited.
181+
// This is always UB.
182+
self.enums.push(EnumCheckType::Uninhabited);
183+
}
184+
// Construction of Single value enums is always fine.
185+
Variants::Single { .. } => {}
186+
// Construction of an enum with multiple variants but no niche optimizations.
187+
Variants::Multiple {
188+
tag_encoding: TagEncoding::Direct,
189+
tag: Scalar::Initialized { value, .. },
190+
..
191+
} => {
192+
let valid_discrs =
193+
enum_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
194+
195+
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
196+
self.enums.push(EnumCheckType::Direct {
197+
source_op: source_op.to_copy(),
198+
discr,
199+
op_size: op_layout.size,
200+
valid_discrs,
201+
});
202+
}
203+
// Construction of an enum with multiple variants and niche optimizations.
204+
Variants::Multiple {
205+
tag_encoding: TagEncoding::Niche { .. },
206+
tag: Scalar::Initialized { value, valid_range, .. },
207+
tag_field,
208+
..
209+
} => {
210+
let discr = TyAndSize { ty: value.to_ty(self.tcx), size: value.size(&self.tcx) };
211+
self.enums.push(EnumCheckType::WithNiche {
212+
source_op: source_op.to_copy(),
213+
discr,
214+
op_size: op_layout.size,
215+
offset: enum_layout.fields.offset(tag_field.as_usize()),
216+
valid_range,
217+
});
218+
}
219+
_ => return,
220+
}
221+
}
151222
}
152223

153224
impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
154-
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, ___location: Location) {
155-
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
156-
let ty::Adt(adt_def, _) = ty.kind() else {
157-
return;
158-
};
159-
if !adt_def.is_enum() {
160-
return;
161-
}
225+
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, ___location: Location) {
226+
self.super_place(place, context, ___location);
227+
// We only want to emit this check on pointer reads.
228+
match context {
229+
PlaceContext::NonMutatingUse(
230+
NonMutatingUseContext::Copy
231+
| NonMutatingUseContext::Move
232+
| NonMutatingUseContext::SharedBorrow,
233+
) => (),
234+
_ => return,
235+
}
236+
// Get the place and type we visit.
237+
let pointer_ty = place.ty(self.local_decls, self.tcx).ty;
162238

163-
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
239+
// We only want to check places based on raw pointers to enums or ManuallyDrop<Enum>.
240+
let &ty::RawPtr(pointee_ty, _) = pointer_ty.kind() else {
241+
return;
242+
};
243+
let ty::Adt(enum_adt_def, _) = pointee_ty.kind() else {
244+
return;
245+
};
246+
247+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
248+
(pointee_ty, enum_adt_def)
249+
} else if enum_adt_def.is_manually_drop() {
250+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
251+
let Some((manual_drop_arg, adt_def)) =
252+
pointee_ty.walk().skip(1).next().map_or(None, |arg| {
253+
if let Some(ty) = arg.as_type()
254+
&& let ty::Adt(adt_def, _) = ty.kind()
255+
{
256+
Some((ty, adt_def))
257+
} else {
258+
None
259+
}
260+
})
261+
else {
164262
return;
165263
};
166-
let Ok(op_layout) = self
167-
.tcx
168-
.layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
264+
265+
(manual_drop_arg, adt_def)
266+
} else {
267+
return;
268+
};
269+
// Exclude c_void.
270+
if enum_ty.is_c_void(self.tcx) {
271+
return;
272+
}
273+
274+
self.register_new_check(enum_ty, *enum_adt_def, Operand::Copy(*place));
275+
}
276+
277+
fn visit_projection_elem(
278+
&mut self,
279+
place_ref: PlaceRef<'tcx>,
280+
elem: PlaceElem<'tcx>,
281+
context: visit::PlaceContext,
282+
___location: Location,
283+
) {
284+
self.super_projection_elem(place_ref, elem, context, ___location);
285+
// Check whether we are reading an enum or a ManuallyDrop<Enum> from a union.
286+
let ty::Adt(union_adt_def, _) = place_ref.ty(self.local_decls, self.tcx).ty.kind() else {
287+
return;
288+
};
289+
if !union_adt_def.is_union() {
290+
return;
291+
}
292+
let PlaceElem::Field(_, extracted_ty) = elem else {
293+
return;
294+
};
295+
let ty::Adt(enum_adt_def, _) = extracted_ty.kind() else {
296+
return;
297+
};
298+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
299+
(extracted_ty, enum_adt_def)
300+
} else if enum_adt_def.is_manually_drop() {
301+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
302+
let Some((manual_drop_arg, adt_def)) =
303+
extracted_ty.walk().skip(1).next().map_or(None, |arg| {
304+
if let Some(ty) = arg.as_type()
305+
&& let ty::Adt(adt_def, _) = ty.kind()
306+
{
307+
Some((ty, adt_def))
308+
} else {
309+
None
310+
}
311+
})
169312
else {
170313
return;
171314
};
172315

173-
match enum_layout.variants {
174-
Variants::Empty if op_layout.is_uninhabited() => return,
175-
// An empty enum that tries to be constructed from an inhabited value, this
176-
// is never correct.
177-
Variants::Empty => {
178-
// The enum layout is uninhabited but we construct it from sth inhabited.
179-
// This is always UB.
180-
self.enums.push(EnumCheckType::Uninhabited);
181-
}
182-
// Construction of Single value enums is always fine.
183-
Variants::Single { .. } => {}
184-
// Construction of an enum with multiple variants but no niche optimizations.
185-
Variants::Multiple {
186-
tag_encoding: TagEncoding::Direct,
187-
tag: Scalar::Initialized { value, .. },
188-
..
189-
} => {
190-
let valid_discrs =
191-
adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
192-
193-
let discr =
194-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
195-
self.enums.push(EnumCheckType::Direct {
196-
source_op: op.to_copy(),
197-
discr,
198-
op_size: op_layout.size,
199-
valid_discrs,
200-
});
201-
}
202-
// Construction of an enum with multiple variants and niche optimizations.
203-
Variants::Multiple {
204-
tag_encoding: TagEncoding::Niche { .. },
205-
tag: Scalar::Initialized { value, valid_range, .. },
206-
tag_field,
207-
..
208-
} => {
209-
let discr =
210-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
211-
self.enums.push(EnumCheckType::WithNiche {
212-
source_op: op.to_copy(),
213-
discr,
214-
op_size: op_layout.size,
215-
offset: enum_layout.fields.offset(tag_field.as_usize()),
216-
valid_range,
217-
});
218-
}
219-
_ => return,
316+
(manual_drop_arg, adt_def)
317+
} else {
318+
return;
319+
};
320+
321+
self.register_new_check(
322+
enum_ty,
323+
*enum_adt_def,
324+
Operand::Copy(place_ref.to_place(self.tcx)),
325+
);
326+
}
327+
328+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, ___location: Location) {
329+
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
330+
let ty::Adt(adt_def, _) = ty.kind() else {
331+
return;
332+
};
333+
if !adt_def.is_enum() {
334+
return;
220335
}
221336

222-
self.super_rvalue(rvalue, ___location);
337+
self.register_new_check(*ty, *adt_def, op.to_copy());
223338
}
339+
self.super_rvalue(rvalue, ___location);
224340
}
225341
}
226342

@@ -246,7 +362,7 @@ fn insert_discr_cast_to_u128<'tcx>(
246362
local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
247363
block_data: &mut BasicBlockData<'tcx>,
248364
source_op: Operand<'tcx>,
249-
discr: TyAndSize<'tcx>,
365+
mut discr: TyAndSize<'tcx>,
250366
op_size: Size,
251367
offset: Option<Size>,
252368
source_info: SourceInfo,
@@ -262,6 +378,29 @@ fn insert_discr_cast_to_u128<'tcx>(
262378
}
263379
};
264380

381+
// If the enum is behind a pointer, cast it to a *[const|mut] MaybeUninit<T> and then extract the discriminant through that.
382+
let source_op = if let ty::RawPtr(pointee_ty, mutbl) = source_op.ty(local_decls, tcx).kind()
383+
&& !discr.ty.is_raw_ptr()
384+
{
385+
let mu_ptr_ty = Ty::new_ptr(tcx, Ty::new_maybe_uninit(tcx, *pointee_ty), *mutbl);
386+
let mu_ptr_decl =
387+
local_decls.push(LocalDecl::with_source_info(mu_ptr_ty, source_info)).into();
388+
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_ptr_ty);
389+
block_data.statements.push(Statement::new(
390+
source_info,
391+
StatementKind::Assign(Box::new((mu_ptr_decl, rvalue))),
392+
));
393+
394+
Operand::Copy(mu_ptr_decl.project_deeper(&[ProjectionElem::Deref], tcx))
395+
} else {
396+
source_op
397+
};
398+
399+
// Correct the discriminant ty to an integer, to not screw up our casts to the discriminant ty.
400+
if discr.ty.is_raw_ptr() {
401+
discr.ty = tcx.data_layout().ptr_sized_integer().to_ty(tcx, false);
402+
}
403+
265404
let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
266405
// The discriminant is less wide than the operand, cast the operand into
267406
// [MaybeUninit; N] and then index into it.
@@ -335,7 +474,8 @@ fn insert_direct_enum_check<'tcx>(
335474
new_block: BasicBlock,
336475
) {
337476
// Insert a new target block that is branched to in case of an invalid discriminant.
338-
let invalid_discr_block_data = BasicBlockData::new(None, false);
477+
let invalid_discr_block_data =
478+
BasicBlockData::new(None, basic_blocks[current_block].is_cleanup);
339479
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340480
let block_data = &mut basic_blocks[current_block];
341481
let discr_place = insert_discr_cast_to_u128(

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,9 +665,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
665665
body,
666666
&[
667667
// Add some UB checks before any UB gets optimized away.
668+
&check_enums::CheckEnums,
668669
&check_alignment::CheckAlignment,
669670
&check_null::CheckNull,
670-
&check_enums::CheckEnums,
671671
// Before inlining: trim down MIR with passes to reduce inlining work.
672672

673673
// Has to be done before inlining, otherwise actual call will be almost always inlined.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
enum Single {
8+
A,
9+
}
10+
11+
fn main() {
12+
let illegal_val: u16 = 1;
13+
let illegal_val_ptr = &raw const illegal_val;
14+
let foo: *const std::mem::ManuallyDrop<Single> =
15+
unsafe { std::mem::transmute(illegal_val_ptr) };
16+
17+
let val: Single = unsafe { foo.cast::<Single>().read() };
18+
println!("{}", val as u16);
19+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[repr(u16)]
6+
enum Single {
7+
A,
8+
}
9+
10+
fn main() {
11+
let illegal_val: u16 = 0;
12+
let illegal_val_ptr = &raw const illegal_val;
13+
let foo: *const std::mem::ManuallyDrop<Single> =
14+
unsafe { std::mem::transmute(illegal_val_ptr) };
15+
16+
let val: Single = unsafe { foo.cast::<Single>().read() };
17+
println!("{}", val as u16);
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
enum Single {
8+
A,
9+
}
10+
11+
fn main() {
12+
let illegal_val: u16 = 1;
13+
let illegal_val_ptr = &raw const illegal_val;
14+
let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) };
15+
16+
let val: Single = unsafe { foo.read() };
17+
println!("{}", val as u16);
18+
}

0 commit comments

Comments
 (0)