diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index de18275db1..eb8783049c 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -5,11 +5,11 @@ use super::CodegenCx; use crate::abi::ConvSpirvType; use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind}; use crate::spirv_type::SpirvType; +use itertools::Itertools as _; use rspirv::spirv::Word; use rustc_abi::{self as abi, AddressSpace, Float, HasDataLayout, Integer, Primitive, Size}; use rustc_codegen_ssa::traits::{ConstCodegenMethods, MiscCodegenMethods, StaticCodegenMethods}; -use rustc_middle::bug; -use rustc_middle::mir::interpret::{ConstAllocation, GlobalAlloc, Scalar, alloc_range}; +use rustc_middle::mir::interpret::{AllocError, ConstAllocation, GlobalAlloc, Scalar, alloc_range}; use rustc_middle::ty::layout::LayoutOf; use rustc_span::{DUMMY_SP, Span}; @@ -255,7 +255,11 @@ impl ConstCodegenMethods for CodegenCx<'_> { other.debug(ty, self) )), }; - let init = self.create_const_alloc(alloc, pointee); + // FIXME(eddyb) always use `const_data_from_alloc`, and + // defer the actual `try_read_from_const_alloc` step. + let init = self + .try_read_from_const_alloc(alloc, pointee) + .unwrap_or_else(|| self.const_data_from_alloc(alloc)); let value = self.static_addr_of(init, alloc.inner().align, None); (value, AddressSpace::DATA) } @@ -280,7 +284,11 @@ impl ConstCodegenMethods for CodegenCx<'_> { other.debug(ty, self) )), }; - let init = self.create_const_alloc(alloc, pointee); + // FIXME(eddyb) always use `const_data_from_alloc`, and + // defer the actual `try_read_from_const_alloc` step. + let init = self + .try_read_from_const_alloc(alloc, pointee) + .unwrap_or_else(|| self.const_data_from_alloc(alloc)); let value = self.static_addr_of(init, alloc.inner().align, None); (value, AddressSpace::DATA) } @@ -290,24 +298,7 @@ impl ConstCodegenMethods for CodegenCx<'_> { (self.get_static(def_id), AddressSpace::DATA) } }; - let value = if offset.bytes() == 0 { - base_addr - } else { - self.tcx - .dcx() - .fatal("Non-zero scalar_to_backend ptr.offset not supported") - // let offset = self.constant_bit64(ptr.offset.bytes()); - // self.gep(base_addr, once(offset)) - }; - if let Primitive::Pointer(_) = layout.primitive() { - assert_ty_eq!(self, value.ty, ty); - value - } else { - self.tcx - .dcx() - .fatal("Non-pointer-typed scalar_to_backend Scalar::Ptr not supported"); - // unsafe { llvm::LLVMConstPtrToInt(llval, llty) } - } + self.const_bitcast(self.const_ptr_byte_offset(base_addr, offset), ty) } } } @@ -348,9 +339,8 @@ impl<'tcx> CodegenCx<'tcx> { && let Some(SpirvConst::ConstDataFromAlloc(alloc)) = self.builder.lookup_const_by_id(pointee) && let SpirvType::Pointer { pointee } = self.lookup_type(ty) + && let Some(init) = self.try_read_from_const_alloc(alloc, pointee) { - let mut offset = Size::ZERO; - let init = self.read_from_const_alloc(alloc, &mut offset, pointee); return self.static_addr_of(init, alloc.inner().align, None); } @@ -379,44 +369,38 @@ impl<'tcx> CodegenCx<'tcx> { } } - pub fn create_const_alloc(&self, alloc: ConstAllocation<'tcx>, ty: Word) -> SpirvValue { - tracing::trace!( - "Creating const alloc of type {} with {} bytes", - self.debug_type(ty), - alloc.inner().len() - ); - let mut offset = Size::ZERO; - let result = self.read_from_const_alloc(alloc, &mut offset, ty); - assert_eq!( - offset.bytes_usize(), - alloc.inner().len(), - "create_const_alloc must consume all bytes of an Allocation" - ); - tracing::trace!("Done creating alloc of type {}", self.debug_type(ty)); - result - } - - fn read_from_const_alloc( + /// Attempt to read a whole constant of type `ty` from `alloc`, but only + /// returning that constant if its size covers the entirety of `alloc`. + // + // FIXME(eddyb) should this use something like `Result<_, PartialRead>`? + pub fn try_read_from_const_alloc( + &self, + alloc: ConstAllocation<'tcx>, + ty: Word, + ) -> Option { + let (result, read_size) = self.read_from_const_alloc_at(alloc, ty, Size::ZERO); + (read_size == alloc.inner().size()).then_some(result) + } + + // HACK(eddyb) the `Size` returned is the equivalent of `size_of_val` on + // the returned constant, i.e. `ty.sizeof()` can be either `Some(read_size)`, + // or `None` - i.e. unsized, in which case only the returned `Size` records + // how much was read from `alloc` to build the returned constant value. + #[tracing::instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), offset))] + fn read_from_const_alloc_at( &self, alloc: ConstAllocation<'tcx>, - offset: &mut Size, ty: Word, - ) -> SpirvValue { - let ty_concrete = self.lookup_type(ty); - *offset = offset.align_to(ty_concrete.alignof(self)); - // these print statements are really useful for debugging, so leave them easily available - // println!("const at {}: {}", offset.bytes(), self.debug_type(ty)); - match ty_concrete { - SpirvType::Void => self - .tcx - .dcx() - .fatal("cannot create const alloc of type void"), + offset: Size, + ) -> (SpirvValue, Size) { + let ty_def = self.lookup_type(ty); + match ty_def { SpirvType::Bool | SpirvType::Integer(..) | SpirvType::Float(_) | SpirvType::Pointer { .. } => { - let size = ty_concrete.sizeof(self).unwrap(); - let primitive = match ty_concrete { + let size = ty_def.sizeof(self).unwrap(); + let primitive = match ty_def { SpirvType::Bool => Primitive::Int(Integer::fit_unsigned(0), false), SpirvType::Integer(int_size, int_signedness) => Primitive::Int( match int_size { @@ -445,147 +429,196 @@ impl<'tcx> CodegenCx<'tcx> { } }), SpirvType::Pointer { .. } => Primitive::Pointer(AddressSpace::DATA), - unsupported_spirv_type => bug!( - "invalid spirv type internal to create_alloc_const2: {:?}", - unsupported_spirv_type - ), + _ => unreachable!(), }; - // alloc_id is not needed by read_scalar, so we just use 0. If the context - // refers to a pointer, read_scalar will find the actual alloc_id. It - // only uses the input alloc_id in the case that the scalar is uninitialized - // as part of the error output - // tldr, the pointer here is only needed for the offset - let value = match alloc.inner().read_scalar( - self, - alloc_range(*offset, size), - matches!(primitive, Primitive::Pointer(_)), - ) { + + let range = alloc_range(offset, size); + let read_provenance = matches!(primitive, Primitive::Pointer(_)); + + let mut primitive = primitive; + let mut read_result = alloc.inner().read_scalar(self, range, read_provenance); + + // HACK(eddyb) while reading a pointer as an integer will fail, + // the pointer itself can be read as a pointer, and then passed + // to `scalar_to_backend`, which will `const_bitcast` it to `ty`. + if read_result.is_err() + && !read_provenance + && let read_ptr_result @ Ok(Scalar::Ptr(ptr, _)) = alloc + .inner() + .read_scalar(self, range, /* read_provenance */ true) + { + let (prov, _offset) = ptr.into_parts(); + primitive = Primitive::Pointer( + self.tcx.global_alloc(prov.alloc_id()).address_space(self), + ); + read_result = read_ptr_result; + } + + let scalar_or_zombie = match read_result { Ok(scalar) => { - self.scalar_to_backend(scalar, self.primitive_to_scalar(primitive), ty) + Ok(self.scalar_to_backend(scalar, self.primitive_to_scalar(primitive), ty)) } - _ => self.undef(ty), + + // FIXME(eddyb) could some of these use e.g. `const_bitcast`? + // (or, in general, assembling one constant out of several) + Err(err) => match err { + // The scalar is only `undef` if the entire byte range + // it covers is completely uninitialized - all other + // failure modes of `read_scalar` are various errors. + AllocError::InvalidUninitBytes(_) => { + let uninit_range = alloc + .inner() + .init_mask() + .is_range_initialized(range) + .unwrap_err(); + let uninit_size = { + let [start, end] = [uninit_range.start, uninit_range.end()] + .map(|x| x.clamp(range.start, range.end())); + end - start + }; + if uninit_size == size { + Ok(self.undef(ty)) + } else { + Err(format!( + "overlaps {} uninitialized bytes", + uninit_size.bytes() + )) + } + } + AllocError::ReadPointerAsInt(_) => Err("overlaps pointer bytes".into()), + AllocError::ReadPartialPointer(_) => { + Err("partially overlaps another pointer".into()) + } + + // HACK(eddyb) these should never happen when using + // `read_scalar`, but better not outright crash. + AllocError::ScalarSizeMismatch(_) + | AllocError::OverwritePartialPointer(_) => { + Err(format!("unrecognized `AllocError::{err:?}`")) + } + }, }; - *offset += size; - value + let result = scalar_or_zombie.unwrap_or_else(|reason| { + let result = self.undef(ty); + self.zombie_no_span( + result.def_cx(self), + &format!("unsupported `{}` constant: {reason}", self.debug_type(ty),), + ); + result + }); + (result, size) } SpirvType::Adt { - size, field_types, field_offsets, .. } => { - let base = *offset; - let mut values = Vec::with_capacity(field_types.len()); - let mut occupied_spaces = Vec::with_capacity(field_types.len()); - for (&ty, &field_offset) in field_types.iter().zip(field_offsets.iter()) { - let total_offset_start = base + field_offset; - let mut total_offset_end = total_offset_start; - values.push( - self.read_from_const_alloc(alloc, &mut total_offset_end, ty) - .def_cx(self), - ); - occupied_spaces.push(total_offset_start..total_offset_end); - } - if let Some(size) = size { - *offset += size; - } else { - assert_eq!( - offset.bytes_usize(), - alloc.inner().len(), - "create_const_alloc must consume all bytes of an Allocation after an unsized struct" + // HACK(eddyb) this accounts for unsized `struct`s, and allows + // detecting gaps *only* at the end of the type, but is cheap. + let mut tail_read_range = ..Size::ZERO; + let result = self.constant_composite( + ty, + field_types + .iter() + .zip_eq(field_offsets.iter()) + .map(|(&f_ty, &f_offset)| { + let (f, f_size) = + self.read_from_const_alloc_at(alloc, f_ty, offset + f_offset); + tail_read_range.end = + tail_read_range.end.max(offset + f_offset + f_size); + f.def_cx(self) + }), + ); + + let ty_size = ty_def.sizeof(self); + + // HACK(eddyb) catch non-padding holes in e.g. `enum` values. + if let Some(ty_size) = ty_size + && let Some(tail_gap) = (ty_size.bytes()) + .checked_sub(tail_read_range.end.align_to(ty_def.alignof(self)).bytes()) + && tail_gap > 0 + { + self.zombie_no_span( + result.def_cx(self), + &format!( + "undersized `{}` constant (at least {tail_gap} bytes may be missing)", + self.debug_type(ty) + ), ); } - self.constant_composite(ty, values.into_iter()) - } - SpirvType::Array { element, count } => { - let count = self.builder.lookup_const_scalar(count).unwrap() as usize; - let values = (0..count).map(|_| { - self.read_from_const_alloc(alloc, offset, element) - .def_cx(self) - }); - self.constant_composite(ty, values) - } - SpirvType::Vector { element, count } => { - let total_size = ty_concrete - .sizeof(self) - .expect("create_const_alloc: Vectors must be sized"); - let final_offset = *offset + total_size; - let values = (0..count).map(|_| { - self.read_from_const_alloc(alloc, offset, element) - .def_cx(self) - }); - let result = self.constant_composite(ty, values); - assert!(*offset <= final_offset); - // Vectors sometimes have padding at the end (e.g. vec3), skip over it. - *offset = final_offset; - result - } - SpirvType::Matrix { element, count } => { - let total_size = ty_concrete - .sizeof(self) - .expect("create_const_alloc: Matrices must be sized"); - let final_offset = *offset + total_size; - let values = (0..count).map(|_| { - self.read_from_const_alloc(alloc, offset, element) - .def_cx(self) - }); - let result = self.constant_composite(ty, values); - assert!(*offset <= final_offset); - // Matrices sometimes have padding at the end (e.g. Mat4x3), skip over it. - *offset = final_offset; - result + + (result, ty_size.unwrap_or(tail_read_range.end)) } - SpirvType::RuntimeArray { element } => { - let mut values = Vec::new(); - while offset.bytes_usize() != alloc.inner().len() { - values.push( - self.read_from_const_alloc(alloc, offset, element) - .def_cx(self), - ); + SpirvType::Vector { element, .. } + | SpirvType::Matrix { element, .. } + | SpirvType::Array { element, .. } + | SpirvType::RuntimeArray { element } => { + let stride = self.lookup_type(element).sizeof(self).unwrap(); + + let count = match ty_def { + SpirvType::Vector { count, .. } | SpirvType::Matrix { count, .. } => { + u64::from(count) + } + SpirvType::Array { count, .. } => { + u64::try_from(self.builder.lookup_const_scalar(count).unwrap()).unwrap() + } + SpirvType::RuntimeArray { .. } => { + (alloc.inner().size() - offset).bytes() / stride.bytes() + } + _ => unreachable!(), + }; + + let result = self.constant_composite( + ty, + (0..count).map(|i| { + let (e, e_size) = + self.read_from_const_alloc_at(alloc, element, offset + i * stride); + assert_eq!(e_size, stride); + e.def_cx(self) + }), + ); + + // HACK(eddyb) `align_to` can only cause an increase for `Vector`, + // because its `size`/`align` are rounded up to a power of two + // (for now, at least, even if eventually that should go away). + let read_size = (count * stride).align_to(ty_def.alignof(self)); + + if let Some(ty_size) = ty_def.sizeof(self) { + assert_eq!(read_size, ty_size); } - let result = self.constant_composite(ty, values.into_iter()); - // TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv: - /* - __constant struct A { - float x; - int y[]; - } a = {1, {2, 3, 4}}; - - __kernel void foo(__global int* data, __constant int* c) { - __constant struct A* asdf = &a; - *data = *c + asdf->y[*c]; + + if let SpirvType::RuntimeArray { .. } = ty_def { + // FIXME(eddyb) values of this type should never be created, + // the only reasonable encoding of e.g. `&str` consts should + // be `&[u8; N]` consts, with the `static_addr_of` pointer + // (*not* the value it points to) cast to `&str`, afterwards. + self.zombie_no_span( + result.def_cx(self), + &format!("unsupported unsized `{}` constant", self.debug_type(ty)), + ); } - */ - // NOTE(eddyb) the above description is a bit outdated, it's now - // clear `OpTypeRuntimeArray` does not belong in user code, and - // is only for dynamically-sized SSBOs and descriptor indexing, - // and a general solution looks similar to `union` handling, but - // for the length of a fixed-length array. - self.zombie_no_span(result.def_cx(self), "constant `OpTypeRuntimeArray` value"); - result + + (result, read_size) + } + + SpirvType::Void + | SpirvType::Function { .. } + | SpirvType::Image { .. } + | SpirvType::Sampler + | SpirvType::SampledImage { .. } + | SpirvType::InterfaceBlock { .. } + | SpirvType::AccelerationStructureKhr + | SpirvType::RayQueryKhr => { + let result = self.undef(ty); + self.zombie_no_span( + result.def_cx(self), + &format!( + "cannot reinterpret Rust constant data as a `{}` value", + self.debug_type(ty) + ), + ); + (result, ty_def.sizeof(self).unwrap_or(Size::ZERO)) } - SpirvType::Function { .. } => self - .tcx - .dcx() - .fatal("TODO: SpirvType::Function not supported yet in create_const_alloc"), - SpirvType::Image { .. } => self.tcx.dcx().fatal("cannot create a constant image value"), - SpirvType::Sampler => self - .tcx - .dcx() - .fatal("cannot create a constant sampler value"), - SpirvType::SampledImage { .. } => self - .tcx - .dcx() - .fatal("cannot create a constant sampled image value"), - SpirvType::InterfaceBlock { .. } => self - .tcx - .dcx() - .fatal("cannot create a constant interface block value"), - SpirvType::AccelerationStructureKhr => self - .tcx - .dcx() - .fatal("cannot create a constant acceleration structure"), - SpirvType::RayQueryKhr => self.tcx.dcx().fatal("cannot create a constant ray query"), } } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 5ca2edbd0c..947a8105cc 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -394,7 +394,7 @@ impl<'tcx> StaticCodegenMethods for CodegenCx<'tcx> { other.debug(g.ty, self) )), }; - let v = self.create_const_alloc(alloc, value_ty); + let v = self.try_read_from_const_alloc(alloc, value_ty).unwrap(); assert_ty_eq!(self, value_ty, v.ty); self.builder .set_global_initializer(g.def_cx(self), v.def_cx(self));