Skip to content

Commit ec30c0d

Browse files
committed
Extend the enum check to pointer and union reads
This change extends the previously added enum discriminant check to enums read through a union or pointer. At the moment we only insert the check when transmuting to an enum. Although I hoped for it, this check isn't yet inserted for calls to `MaybeUninit::assume_init`, because the pass is running on polymorphic MIR and thus doesn't have the information yet to know whether the type that is read is an enum.
1 parent 6c0a912 commit ec30c0d

9 files changed

+350
-63
lines changed

compiler/rustc_mir_transform/src/check_enums.rs

Lines changed: 198 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
22
use rustc_hir::LangItem;
33
use rustc_index::IndexVec;
44
use rustc_middle::bug;
5-
use rustc_middle::mir::visit::Visitor;
5+
use rustc_middle::mir::visit::{NonMutatingUseContext, PlaceContext, Visitor};
66
use rustc_middle::mir::*;
77
use rustc_middle::ty::layout::PrimitiveExt;
8-
use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
8+
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypingEnv};
99
use rustc_session::Session;
1010
use tracing::debug;
1111

@@ -148,79 +148,197 @@ impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
148148
fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
149149
self.enums
150150
}
151+
152+
/// Registers a new enum check in the finder.
153+
fn register_new_check(
154+
&mut self,
155+
enum_ty: Ty<'tcx>,
156+
enum_def: AdtDef<'tcx>,
157+
source_op: Operand<'tcx>,
158+
) {
159+
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(enum_ty)) else {
160+
return;
161+
};
162+
// If the operand is a pointer, we want to pass on the size of the operand to the check,
163+
// as we will dereference the pointer and look at the value directly.
164+
let Ok(op_layout) = (if let ty::RawPtr(pointee_ty, _) =
165+
source_op.ty(self.local_decls, self.tcx).kind()
166+
{
167+
self.tcx.layout_of(self.typing_env.as_query_input(*pointee_ty))
168+
} else {
169+
self.tcx
170+
.layout_of(self.typing_env.as_query_input(source_op.ty(self.local_decls, self.tcx)))
171+
}) else {
172+
return;
173+
};
174+
175+
match enum_layout.variants {
176+
Variants::Empty if op_layout.is_uninhabited() => return,
177+
// An empty enum that tries to be constructed from an inhabited value, this
178+
// is never correct.
179+
Variants::Empty => {
180+
// The enum layout is uninhabited but we construct it from sth inhabited.
181+
// This is always UB.
182+
self.enums.push(EnumCheckType::Uninhabited);
183+
}
184+
// Construction of Single value enums is always fine.
185+
Variants::Single { .. } => {}
186+
// Construction of an enum with multiple variants but no niche optimizations.
187+
Variants::Multiple {
188+
tag_encoding: TagEncoding::Direct,
189+
tag: Scalar::Initialized { value, .. },
190+
..
191+
} => {
192+
let valid_discrs =
193+
enum_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
194+
195+
let discr =
196+
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
197+
self.enums.push(EnumCheckType::Direct {
198+
source_op: source_op.to_copy(),
199+
discr,
200+
op_size: op_layout.size,
201+
valid_discrs,
202+
});
203+
}
204+
// Construction of an enum with multiple variants and niche optimizations.
205+
Variants::Multiple {
206+
tag_encoding: TagEncoding::Niche { .. },
207+
tag: Scalar::Initialized { value, valid_range, .. },
208+
tag_field,
209+
..
210+
} => {
211+
let discr =
212+
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
213+
self.enums.push(EnumCheckType::WithNiche {
214+
source_op: source_op.to_copy(),
215+
discr,
216+
op_size: op_layout.size,
217+
offset: enum_layout.fields.offset(tag_field.as_usize()),
218+
valid_range,
219+
});
220+
}
221+
_ => return,
222+
}
223+
}
151224
}
152225

153226
impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
154-
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, ___location: Location) {
155-
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
156-
let ty::Adt(adt_def, _) = ty.kind() else {
157-
return;
158-
};
159-
if !adt_def.is_enum() {
160-
return;
161-
}
227+
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, ___location: Location) {
228+
self.super_place(place, context, ___location);
229+
// We only want to emit this check on pointer reads.
230+
match context {
231+
PlaceContext::NonMutatingUse(
232+
NonMutatingUseContext::Copy
233+
| NonMutatingUseContext::Move
234+
| NonMutatingUseContext::SharedBorrow,
235+
) => (),
236+
_ => return,
237+
}
238+
// Get the place and type we visit.
239+
let pointer = Place::from(place.local);
240+
let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty;
162241

163-
let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
242+
// We only want to check places based on raw pointers to enums or ManuallyDrop<Enum>.
243+
let &ty::RawPtr(pointee_ty, _) = pointer_ty.kind() else {
244+
return;
245+
};
246+
let ty::Adt(enum_adt_def, _) = pointee_ty.kind() else {
247+
return;
248+
};
249+
250+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
251+
(pointee_ty, enum_adt_def)
252+
} else if enum_adt_def.is_manually_drop() {
253+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
254+
let Some((manual_drop_arg, adt_def)) =
255+
pointee_ty.walk().skip(1).next().map_or(None, |arg| {
256+
if let Some(ty) = arg.as_type()
257+
&& let ty::Adt(adt_def, _) = ty.kind()
258+
{
259+
Some((ty, adt_def))
260+
} else {
261+
None
262+
}
263+
})
264+
else {
164265
return;
165266
};
166-
let Ok(op_layout) = self
167-
.tcx
168-
.layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
267+
268+
(manual_drop_arg, adt_def)
269+
} else {
270+
return;
271+
};
272+
// Exclude c_void.
273+
if enum_ty.is_c_void(self.tcx) {
274+
return;
275+
}
276+
self.register_new_check(enum_ty, *enum_adt_def, Operand::Copy(*place));
277+
}
278+
279+
fn visit_projection_elem(
280+
&mut self,
281+
place_ref: PlaceRef<'tcx>,
282+
elem: PlaceElem<'tcx>,
283+
context: visit::PlaceContext,
284+
___location: Location,
285+
) {
286+
self.super_projection_elem(place_ref, elem, context, ___location);
287+
// Check whether we are reading an enum or a ManuallyDrop<Enum> from a union.
288+
let ty::Adt(union_adt_def, _) = place_ref.ty(self.local_decls, self.tcx).ty.kind() else {
289+
return;
290+
};
291+
if !union_adt_def.is_union() {
292+
return;
293+
}
294+
let PlaceElem::Field(_, extracted_ty) = elem else {
295+
return;
296+
};
297+
let ty::Adt(enum_adt_def, _) = extracted_ty.kind() else {
298+
return;
299+
};
300+
let (enum_ty, enum_adt_def) = if enum_adt_def.is_enum() {
301+
(extracted_ty, enum_adt_def)
302+
} else if enum_adt_def.is_manually_drop() {
303+
// Find the type contained in the ManuallyDrop and check whether it is an enum.
304+
let Some((manual_drop_arg, adt_def)) =
305+
extracted_ty.walk().skip(1).next().map_or(None, |arg| {
306+
if let Some(ty) = arg.as_type()
307+
&& let ty::Adt(adt_def, _) = ty.kind()
308+
{
309+
Some((ty, adt_def))
310+
} else {
311+
None
312+
}
313+
})
169314
else {
170315
return;
171316
};
172317

173-
match enum_layout.variants {
174-
Variants::Empty if op_layout.is_uninhabited() => return,
175-
// An empty enum that tries to be constructed from an inhabited value, this
176-
// is never correct.
177-
Variants::Empty => {
178-
// The enum layout is uninhabited but we construct it from sth inhabited.
179-
// This is always UB.
180-
self.enums.push(EnumCheckType::Uninhabited);
181-
}
182-
// Construction of Single value enums is always fine.
183-
Variants::Single { .. } => {}
184-
// Construction of an enum with multiple variants but no niche optimizations.
185-
Variants::Multiple {
186-
tag_encoding: TagEncoding::Direct,
187-
tag: Scalar::Initialized { value, .. },
188-
..
189-
} => {
190-
let valid_discrs =
191-
adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
192-
193-
let discr =
194-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
195-
self.enums.push(EnumCheckType::Direct {
196-
source_op: op.to_copy(),
197-
discr,
198-
op_size: op_layout.size,
199-
valid_discrs,
200-
});
201-
}
202-
// Construction of an enum with multiple variants and niche optimizations.
203-
Variants::Multiple {
204-
tag_encoding: TagEncoding::Niche { .. },
205-
tag: Scalar::Initialized { value, valid_range, .. },
206-
tag_field,
207-
..
208-
} => {
209-
let discr =
210-
TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
211-
self.enums.push(EnumCheckType::WithNiche {
212-
source_op: op.to_copy(),
213-
discr,
214-
op_size: op_layout.size,
215-
offset: enum_layout.fields.offset(tag_field.as_usize()),
216-
valid_range,
217-
});
218-
}
219-
_ => return,
318+
(manual_drop_arg, adt_def)
319+
} else {
320+
return;
321+
};
322+
323+
self.register_new_check(
324+
enum_ty,
325+
*enum_adt_def,
326+
Operand::Copy(place_ref.to_place(self.tcx)),
327+
);
328+
}
329+
330+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, ___location: Location) {
331+
if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
332+
let ty::Adt(adt_def, _) = ty.kind() else {
333+
return;
334+
};
335+
if !adt_def.is_enum() {
336+
return;
220337
}
221338

222-
self.super_rvalue(rvalue, ___location);
339+
self.register_new_check(*ty, *adt_def, op.to_copy());
223340
}
341+
self.super_rvalue(rvalue, ___location);
224342
}
225343
}
226344

@@ -262,6 +380,22 @@ fn insert_discr_cast_to_u128<'tcx>(
262380
}
263381
};
264382

383+
// If the enum is behind a pointer, cast it to a *[const|mut] MaybeUninit<T> and then extract the discriminant through that.
384+
let source_op = if let ty::RawPtr(pointee_ty, mutbl) = source_op.ty(local_decls, tcx).kind() {
385+
let mu_ptr_ty = Ty::new_ptr(tcx, Ty::new_maybe_uninit(tcx, *pointee_ty), *mutbl);
386+
let mu_ptr_decl =
387+
local_decls.push(LocalDecl::with_source_info(mu_ptr_ty, source_info)).into();
388+
let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_ptr_ty);
389+
block_data.statements.push(Statement::new(
390+
source_info,
391+
StatementKind::Assign(Box::new((mu_ptr_decl, rvalue))),
392+
));
393+
394+
Operand::Copy(mu_ptr_decl.project_deeper(&[ProjectionElem::Deref], tcx))
395+
} else {
396+
source_op
397+
};
398+
265399
let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
266400
// The discriminant is less wide than the operand, cast the operand into
267401
// [MaybeUninit; N] and then index into it.
@@ -335,7 +469,8 @@ fn insert_direct_enum_check<'tcx>(
335469
new_block: BasicBlock,
336470
) {
337471
// Insert a new target block that is branched to in case of an invalid discriminant.
338-
let invalid_discr_block_data = BasicBlockData::new(None, false);
472+
let invalid_discr_block_data =
473+
BasicBlockData::new(None, basic_blocks[current_block].is_cleanup);
339474
let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
340475
let block_data = &mut basic_blocks[current_block];
341476
let discr_place = insert_discr_cast_to_u128(
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
enum Single {
8+
A,
9+
}
10+
11+
fn main() {
12+
let illegal_val: u16 = 1;
13+
let illegal_val_ptr = &raw const illegal_val;
14+
let foo: *const std::mem::ManuallyDrop<Single> =
15+
unsafe { std::mem::transmute(illegal_val_ptr) };
16+
17+
let val: Single = unsafe { foo.cast::<Single>().read() };
18+
println!("{}", val as u16);
19+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[repr(u16)]
6+
enum Single {
7+
A,
8+
}
9+
10+
fn main() {
11+
let illegal_val: u16 = 0;
12+
let illegal_val_ptr = &raw const illegal_val;
13+
let foo: *const std::mem::ManuallyDrop<Single> =
14+
unsafe { std::mem::transmute(illegal_val_ptr) };
15+
16+
let val: Single = unsafe { foo.cast::<Single>().read() };
17+
println!("{}", val as u16);
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//@ run-fail
2+
//@ compile-flags: -C debug-assertions
3+
//@ error-pattern: trying to construct an enum from an invalid value 0x1
4+
5+
#[allow(dead_code)]
6+
#[repr(u16)]
7+
enum Single {
8+
A,
9+
}
10+
11+
fn main() {
12+
let illegal_val: u16 = 1;
13+
let illegal_val_ptr = &raw const illegal_val;
14+
let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) };
15+
16+
let val: Single = unsafe { foo.read() };
17+
println!("{}", val as u16);
18+
}

tests/ui/mir/enum/pointer_read_ok.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//@ run-pass
2+
//@ compile-flags: -C debug-assertions
3+
4+
#[allow(dead_code)]
5+
#[repr(u16)]
6+
enum Single {
7+
A,
8+
}
9+
10+
fn main() {
11+
let illegal_val: u16 = 0;
12+
let illegal_val_ptr = &raw const illegal_val;
13+
let foo: *const Single = unsafe { std::mem::transmute(illegal_val_ptr) };
14+
15+
let val: Single = unsafe { foo.read() };
16+
println!("{}", val as u16);
17+
}

0 commit comments

Comments
 (0)