diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp index dcef01a5eb6d3..0cfeb9ec26a02 100644 --- a/clang/lib/CodeGen/CGPointerAuth.cpp +++ b/clang/lib/CodeGen/CGPointerAuth.cpp @@ -440,9 +440,9 @@ CodeGenModule::getConstantSignedPointer(llvm::Constant *Pointer, unsigned Key, IntegerDiscriminator = llvm::ConstantInt::get(Int64Ty, 0); } - return llvm::ConstantPtrAuth::get(Pointer, - llvm::ConstantInt::get(Int32Ty, Key), - IntegerDiscriminator, AddressDiscriminator); + return llvm::ConstantPtrAuth::get( + Pointer, llvm::ConstantInt::get(Int32Ty, Key), IntegerDiscriminator, + AddressDiscriminator, llvm::Constant::getNullValue(UnqualPtrTy)); } /// Does a given PointerAuthScheme require us to sign a value diff --git a/libcxx/include/__config b/libcxx/include/__config index 77a71b6cf1cae..7222cf2d71505 100644 --- a/libcxx/include/__config +++ b/libcxx/include/__config @@ -484,8 +484,21 @@ typedef __char32_t char32_t; # define _LIBCPP_EXCEPTIONS_SIG e # endif +# if !_LIBCPP_HAS_EXCEPTIONS +# define _LIBCPP_EXCEPTIONS_SIG n +# else +# define _LIBCPP_EXCEPTIONS_SIG e +# endif + +# if __has_extension(pointer_field_protection) +# define _LIBCPP_PFP_SIG p +# else +# define _LIBCPP_PFP_SIG +# endif + # define _LIBCPP_ODR_SIGNATURE \ - _LIBCPP_CONCAT(_LIBCPP_CONCAT(_LIBCPP_HARDENING_SIG, _LIBCPP_EXCEPTIONS_SIG), _LIBCPP_VERSION) + _LIBCPP_CONCAT(_LIBCPP_CONCAT(_LIBCPP_CONCAT(_LIBCPP_HARDENING_SIG, _LIBCPP_EXCEPTIONS_SIG), _LIBCPP_PFP_SIG), \ + _LIBCPP_VERSION) // This macro marks a symbol as being hidden from libc++'s ABI. This is achieved // on two levels: @@ -1262,6 +1275,12 @@ typedef __char32_t char32_t; # define _LIBCPP_HAS_EXPLICIT_THIS_PARAMETER 0 # endif +# if __has_extension(pointer_field_protection) +# define _LIBCPP_NO_PFP [[clang::no_field_protection]] +# else +# define _LIBCPP_NO_PFP +# endif + #endif // __cplusplus #endif // _LIBCPP___CONFIG diff --git a/libcxx/include/__type_traits/is_trivially_relocatable.h b/libcxx/include/__type_traits/is_trivially_relocatable.h index 9b0e240de55f4..dc8b0e23360fd 100644 --- a/libcxx/include/__type_traits/is_trivially_relocatable.h +++ b/libcxx/include/__type_traits/is_trivially_relocatable.h @@ -34,10 +34,13 @@ template struct __libcpp_is_trivially_relocatable : is_trivially_copyable<_Tp> {}; #endif +// __trivially_relocatable on libc++'s builtin types does not currently return the right answer with PFP. +#if !__has_feature(pointer_field_protection) template struct __libcpp_is_trivially_relocatable<_Tp, __enable_if_t::value> > : true_type {}; +#endif _LIBCPP_END_NAMESPACE_STD diff --git a/libcxx/include/typeinfo b/libcxx/include/typeinfo index 24aaabf0a87df..a1ee4155bcb3d 100644 --- a/libcxx/include/typeinfo +++ b/libcxx/include/typeinfo @@ -300,7 +300,7 @@ class _LIBCPP_EXPORTED_FROM_ABI _LIBCPP_TYPE_INFO_VTABLE_POINTER_AUTH type_info protected: typedef __type_info_implementations::__impl __impl; - __impl::__type_name_t __type_name; + _LIBCPP_NO_PFP __impl::__type_name_t __type_name; _LIBCPP_HIDE_FROM_ABI explicit type_info(const char* __n) : __type_name(__impl::__string_to_type_name(__n)) {} diff --git a/libcxx/test/libcxx/gdb/gdb_pretty_printer_test.sh.cpp b/libcxx/test/libcxx/gdb/gdb_pretty_printer_test.sh.cpp index f125cc9adc491..402c5ad72bd09 100644 --- a/libcxx/test/libcxx/gdb/gdb_pretty_printer_test.sh.cpp +++ b/libcxx/test/libcxx/gdb/gdb_pretty_printer_test.sh.cpp @@ -256,9 +256,12 @@ void unique_ptr_test() { ComparePrettyPrintToRegex(std::move(forty_two), R"(std::unique_ptr containing = {__ptr_ = 0x[a-f0-9]+})"); +#if !__has_extension(pointer_field_protection) + // GDB doesn't know how to read PFP fields correctly yet. std::unique_ptr this_is_null; ComparePrettyPrintToChars(std::move(this_is_null), R"(std::unique_ptr is nullptr)"); +#endif } void bitset_test() { @@ -476,10 +479,13 @@ void vector_test() { "std::vector of length " "3, capacity 3 = {5, 6, 7}"); +#if !__has_extension(pointer_field_protection) + // GDB doesn't know how to read PFP fields correctly yet. std::vector> test3({7, 8}); ComparePrettyPrintToChars(std::move(test3), "std::vector of length " "2, capacity 2 = {7, 8}"); +#endif } void set_iterator_test() { @@ -650,8 +656,11 @@ void shared_ptr_test() { test0, R"(std::shared_ptr count [3\?], weak [0\?]( \(libc\+\+ missing debug info\))? containing = {__ptr_ = 0x[a-f0-9]+})"); +#if !__has_extension(pointer_field_protection) + // GDB doesn't know how to read PFP fields correctly yet. std::shared_ptr test3; ComparePrettyPrintToChars(test3, "std::shared_ptr is nullptr"); +#endif } void streampos_test() { diff --git a/libcxx/test/libcxx/type_traits/is_trivially_relocatable.compile.pass.cpp b/libcxx/test/libcxx/type_traits/is_trivially_relocatable.compile.pass.cpp index 8066925f2900a..4149a7a12270e 100644 --- a/libcxx/test/libcxx/type_traits/is_trivially_relocatable.compile.pass.cpp +++ b/libcxx/test/libcxx/type_traits/is_trivially_relocatable.compile.pass.cpp @@ -28,6 +28,12 @@ # include #endif +#if __has_extension(pointer_field_protection) +constexpr bool pfp_disabled = false; +#else +constexpr bool pfp_disabled = true; +#endif + static_assert(std::__libcpp_is_trivially_relocatable::value, ""); static_assert(std::__libcpp_is_trivially_relocatable::value, ""); static_assert(std::__libcpp_is_trivially_relocatable::value, ""); @@ -70,8 +76,8 @@ static_assert(!std::__libcpp_is_trivially_relocatable::val // ---------------------- // __split_buffer -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable > >::value, ""); // standard library types @@ -84,7 +90,7 @@ static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable, 1> >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable, 1> >::value == pfp_disabled, ""); // basic_string #if !__has_feature(address_sanitizer) || !_LIBCPP_INSTRUMENTED_WITH_ASAN @@ -99,17 +105,17 @@ struct NotTriviallyRelocatableCharTraits : constexpr_char_traits { }; static_assert(std::__libcpp_is_trivially_relocatable< - std::basic_string, std::allocator > >::value, + std::basic_string, std::allocator > >::value == pfp_disabled, ""); static_assert(std::__libcpp_is_trivially_relocatable< - std::basic_string, std::allocator > >::value, + std::basic_string, std::allocator > >::value == pfp_disabled, ""); static_assert(std::__libcpp_is_trivially_relocatable< - std::basic_string, std::allocator > >::value, + std::basic_string, std::allocator > >::value == pfp_disabled, ""); static_assert( std::__libcpp_is_trivially_relocatable< - std::basic_string, std::allocator > >::value, + std::basic_string, std::allocator > >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable< std::basic_string, test_allocator > >::value, @@ -121,21 +127,21 @@ static_assert( #endif // deque -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable > >::value, ""); // exception_ptr #ifndef _LIBCPP_ABI_MICROSOFT // FIXME: Is this also the case on windows? -static_assert(std::__libcpp_is_trivially_relocatable::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable::value == pfp_disabled, ""); #endif // expected #if TEST_STD_VER >= 23 -static_assert(std::__libcpp_is_trivially_relocatable >::value); -static_assert(std::__libcpp_is_trivially_relocatable, int>>::value); -static_assert(std::__libcpp_is_trivially_relocatable>>::value); -static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr>>::value); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled); +static_assert(std::__libcpp_is_trivially_relocatable, int>>::value == pfp_disabled); +static_assert(std::__libcpp_is_trivially_relocatable>>::value == pfp_disabled); +static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr>>::value == pfp_disabled); static_assert(!std::__libcpp_is_trivially_relocatable>::value); static_assert(!std::__libcpp_is_trivially_relocatable>::value); @@ -145,42 +151,42 @@ static_assert( // locale #ifndef TEST_HAS_NO_LOCALIZATION -static_assert(std::__libcpp_is_trivially_relocatable::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable::value == pfp_disabled, ""); #endif // optional #if TEST_STD_VER >= 17 static_assert(std::__libcpp_is_trivially_relocatable>::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable>::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable>>::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable>>::value == pfp_disabled, ""); #endif // TEST_STD_VER >= 17 // pair -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr > >::value, +static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr > >::value == pfp_disabled, ""); // shared_ptr -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); // tuple #if TEST_STD_VER >= 11 static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable > >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable > >::value == pfp_disabled, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr > >::value, +static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr > >::value == pfp_disabled, ""); #endif // TEST_STD_VER >= 11 @@ -205,9 +211,9 @@ struct NotTriviallyRelocatablePointer { void operator()(T*); }; -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, @@ -221,23 +227,23 @@ static_assert(!std::__libcpp_is_trivially_relocatable= 17 static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable > >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable > >::value == pfp_disabled, ""); static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); static_assert(!std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr > >::value, +static_assert(std::__libcpp_is_trivially_relocatable, std::unique_ptr > >::value == pfp_disabled, ""); #endif // TEST_STD_VER >= 17 // vector -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); static_assert(!std::__libcpp_is_trivially_relocatable > >::value, ""); // weak_ptr -static_assert(std::__libcpp_is_trivially_relocatable >::value, ""); +static_assert(std::__libcpp_is_trivially_relocatable >::value == pfp_disabled, ""); // TODO: Mark all the trivially relocatable STL types as such diff --git a/libcxxabi/include/__cxxabi_config.h b/libcxxabi/include/__cxxabi_config.h index 759445dac91f9..2e8ab664e0615 100644 --- a/libcxxabi/include/__cxxabi_config.h +++ b/libcxxabi/include/__cxxabi_config.h @@ -109,4 +109,14 @@ # define _LIBCXXABI_NOEXCEPT noexcept #endif +#if defined(_LIBCXXABI_COMPILER_CLANG) +# if __has_extension(pointer_field_protection) +# define _LIBCXXABI_NO_PFP [[clang::no_field_protection]] +# else +# define _LIBCXXABI_NO_PFP +# endif +#else +# define _LIBCXXABI_NO_PFP +#endif + #endif // ____CXXABI_CONFIG_H diff --git a/libcxxabi/src/private_typeinfo.h b/libcxxabi/src/private_typeinfo.h index 328a02edef5c1..a3bc0bffd41bc 100644 --- a/libcxxabi/src/private_typeinfo.h +++ b/libcxxabi/src/private_typeinfo.h @@ -145,7 +145,7 @@ class _LIBCXXABI_TYPE_VIS __class_type_info : public __shim_type_info { // Has one non-virtual public base class at offset zero class _LIBCXXABI_TYPE_VIS __si_class_type_info : public __class_type_info { public: - const __class_type_info *__base_type; + _LIBCXXABI_NO_PFP const __class_type_info *__base_type; _LIBCXXABI_HIDDEN virtual ~__si_class_type_info(); @@ -204,7 +204,7 @@ class _LIBCXXABI_TYPE_VIS __vmi_class_type_info : public __class_type_info { class _LIBCXXABI_TYPE_VIS __pbase_type_info : public __shim_type_info { public: unsigned int __flags; - const __shim_type_info *__pointee; + _LIBCXXABI_NO_PFP const __shim_type_info *__pointee; enum __masks { __const_mask = 0x1, @@ -245,7 +245,7 @@ class _LIBCXXABI_TYPE_VIS __pointer_type_info : public __pbase_type_info { class _LIBCXXABI_TYPE_VIS __pointer_to_member_type_info : public __pbase_type_info { public: - const __class_type_info *__context; + _LIBCXXABI_NO_PFP const __class_type_info *__context; _LIBCXXABI_HIDDEN virtual ~__pointer_to_member_type_info(); _LIBCXXABI_HIDDEN virtual bool can_catch(const __shim_type_info *, diff --git a/lld/ELF/Arch/AArch64.cpp b/lld/ELF/Arch/AArch64.cpp index 1812f2af419d2..ff13773aee5d4 100644 --- a/lld/ELF/Arch/AArch64.cpp +++ b/lld/ELF/Arch/AArch64.cpp @@ -114,6 +114,7 @@ AArch64::AArch64(Ctx &ctx) : TargetInfo(ctx) { copyRel = R_AARCH64_COPY; relativeRel = R_AARCH64_RELATIVE; iRelativeRel = R_AARCH64_IRELATIVE; + iRelSymbolicRel = R_AARCH64_FUNCINIT64; gotRel = R_AARCH64_GLOB_DAT; pltRel = R_AARCH64_JUMP_SLOT; symbolicRel = R_AARCH64_ABS64; @@ -137,6 +138,7 @@ RelExpr AArch64::getRelExpr(RelType type, const Symbol &s, case R_AARCH64_ABS16: case R_AARCH64_ABS32: case R_AARCH64_ABS64: + case R_AARCH64_FUNCINIT64: case R_AARCH64_ADD_ABS_LO12_NC: case R_AARCH64_LDST128_ABS_LO12_NC: case R_AARCH64_LDST16_ABS_LO12_NC: @@ -154,6 +156,12 @@ RelExpr AArch64::getRelExpr(RelType type, const Symbol &s, case R_AARCH64_MOVW_UABS_G2_NC: case R_AARCH64_MOVW_UABS_G3: return R_ABS; + case R_AARCH64_PATCHINST: + if (!isAbsolute(s)) + Err(ctx) << getErrorLoc(ctx, loc) + << "R_AARCH64_PATCHINST relocation against non-absolute symbol " + << &s; + return R_ABS; case R_AARCH64_AUTH_ABS64: return RE_AARCH64_AUTH; case R_AARCH64_TLSDESC_ADR_PAGE21: @@ -261,7 +269,8 @@ bool AArch64::usesOnlyLowPageBits(RelType type) const { } RelType AArch64::getDynRel(RelType type) const { - if (type == R_AARCH64_ABS64 || type == R_AARCH64_AUTH_ABS64) + if (type == R_AARCH64_ABS64 || type == R_AARCH64_AUTH_ABS64 || + type == R_AARCH64_FUNCINIT64) return type; return R_AARCH64_NONE; } @@ -506,6 +515,12 @@ void AArch64::relocate(uint8_t *loc, const Relocation &rel, checkIntUInt(ctx, loc, val, 32, rel); write32(ctx, loc, val); break; + case R_AARCH64_PATCHINST: + if (!rel.sym->isUndefined()) { + checkUInt(ctx, loc, val, 32, rel); + write32le(loc, val); + } + break; case R_AARCH64_PLT32: case R_AARCH64_GOTPCREL32: checkInt(ctx, loc, val, 32, rel); diff --git a/lld/ELF/Relocations.cpp b/lld/ELF/Relocations.cpp index 32ac28d61f445..e2f594f3d2bba 100644 --- a/lld/ELF/Relocations.cpp +++ b/lld/ELF/Relocations.cpp @@ -176,7 +176,7 @@ static RelType getMipsPairType(RelType type, bool isLocal) { // True if non-preemptable symbol always has the same value regardless of where // the DSO is loaded. -static bool isAbsolute(const Symbol &sym) { +bool elf::isAbsolute(const Symbol &sym) { if (sym.isUndefined()) return true; if (const auto *dr = dyn_cast(&sym)) @@ -989,8 +989,8 @@ bool RelocationScanner::isStaticLinkTimeConstant(RelExpr e, RelType type, // only the low bits are used. if (e == R_GOT || e == R_PLT) return ctx.target->usesOnlyLowPageBits(type) || !ctx.arg.isPic; - // R_AARCH64_AUTH_ABS64 requires a dynamic relocation. - if (e == RE_AARCH64_AUTH) + // R_AARCH64_AUTH_ABS64 and iRelSymbolicRel require a dynamic relocation. + if (e == RE_AARCH64_AUTH || type == ctx.target->iRelSymbolicRel) return false; // The behavior of an undefined weak reference is implementation defined. @@ -1163,6 +1163,23 @@ void RelocationScanner::processAux(RelExpr expr, RelType type, uint64_t offset, } return; } + if (LLVM_UNLIKELY(type == ctx.target->iRelSymbolicRel)) { + if (sym.isPreemptible) { + auto diag = Err(ctx); + diag << "relocation " << type + << " cannot be used against preemptible symbol '" << &sym << "'"; + printLocation(diag, *sec, sym, offset); + } else if (isIfunc) { + auto diag = Err(ctx); + diag << "relocation " << type + << " cannot be used against ifunc symbol '" << &sym << "'"; + printLocation(diag, *sec, sym, offset); + } else { + part.relaDyn->addReloc({ctx.target->iRelativeRel, sec, offset, false, + sym, addend, R_ABS}); + return; + } + } part.relaDyn->addSymbolReloc(rel, *sec, offset, sym, addend, type); // MIPS ABI turns using of GOT and dynamic relocations inside out. diff --git a/lld/ELF/Relocations.h b/lld/ELF/Relocations.h index c1c4860ceaa92..7ea03c359bbf4 100644 --- a/lld/ELF/Relocations.h +++ b/lld/ELF/Relocations.h @@ -165,6 +165,8 @@ void addGotEntry(Ctx &ctx, Symbol &sym); void hexagonTLSSymbolUpdate(Ctx &ctx); bool hexagonNeedsTLSSymbol(ArrayRef outputSections); +bool isAbsolute(const Symbol &sym); + class ThunkSection; class Thunk; class InputSectionDescription; diff --git a/lld/ELF/Target.h b/lld/ELF/Target.h index fdc0c20f9cd02..de6040e79fa81 100644 --- a/lld/ELF/Target.h +++ b/lld/ELF/Target.h @@ -131,6 +131,7 @@ class TargetInfo { RelType relativeRel = 0; RelType iRelativeRel = 0; RelType symbolicRel = 0; + RelType iRelSymbolicRel = 0; RelType tlsDescRel = 0; RelType tlsGotRel = 0; RelType tlsModuleIndexRel = 0; diff --git a/lld/test/ELF/aarch64-funcinit64-invalid.s b/lld/test/ELF/aarch64-funcinit64-invalid.s new file mode 100644 index 0000000000000..2507c07056783 --- /dev/null +++ b/lld/test/ELF/aarch64-funcinit64-invalid.s @@ -0,0 +1,18 @@ +# REQUIRES: aarch64 + +# RUN: llvm-mc -filetype=obj -triple=aarch64 %s -o %t.o +# RUN: not ld.lld %t.o -o %t 2>&1 | FileCheck --check-prefix=ERR %s + +.rodata +# ERR: relocation R_AARCH64_FUNCINIT64 cannot be used against local symbol +.8byte func@FUNCINIT + +.data +# ERR: relocation R_AARCH64_FUNCINIT64 cannot be used against ifunc symbol 'ifunc' +.8byte ifunc@FUNCINIT + +.text +func: +.type ifunc, @gnu_indirect_function +ifunc: +ret diff --git a/lld/test/ELF/aarch64-funcinit64.s b/lld/test/ELF/aarch64-funcinit64.s new file mode 100644 index 0000000000000..5f2b863ee884b --- /dev/null +++ b/lld/test/ELF/aarch64-funcinit64.s @@ -0,0 +1,19 @@ +# REQUIRES: aarch64 + +# RUN: llvm-mc -filetype=obj -triple=aarch64 %s -o %t.o +# RUN: ld.lld %t.o -o %t +# RUN: llvm-readelf -s -r %t | FileCheck %s +# RUN: ld.lld %t.o -o %t -pie +# RUN: llvm-readelf -s -r %t | FileCheck %s +# RUN: not ld.lld %t.o -o %t -shared 2>&1 | FileCheck --check-prefix=ERR %s + +.data +# CHECK: R_AARCH64_IRELATIVE [[FOO:[0-9a-f]*]] +# ERR: relocation R_AARCH64_FUNCINIT64 cannot be used against preemptible symbol 'foo' +.8byte foo@FUNCINIT + +.text +# CHECK: {{0*}}[[FOO]] {{.*}} foo +.globl foo +foo: +ret diff --git a/lld/test/ELF/aarch64-patchinst.s b/lld/test/ELF/aarch64-patchinst.s new file mode 100644 index 0000000000000..244c8ccca8e31 --- /dev/null +++ b/lld/test/ELF/aarch64-patchinst.s @@ -0,0 +1,87 @@ +# RUN: rm -rf %t && split-file %s %t +# RUN: llvm-mc -filetype=obj -triple=aarch64 %t/use.s -o %t/use-le.o +# RUN: llvm-mc -filetype=obj -triple=aarch64 %t/def.s -o %t/def-le.o +# RUN: llvm-mc -filetype=obj -triple=aarch64 %t/rel.s -o %t/rel-le.o + +## Deactivation symbol used without being defined: instruction emitted as usual. +# RUN: ld.lld -o %t/undef-le %t/use-le.o --emit-relocs +# RUN: llvm-objdump -r %t/undef-le | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/undef-le | FileCheck --check-prefix=UNDEF %s +# RUN: ld.lld -pie -o %t/undef-le %t/use-le.o --emit-relocs +# RUN: llvm-objdump -r %t/undef-le | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/undef-le | FileCheck --check-prefix=UNDEF %s + +## Deactivation symbol defined: instructions overwritten with NOPs. +# RUN: ld.lld -o %t/def-le %t/use-le.o %t/def-le.o --emit-relocs +# RUN: llvm-objdump -r %t/def-le | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/def-le | FileCheck --check-prefix=DEF %s +# RUN: ld.lld -pie -o %t/def-le %t/use-le.o %t/def-le.o --emit-relocs +# RUN: llvm-objdump -r %t/def-le | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/def-le | FileCheck --check-prefix=DEF %s + +## Relocation pointing to a non-SHN_UNDEF non-SHN_ABS symbol is an error. +# RUN: not ld.lld -o %t/rel-le %t/use-le.o %t/rel-le.o 2>&1 | FileCheck --check-prefix=ERROR %s +# RUN: not ld.lld -pie -o %t/rel-le %t/use-le.o %t/rel-le.o 2>&1 | FileCheck --check-prefix=ERROR %s + +## Behavior unchanged by endianness: relocation always written as little endian. +# RUN: llvm-mc -filetype=obj -triple=aarch64_be %t/use.s -o %t/use-be.o +# RUN: llvm-mc -filetype=obj -triple=aarch64_be %t/def.s -o %t/def-be.o +# RUN: llvm-mc -filetype=obj -triple=aarch64_be %t/rel.s -o %t/rel-be.o +# RUN: ld.lld -o %t/undef-be %t/use-be.o --emit-relocs +# RUN: llvm-objdump -r %t/undef-be | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/undef-be | FileCheck --check-prefix=UNDEF %s +# RUN: ld.lld -pie -o %t/undef-be %t/use-be.o --emit-relocs +# RUN: llvm-objdump -r %t/undef-be | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/undef-be | FileCheck --check-prefix=UNDEF %s +# RUN: ld.lld -o %t/def-be %t/use-be.o %t/def-be.o --emit-relocs +# RUN: llvm-objdump -r %t/def-be | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/def-be | FileCheck --check-prefix=DEF %s +# RUN: ld.lld -pie -o %t/def-be %t/use-be.o %t/def-be.o --emit-relocs +# RUN: llvm-objdump -r %t/def-be | FileCheck --check-prefix=RELOC %s +# RUN: llvm-objdump -d %t/def-be | FileCheck --check-prefix=DEF %s +# RUN: not ld.lld -o %t/rel-be %t/use-be.o %t/rel-be.o 2>&1 | FileCheck --check-prefix=ERROR %s +# RUN: not ld.lld -pie -o %t/rel-be %t/use-be.o %t/rel-be.o 2>&1 | FileCheck --check-prefix=ERROR %s + +# RELOC: R_AARCH64_JUMP26 +# RELOC-NEXT: R_AARCH64_PATCHINST ds +# RELOC-NEXT: R_AARCH64_PATCHINST ds +# RELOC-NEXT: R_AARCH64_PATCHINST ds0+0xd503201f + +#--- use.s +.weak ds +.weak ds0 +# This instruction has a single relocation: the DS relocation. +# UNDEF: add x0, x1, x2 +# DEF: nop +# ERROR: R_AARCH64_PATCHINST relocation against non-absolute symbol ds +.reloc ., R_AARCH64_PATCHINST, ds +add x0, x1, x2 +# This instruction has two relocations: the DS relocation and the JUMP26 to f1. +# Make sure that the DS relocation takes precedence. +.reloc ., R_AARCH64_PATCHINST, ds +# UNDEF: b {{.*}} +# DEF: nop +# ERROR: R_AARCH64_PATCHINST relocation against non-absolute symbol ds +b f1 +# Alternative representation: instruction opcode stored in addend. +# UNDEF: add x3, x4, x5 +# DEF: nop +# ERROR: R_AARCH64_PATCHINST relocation against non-absolute symbol ds0 +.reloc ., R_AARCH64_PATCHINST, ds0 + 0xd503201f +add x3, x4, x5 + +.section .text.f1,"ax",@progbits +f1: +ret + +#--- def.s +.globl ds +ds = 0xd503201f +.globl ds0 +ds0 = 0 + +#--- rel.s +.globl ds +ds: +.globl ds0 +ds0: diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst index 28746bf9d05aa..96dcb21304787 100644 --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -3090,6 +3090,21 @@ A "convergencectrl" operand bundle is only valid on a ``convergent`` operation. When present, the operand bundle must contain exactly one value of token type. See the :doc:`ConvergentOperations` document for details. +Deactivation Symbol Operand Bundles +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A ``"deactivation-symbol"`` operand bundle is valid on the following +instructions (AArch64 only): + +- Call to a normal function with ``notail`` attribute. +- Call to ``llvm.ptrauth.sign`` or ``llvm.ptrauth.auth`` intrinsics. + +This operand bundle specifies that if the deactivation symbol is defined +to a valid value for the target, the marked instruction will return the +value of its first argument instead of calling the specified function +or intrinsic. This is achieved with ``PATCHINST`` relocations on the +target instructions (see the AArch64 psABI for details). + .. _moduleasm: Module-Level Inline Assembly @@ -31146,3 +31161,57 @@ This intrinsic is assumed to execute in the default :ref:`floating-point environment ` *except* for the rounding mode. This intrinsic is not supported on all targets. Some targets may not support all rounding modes. + +'``llvm.protected.field.ptr``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" + +:: + + declare ptr @llvm.protected.field.ptr(ptr ptr, i64 disc, i1 use_hw_encoding) + +Overview: +""""""""" + +The '``llvm.protected.field.ptr``' intrinsic returns a pointer to the +storage location of a pointer that has special properties as described +below. + +Arguments: +"""""""""" + +The first argument is the pointer specifying the location to store the +pointer. The second argument is the discriminator, which is used as an +input for the pointer encoding. The third argument specifies whether to +use a target-specific mechanism to encode the pointer. + +Semantics: +"""""""""" + +This intrinsic returns a pointer which may be used to store a +pointer at the specified address that is encoded using the specified +discriminator. Stores via the pointer will cause the stored pointer to be +blended with the second argument before being stored. The blend operation +shall be either a weak but cheap and target-independent operation (if +the third argument is 0) or a stronger target-specific operation (if the +third argument is 1). When loading from the pointer, the inverse operation +is done on the loaded pointer after it is loaded. Specifically, when the +third argument is 1, the pointer is signed (using pointer authentication +instructions or emulated PAC if not supported by the hardware) using +the discriminator before being stored, and authenticated after being +loaded. Note that it is currently unsupported to have the third argument +be 1 on targets other than AArch64. When the third argument is 0, it is +rotated left by 16 bits and the discriminator is subtracted before being +stored, and the discriminator is added and the pointer is rotated right +by 16 bits after being loaded. + +If the pointer is used other than for loading or storing (e.g. its +address escapes), that will disable all blending operations using +the deactivation symbol specified in the intrinsic's operand bundle. +The deactivation symbol operand bundle is copied onto any sign and auth +intrinsics that this intrinsic is lowered into. The intent is that the +deactivation symbol represents a field identifier. + +This intrinsic is used to implement structure protection. diff --git a/llvm/include/llvm/Analysis/PtrUseVisitor.h b/llvm/include/llvm/Analysis/PtrUseVisitor.h index 0858d8aee2186..a39f6881f24f3 100644 --- a/llvm/include/llvm/Analysis/PtrUseVisitor.h +++ b/llvm/include/llvm/Analysis/PtrUseVisitor.h @@ -134,6 +134,7 @@ class PtrUseVisitorBase { UseAndIsOffsetKnownPair UseAndIsOffsetKnown; APInt Offset; + Value *ProtectedFieldDisc; }; /// The worklist of to-visit uses. @@ -158,6 +159,10 @@ class PtrUseVisitorBase { /// The constant offset of the use if that is known. APInt Offset; + // When this access is via an llvm.protected.field.ptr intrinsic, contains + // the second argument to the intrinsic, the discriminator. + Value *ProtectedFieldDisc; + /// @} /// Note that the constructor is protected because this class must be a base @@ -230,6 +235,7 @@ class PtrUseVisitor : protected InstVisitor, IntegerType *IntIdxTy = cast(DL.getIndexType(I.getType())); IsOffsetKnown = true; Offset = APInt(IntIdxTy->getBitWidth(), 0); + ProtectedFieldDisc = nullptr; PI.reset(); // Enqueue the uses of this pointer. @@ -242,6 +248,7 @@ class PtrUseVisitor : protected InstVisitor, IsOffsetKnown = ToVisit.UseAndIsOffsetKnown.getInt(); if (IsOffsetKnown) Offset = std::move(ToVisit.Offset); + ProtectedFieldDisc = ToVisit.ProtectedFieldDisc; Instruction *I = cast(U->getUser()); static_cast(this)->visit(I); @@ -300,6 +307,14 @@ class PtrUseVisitor : protected InstVisitor, case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: return; // No-op intrinsics. + + case Intrinsic::protected_field_ptr: { + if (!IsOffsetKnown) + return Base::visitIntrinsicInst(II); + ProtectedFieldDisc = II.getArgOperand(1); + enqueueUsers(II); + break; + } } } diff --git a/llvm/include/llvm/BinaryFormat/ELFRelocs/AArch64.def b/llvm/include/llvm/BinaryFormat/ELFRelocs/AArch64.def index 05b79eae573f7..1cfcdbf67dac5 100644 --- a/llvm/include/llvm/BinaryFormat/ELFRelocs/AArch64.def +++ b/llvm/include/llvm/BinaryFormat/ELFRelocs/AArch64.def @@ -61,6 +61,8 @@ ELF_RELOC(R_AARCH64_LD64_GOT_LO12_NC, 0x138) ELF_RELOC(R_AARCH64_LD64_GOTPAGE_LO15, 0x139) ELF_RELOC(R_AARCH64_PLT32, 0x13a) ELF_RELOC(R_AARCH64_GOTPCREL32, 0x13b) +ELF_RELOC(R_AARCH64_PATCHINST, 0x13c) +ELF_RELOC(R_AARCH64_FUNCINIT64, 0x13d) // General dynamic TLS relocations ELF_RELOC(R_AARCH64_TLSGD_ADR_PREL21, 0x200) ELF_RELOC(R_AARCH64_TLSGD_ADR_PAGE21, 0x201) diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h index dc78eb4164acf..b02891a7b8aa4 100644 --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -437,6 +437,8 @@ enum ConstantsCodes { CST_CODE_CE_GEP_WITH_INRANGE = 31, // [opty, flags, range, n x operands] CST_CODE_CE_GEP = 32, // [opty, flags, n x operands] CST_CODE_PTRAUTH = 33, // [ptr, key, disc, addrdisc] + CST_CODE_PTRAUTH2 = 34, // [ptr, key, disc, addrdisc, + // deactivation_symbol] }; /// CastOpcodes - These are values used in the bitcode files to encode which diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h index 75c051712ae43..b790786760742 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h @@ -162,6 +162,8 @@ class LLVM_ABI CallLowering { /// True if this call results in convergent operations. bool IsConvergent = true; + + GlobalValue *DeactivationSymbol = nullptr; }; /// Argument handling is mostly uniform between the four places that diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h index 99d3cd0aac85c..515196a2b0426 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h @@ -56,6 +56,7 @@ struct MachineIRBuilderState { MDNode *PCSections = nullptr; /// MMRA Metadata to be set on any instruction we create. MDNode *MMRA = nullptr; + Value *DS = nullptr; /// \name Fields describing the insertion point. /// @{ @@ -369,6 +370,7 @@ class LLVM_ABI MachineIRBuilder { State.II = MI.getIterator(); setPCSections(MI.getPCSections()); setMMRAMetadata(MI.getMMRAMetadata()); + setDeactivationSymbol(MI.getDeactivationSymbol()); } /// @} @@ -405,6 +407,9 @@ class LLVM_ABI MachineIRBuilder { /// Set the PC sections metadata to \p MD for all the next build instructions. void setMMRAMetadata(MDNode *MMRA) { State.MMRA = MMRA; } + Value *getDeactivationSymbol() { return State.DS; } + void setDeactivationSymbol(Value *DS) { State.DS = DS; } + /// Get the current instruction's MMRA metadata. MDNode *getMMRAMetadata() { return State.MMRA; } diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h index 465e4a0a9d0d8..0f432fd7fbf87 100644 --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -1563,6 +1563,10 @@ enum NodeType { // Outputs: Output Chain CLEAR_CACHE, + // Untyped node storing deactivation symbol reference + // (DeactivationSymbolSDNode). + DEACTIVATION_SYMBOL, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h index 06c4daf245fa0..16227b6b1a32e 100644 --- a/llvm/include/llvm/CodeGen/MachineFunction.h +++ b/llvm/include/llvm/CodeGen/MachineFunction.h @@ -1207,7 +1207,7 @@ class LLVM_ABI MachineFunction { ArrayRef MMOs, MCSymbol *PreInstrSymbol = nullptr, MCSymbol *PostInstrSymbol = nullptr, MDNode *HeapAllocMarker = nullptr, MDNode *PCSections = nullptr, uint32_t CFIType = 0, - MDNode *MMRAs = nullptr); + MDNode *MMRAs = nullptr, Value *DS = nullptr); /// Allocate a string and populate it with the given external symbol name. const char *createExternalSymbolName(StringRef Name); diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h index 10a9b1ff1411d..2d462051e8d4c 100644 --- a/llvm/include/llvm/CodeGen/MachineInstr.h +++ b/llvm/include/llvm/CodeGen/MachineInstr.h @@ -160,8 +160,9 @@ class MachineInstr /// /// This has to be defined eagerly due to the implementation constraints of /// `PointerSumType` where it is used. - class ExtraInfo final : TrailingObjects { + class ExtraInfo final + : TrailingObjects { public: static ExtraInfo *create(BumpPtrAllocator &Allocator, ArrayRef MMOs, @@ -169,20 +170,23 @@ class MachineInstr MCSymbol *PostInstrSymbol = nullptr, MDNode *HeapAllocMarker = nullptr, MDNode *PCSections = nullptr, uint32_t CFIType = 0, - MDNode *MMRAs = nullptr) { + MDNode *MMRAs = nullptr, Value *DS = nullptr) { bool HasPreInstrSymbol = PreInstrSymbol != nullptr; bool HasPostInstrSymbol = PostInstrSymbol != nullptr; bool HasHeapAllocMarker = HeapAllocMarker != nullptr; bool HasMMRAs = MMRAs != nullptr; bool HasCFIType = CFIType != 0; bool HasPCSections = PCSections != nullptr; + bool HasDS = DS != nullptr; auto *Result = new (Allocator.Allocate( - totalSizeToAlloc( + totalSizeToAlloc( MMOs.size(), HasPreInstrSymbol + HasPostInstrSymbol, - HasHeapAllocMarker + HasPCSections + HasMMRAs, HasCFIType), + HasHeapAllocMarker + HasPCSections + HasMMRAs, HasCFIType, HasDS), alignof(ExtraInfo))) ExtraInfo(MMOs.size(), HasPreInstrSymbol, HasPostInstrSymbol, - HasHeapAllocMarker, HasPCSections, HasCFIType, HasMMRAs); + HasHeapAllocMarker, HasPCSections, HasCFIType, HasMMRAs, + HasDS); // Copy the actual data into the trailing objects. std::copy(MMOs.begin(), MMOs.end(), @@ -203,6 +207,8 @@ class MachineInstr Result->getTrailingObjects()[0] = CFIType; if (HasMMRAs) Result->getTrailingObjects()[MDNodeIdx++] = MMRAs; + if (HasDS) + Result->getTrailingObjects()[0] = DS; return Result; } @@ -241,6 +247,10 @@ class MachineInstr : nullptr; } + Value *getDeactivationSymbol() const { + return HasDS ? getTrailingObjects()[0] : 0; + } + private: friend TrailingObjects; @@ -256,6 +266,7 @@ class MachineInstr const bool HasPCSections; const bool HasCFIType; const bool HasMMRAs; + const bool HasDS; // Implement the `TrailingObjects` internal API. size_t numTrailingObjects(OverloadToken) const { @@ -270,16 +281,19 @@ class MachineInstr size_t numTrailingObjects(OverloadToken) const { return HasCFIType; } + size_t numTrailingObjects(OverloadToken) const { + return HasDS; + } // Just a boring constructor to allow us to initialize the sizes. Always use // the `create` routine above. ExtraInfo(int NumMMOs, bool HasPreInstrSymbol, bool HasPostInstrSymbol, bool HasHeapAllocMarker, bool HasPCSections, bool HasCFIType, - bool HasMMRAs) + bool HasMMRAs, bool HasDS) : NumMMOs(NumMMOs), HasPreInstrSymbol(HasPreInstrSymbol), HasPostInstrSymbol(HasPostInstrSymbol), HasHeapAllocMarker(HasHeapAllocMarker), HasPCSections(HasPCSections), - HasCFIType(HasCFIType), HasMMRAs(HasMMRAs) {} + HasCFIType(HasCFIType), HasMMRAs(HasMMRAs), HasDS(HasDS) {} }; /// Enumeration of the kinds of inline extra info available. It is important @@ -868,6 +882,14 @@ class MachineInstr return nullptr; } + Value *getDeactivationSymbol() const { + if (!Info) + return nullptr; + if (ExtraInfo *EI = Info.get()) + return EI->getDeactivationSymbol(); + return nullptr; + } + /// Helper to extract a CFI type hash if one has been added. uint32_t getCFIType() const { if (!Info) @@ -1970,6 +1992,8 @@ class MachineInstr /// Set the CFI type for the instruction. LLVM_ABI void setCFIType(MachineFunction &MF, uint32_t Type); + LLVM_ABI void setDeactivationSymbol(MachineFunction &MF, Value *DS); + /// Return the MIFlags which represent both MachineInstrs. This /// should be used when merging two MachineInstrs into one. This routine does /// not modify the MIFlags of this MachineInstr. @@ -2080,7 +2104,7 @@ class MachineInstr void setExtraInfo(MachineFunction &MF, ArrayRef MMOs, MCSymbol *PreInstrSymbol, MCSymbol *PostInstrSymbol, MDNode *HeapAllocMarker, MDNode *PCSections, - uint32_t CFIType, MDNode *MMRAs); + uint32_t CFIType, MDNode *MMRAs, Value *DS); }; /// Special DenseMapInfo traits to compare MachineInstr* by *value* of the diff --git a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h index e705d7d99544c..caeb430d6fd1c 100644 --- a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h +++ b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h @@ -70,29 +70,44 @@ enum { } // end namespace RegState /// Set of metadata that should be preserved when using BuildMI(). This provides -/// a more convenient way of preserving DebugLoc, PCSections and MMRA. +/// a more convenient way of preserving certain data from the original +/// instruction. class MIMetadata { public: MIMetadata() = default; - MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr) - : DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA) {} + MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr, + Value *DeactivationSymbol = nullptr) + : DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA), + DeactivationSymbol(DeactivationSymbol) {} MIMetadata(const DILocation *DI, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr) : DL(DI), PCSections(PCSections), MMRA(MMRA) {} explicit MIMetadata(const Instruction &From) : DL(From.getDebugLoc()), - PCSections(From.getMetadata(LLVMContext::MD_pcsections)) {} + PCSections(From.getMetadata(LLVMContext::MD_pcsections)), + DeactivationSymbol(getDeactivationSymbol(&From)) {} explicit MIMetadata(const MachineInstr &From) - : DL(From.getDebugLoc()), PCSections(From.getPCSections()) {} + : DL(From.getDebugLoc()), PCSections(From.getPCSections()), + DeactivationSymbol(From.getDeactivationSymbol()) {} const DebugLoc &getDL() const { return DL; } MDNode *getPCSections() const { return PCSections; } MDNode *getMMRAMetadata() const { return MMRA; } + Value *getDeactivationSymbol() const { return DeactivationSymbol; } private: DebugLoc DL; MDNode *PCSections = nullptr; MDNode *MMRA = nullptr; + Value *DeactivationSymbol = nullptr; + + static inline Value *getDeactivationSymbol(const Instruction *I) { + if (auto *CB = dyn_cast(I)) + if (auto Bundle = + CB->getOperandBundle(llvm::LLVMContext::OB_deactivation_symbol)) + return Bundle->Inputs[0].get(); + return nullptr; + } }; class MachineInstrBuilder { @@ -348,6 +363,8 @@ class MachineInstrBuilder { MI->setPCSections(*MF, MIMD.getPCSections()); if (MIMD.getMMRAMetadata()) MI->setMMRAMetadata(*MF, MIMD.getMMRAMetadata()); + if (MIMD.getDeactivationSymbol()) + MI->setDeactivationSymbol(*MF, MIMD.getDeactivationSymbol()); return *this; } diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index e5644a5ef206a..b37a663fe44c5 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -763,6 +763,7 @@ class SelectionDAG { int64_t offset = 0, unsigned TargetFlags = 0) { return getGlobalAddress(GV, DL, VT, offset, true, TargetFlags); } + LLVM_ABI SDValue getDeactivationSymbol(const GlobalValue *GV); LLVM_ABI SDValue getFrameIndex(int FI, EVT VT, bool isTarget = false); SDValue getTargetFrameIndex(int FI, EVT VT) { return getFrameIndex(FI, VT, true); diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h index a6a3928230c3d..824cae23a7ff4 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h @@ -150,6 +150,7 @@ class SelectionDAGISel { OPC_RecordChild7, OPC_RecordMemRef, OPC_CaptureGlueInput, + OPC_CaptureDeactivationSymbol, OPC_MoveChild, OPC_MoveChild0, OPC_MoveChild1, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 11ae8cd5eb77a..01ee392925b60 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1980,6 +1980,23 @@ class GlobalAddressSDNode : public SDNode { } }; +class DeactivationSymbolSDNode : public SDNode { + friend class SelectionDAG; + + const GlobalValue *TheGlobal; + + DeactivationSymbolSDNode(const GlobalValue *GV, SDVTList VTs) + : SDNode(ISD::DEACTIVATION_SYMBOL, 0, DebugLoc(), VTs), + TheGlobal(GV) {} + +public: + const GlobalValue *getGlobal() const { return TheGlobal; } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::DEACTIVATION_SYMBOL; + } +}; + class FrameIndexSDNode : public SDNode { friend class SelectionDAG; diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index cbdc1b6031680..c9d253d941ddc 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4707,6 +4707,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { SmallVector InVals; const ConstantInt *CFIType = nullptr; SDValue ConvergenceControlToken; + GlobalValue *DeactivationSymbol = nullptr; std::optional PAI; @@ -4852,6 +4853,11 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { return *this; } + CallLoweringInfo &setDeactivationSymbol(GlobalValue *Sym) { + DeactivationSymbol = Sym; + return *this; + } + ArgListTy &getArgs() { return Args; } diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index 9c9fc8892bdbf..8862ef916656b 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -1033,10 +1033,10 @@ class ConstantPtrAuth final : public Constant { friend struct ConstantPtrAuthKeyType; friend class Constant; - constexpr static IntrusiveOperandsAllocMarker AllocMarker{4}; + constexpr static IntrusiveOperandsAllocMarker AllocMarker{5}; ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, ConstantInt *Disc, - Constant *AddrDisc); + Constant *AddrDisc, Constant *DeactivationSymbol); void *operator new(size_t s) { return User::operator new(s, AllocMarker); } @@ -1046,7 +1046,8 @@ class ConstantPtrAuth final : public Constant { public: /// Return a pointer signed with the specified parameters. LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc); + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol); /// Produce a new ptrauth expression signing the given value using /// the same schema as is stored in one. @@ -1078,6 +1079,10 @@ class ConstantPtrAuth final : public Constant { return !getAddrDiscriminator()->isNullValue(); } + Constant *getDeactivationSymbol() const { + return cast(Op<4>().get()); + } + /// A constant value for the address discriminator which has special /// significance to ctors/dtors lowering. Regular address discrimination can't /// be applied for them since uses of llvm.global_{c|d}tors are disallowed @@ -1106,7 +1111,7 @@ class ConstantPtrAuth final : public Constant { template <> struct OperandTraits - : public FixedNumOperandTraits {}; + : public FixedNumOperandTraits {}; DEFINE_TRANSPARENT_OPERAND_ACCESSORS(ConstantPtrAuth, Constant) diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index bd6f94ac1286c..914424831b8a8 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -2850,6 +2850,12 @@ def int_experimental_convergence_anchor def int_experimental_convergence_loop : DefaultAttrsIntrinsic<[llvm_token_ty], [], [IntrNoMem, IntrConvergent]>; +//===----------------- Structure Protection Intrinsics --------------------===// + +def int_protected_field_ptr : + DefaultAttrsIntrinsic<[llvm_ptr_ty], [llvm_ptr_ty, llvm_i64_ty, llvm_i1_ty], + [IntrNoMem, ImmArg>]>; + //===----------------------------------------------------------------------===// // Target-specific intrinsics //===----------------------------------------------------------------------===// diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h index 852a3a4e2f638..c2d2d46da6ddd 100644 --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -97,6 +97,7 @@ class LLVMContext { OB_ptrauth = 7, // "ptrauth" OB_kcfi = 8, // "kcfi" OB_convergencectrl = 9, // "convergencectrl" + OB_deactivation_symbol = 10, // "deactivation-symbol" }; /// getMDKindID - Return a unique non-zero ID for the specified metadata kind. diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h index 6f682a7059d10..2fe923f6c3866 100644 --- a/llvm/include/llvm/SandboxIR/Constant.h +++ b/llvm/include/llvm/SandboxIR/Constant.h @@ -1363,7 +1363,8 @@ class ConstantPtrAuth final : public Constant { public: /// Return a pointer signed with the specified parameters. LLVM_ABI static ConstantPtrAuth *get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc); + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol); /// The pointer that is signed in this ptrauth signed pointer. LLVM_ABI Constant *getPointer() const; @@ -1378,6 +1379,8 @@ class ConstantPtrAuth final : public Constant { /// the only global-initializer user of the ptrauth signed pointer. LLVM_ABI Constant *getAddrDiscriminator() const; + Constant *getDeactivationSymbol() const; + /// Whether there is any non-null address discriminator. bool hasAddressDiscriminator() const { return cast(Val)->hasAddressDiscriminator(); diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td index 4c83f8a580aa0..77589e4664746 100644 --- a/llvm/include/llvm/Target/Target.td +++ b/llvm/include/llvm/Target/Target.td @@ -682,6 +682,7 @@ class Instruction : InstructionEncoding { // If so, make sure to override // TargetInstrInfo::getInsertSubregLikeInputs. bit variadicOpsAreDefs = false; // Are variadic operands definitions? + bit supportsDeactivationSymbol = false; // Does the instruction have side effects that are not captured by any // operands of the instruction or other flags? diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h index 3f5f4278a2766..926cb73d7ab07 100644 --- a/llvm/include/llvm/Transforms/Utils/Local.h +++ b/llvm/include/llvm/Transforms/Utils/Local.h @@ -182,6 +182,11 @@ LLVM_ABI bool EliminateDuplicatePHINodes(BasicBlock *BB); LLVM_ABI bool EliminateDuplicatePHINodes(BasicBlock *BB, SmallPtrSetImpl &ToRemove); +/// Returns whether it is allowed and beneficial for optimizations to fold this +/// operand through a phi, for example when transforming phi(load(ptr)) into +/// load(phi(ptr)). +bool shouldFoldOperandThroughPhi(const Value *Ptr); + /// This function is used to do simplification of a CFG. For example, it /// adjusts branches to branches to eliminate the extra hop, it eliminates /// unreachable basic blocks, and does other peephole optimization of the CFG. diff --git a/llvm/lib/Analysis/PtrUseVisitor.cpp b/llvm/lib/Analysis/PtrUseVisitor.cpp index 9c79546f491ef..59a09c4ea8721 100644 --- a/llvm/lib/Analysis/PtrUseVisitor.cpp +++ b/llvm/lib/Analysis/PtrUseVisitor.cpp @@ -22,7 +22,8 @@ void detail::PtrUseVisitorBase::enqueueUsers(Value &I) { if (VisitedUses.insert(&U).second) { UseToVisit NewU = { UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown), - Offset + Offset, + ProtectedFieldDisc, }; Worklist.push_back(std::move(NewU)); } diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 13bef1f62f1a9..f2ef8749510b6 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -4216,11 +4216,12 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { } case lltok::kw_ptrauth: { // ValID ::= 'ptrauth' '(' ptr @foo ',' i32 - // (',' i64 (',' ptr addrdisc)? )? ')' + // (',' i64 (',' ptr addrdisc (',' ptr ds)? )? )? ')' Lex.Lex(); Constant *Ptr, *Key; - Constant *Disc = nullptr, *AddrDisc = nullptr; + Constant *Disc = nullptr, *AddrDisc = nullptr, + *DeactivationSymbol = nullptr; if (parseToken(lltok::lparen, "expected '(' in constant ptrauth expression") || @@ -4229,11 +4230,14 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { "expected comma in constant ptrauth expression") || parseGlobalTypeAndValue(Key)) return true; - // If present, parse the optional disc/addrdisc. - if (EatIfPresent(lltok::comma)) - if (parseGlobalTypeAndValue(Disc) || - (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc))) - return true; + // If present, parse the optional disc/addrdisc/ds. + if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(Disc)) + return true; + if (EatIfPresent(lltok::comma) && parseGlobalTypeAndValue(AddrDisc)) + return true; + if (EatIfPresent(lltok::comma) && + parseGlobalTypeAndValue(DeactivationSymbol)) + return true; if (parseToken(lltok::rparen, "expected ')' in constant ptrauth expression")) return true; @@ -4264,7 +4268,16 @@ bool LLParser::parseValID(ValID &ID, PerFunctionState *PFS, Type *ExpectedTy) { AddrDisc = ConstantPointerNull::get(PointerType::get(Context, 0)); } - ID.ConstantVal = ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc); + if (DeactivationSymbol) { + if (!DeactivationSymbol->getType()->isPointerTy()) + return error( + ID.Loc, "constant ptrauth deactivation symbol must be a pointer"); + } else { + DeactivationSymbol = ConstantPointerNull::get(PointerType::get(Context, 0)); + } + + ID.ConstantVal = + ConstantPtrAuth::get(Ptr, KeyC, DiscC, AddrDisc, DeactivationSymbol); ID.Kind = ValID::t_Constant; return false; } diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index 290d873c632c9..03df10410a8d9 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1608,7 +1608,13 @@ Expected BitcodeReader::materializeValue(unsigned StartValID, if (!Disc) return error("ptrauth disc operand must be ConstantInt"); - C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3]); + auto *DeactivationSymbol = + ConstOps.size() > 4 ? ConstOps[4] + : ConstantPointerNull::get(cast( + ConstOps[3]->getType())); + + C = ConstantPtrAuth::get(ConstOps[0], Key, Disc, ConstOps[3], + DeactivationSymbol); break; } case BitcodeConstant::NoCFIOpcode: { @@ -3808,6 +3814,16 @@ Error BitcodeReader::parseConstants() { (unsigned)Record[2], (unsigned)Record[3]}); break; } + case bitc::CST_CODE_PTRAUTH2: { + if (Record.size() < 4) + return error("Invalid ptrauth record"); + // Ptr, Key, Disc, AddrDisc, DeactivationSymbol + V = BitcodeConstant::create( + Alloc, CurTy, BitcodeConstant::ConstantPtrAuthOpcode, + {(unsigned)Record[0], (unsigned)Record[1], (unsigned)Record[2], + (unsigned)Record[3], (unsigned)Record[4]}); + break; + } } assert(V->getType() == getTypeByID(CurTyID) && "Incorrect result type ID"); diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 7e0d81ff4b196..adf79b96da943 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -3010,11 +3010,12 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal, Record.push_back(VE.getTypeID(NC->getGlobalValue()->getType())); Record.push_back(VE.getValueID(NC->getGlobalValue())); } else if (const auto *CPA = dyn_cast(C)) { - Code = bitc::CST_CODE_PTRAUTH; + Code = bitc::CST_CODE_PTRAUTH2; Record.push_back(VE.getValueID(CPA->getPointer())); Record.push_back(VE.getValueID(CPA->getKey())); Record.push_back(VE.getValueID(CPA->getDiscriminator())); Record.push_back(VE.getValueID(CPA->getAddrDiscriminator())); + Record.push_back(VE.getValueID(CPA->getDeactivationSymbol())); } else { #ifndef NDEBUG C->dump(); diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp index 9ba17829d2929..0256b51b611f5 100644 --- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp @@ -195,6 +195,10 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB, assert(Info.CFIType->getType()->isIntegerTy(32) && "Invalid CFI type"); } + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + Info.DeactivationSymbol = cast(Bundle->Inputs[0]); + } + Info.CB = &CB; Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees); Info.CallConv = CallConv; diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index fd38c308003b7..1e8bd121cee6a 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2848,6 +2848,9 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { } } + if (auto Bundle = CI.getOperandBundle(LLVMContext::OB_deactivation_symbol)) + MIB->setDeactivationSymbol(*MF, Bundle->Inputs[0].get()); + return true; } diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp index 27df7e369436a..aa10d909e04e4 100644 --- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp @@ -38,8 +38,10 @@ void MachineIRBuilder::setMF(MachineFunction &MF) { //------------------------------------------------------------------------------ MachineInstrBuilder MachineIRBuilder::buildInstrNoInsert(unsigned Opcode) { - return BuildMI(getMF(), {getDL(), getPCSections(), getMMRAMetadata()}, - getTII().get(Opcode)); + return BuildMI( + getMF(), + {getDL(), getPCSections(), getMMRAMetadata(), getDeactivationSymbol()}, + getTII().get(Opcode)); } MachineInstrBuilder MachineIRBuilder::insertInstr(MachineInstrBuilder MIB) { diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.cpp b/llvm/lib/CodeGen/MIRParser/MILexer.cpp index 8b72c295416a2..dbd56c7414f38 100644 --- a/llvm/lib/CodeGen/MIRParser/MILexer.cpp +++ b/llvm/lib/CodeGen/MIRParser/MILexer.cpp @@ -281,6 +281,7 @@ static MIToken::TokenKind getIdentifierKind(StringRef Identifier) { .Case("heap-alloc-marker", MIToken::kw_heap_alloc_marker) .Case("pcsections", MIToken::kw_pcsections) .Case("cfi-type", MIToken::kw_cfi_type) + .Case("deactivation-symbol", MIToken::kw_deactivation_symbol) .Case("bbsections", MIToken::kw_bbsections) .Case("bb_id", MIToken::kw_bb_id) .Case("unknown-size", MIToken::kw_unknown_size) diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.h b/llvm/lib/CodeGen/MIRParser/MILexer.h index 0627f176b9e00..0407a0e7540d7 100644 --- a/llvm/lib/CodeGen/MIRParser/MILexer.h +++ b/llvm/lib/CodeGen/MIRParser/MILexer.h @@ -136,6 +136,7 @@ struct MIToken { kw_heap_alloc_marker, kw_pcsections, kw_cfi_type, + kw_deactivation_symbol, kw_bbsections, kw_bb_id, kw_unknown_size, diff --git a/llvm/lib/CodeGen/MIRParser/MIParser.cpp b/llvm/lib/CodeGen/MIRParser/MIParser.cpp index 6a464d9dd6886..530b892fb0203 100644 --- a/llvm/lib/CodeGen/MIRParser/MIParser.cpp +++ b/llvm/lib/CodeGen/MIRParser/MIParser.cpp @@ -1072,6 +1072,7 @@ bool MIParser::parse(MachineInstr *&MI) { Token.isNot(MIToken::kw_heap_alloc_marker) && Token.isNot(MIToken::kw_pcsections) && Token.isNot(MIToken::kw_cfi_type) && + Token.isNot(MIToken::kw_deactivation_symbol) && Token.isNot(MIToken::kw_debug_location) && Token.isNot(MIToken::kw_debug_instr_number) && Token.isNot(MIToken::coloncolon) && Token.isNot(MIToken::lbrace)) { @@ -1120,6 +1121,14 @@ bool MIParser::parse(MachineInstr *&MI) { lex(); } + GlobalValue *DS = nullptr; + if (Token.is(MIToken::kw_deactivation_symbol)) { + lex(); + if (parseGlobalValue(DS)) + return true; + lex(); + } + unsigned InstrNum = 0; if (Token.is(MIToken::kw_debug_instr_number)) { lex(); @@ -1194,6 +1203,8 @@ bool MIParser::parse(MachineInstr *&MI) { MI->setPCSections(MF, PCSections); if (CFIType) MI->setCFIType(MF, CFIType); + if (DS) + MI->setDeactivationSymbol(MF, DS); if (!MemOperands.empty()) MI->setMemRefs(MF, MemOperands); if (InstrNum) diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp index ce1834a90ca54..b526df95da28f 100644 --- a/llvm/lib/CodeGen/MIRPrinter.cpp +++ b/llvm/lib/CodeGen/MIRPrinter.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/MIRFormatter.h" #include "llvm/CodeGen/MIRYamlMapping.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineConstantPool.h" @@ -863,6 +864,10 @@ static void printMI(raw_ostream &OS, MFPrintState &State, } if (uint32_t CFIType = MI.getCFIType()) OS << LS << " cfi-type " << CFIType; + if (Value *DS = MI.getDeactivationSymbol()) { + OS << LS << "deactivation-symbol "; + MIRFormatter::printIRValue(OS, *DS, State.MST); + } if (auto Num = MI.peekDebugInstrNum()) OS << LS << " debug-instr-number " << Num; diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp index ec40f6af3caae..93b00eca453f0 100644 --- a/llvm/lib/CodeGen/MachineFunction.cpp +++ b/llvm/lib/CodeGen/MachineFunction.cpp @@ -608,10 +608,10 @@ MachineFunction::getMachineMemOperand(const MachineMemOperand *MMO, MachineInstr::ExtraInfo *MachineFunction::createMIExtraInfo( ArrayRef MMOs, MCSymbol *PreInstrSymbol, MCSymbol *PostInstrSymbol, MDNode *HeapAllocMarker, MDNode *PCSections, - uint32_t CFIType, MDNode *MMRAs) { + uint32_t CFIType, MDNode *MMRAs, Value *DS) { return MachineInstr::ExtraInfo::create(Allocator, MMOs, PreInstrSymbol, PostInstrSymbol, HeapAllocMarker, - PCSections, CFIType, MMRAs); + PCSections, CFIType, MMRAs, DS); } const char *MachineFunction::createExternalSymbolName(StringRef Name) { diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp index 79047f732808a..d833aeab38e61 100644 --- a/llvm/lib/CodeGen/MachineInstr.cpp +++ b/llvm/lib/CodeGen/MachineInstr.cpp @@ -322,15 +322,17 @@ void MachineInstr::setExtraInfo(MachineFunction &MF, MCSymbol *PreInstrSymbol, MCSymbol *PostInstrSymbol, MDNode *HeapAllocMarker, MDNode *PCSections, - uint32_t CFIType, MDNode *MMRAs) { + uint32_t CFIType, MDNode *MMRAs, Value *DS) { bool HasPreInstrSymbol = PreInstrSymbol != nullptr; bool HasPostInstrSymbol = PostInstrSymbol != nullptr; bool HasHeapAllocMarker = HeapAllocMarker != nullptr; bool HasPCSections = PCSections != nullptr; bool HasCFIType = CFIType != 0; bool HasMMRAs = MMRAs != nullptr; + bool HasDS = DS != nullptr; int NumPointers = MMOs.size() + HasPreInstrSymbol + HasPostInstrSymbol + - HasHeapAllocMarker + HasPCSections + HasCFIType + HasMMRAs; + HasHeapAllocMarker + HasPCSections + HasCFIType + HasMMRAs + + HasDS; // Drop all extra info if there is none. if (NumPointers <= 0) { @@ -343,10 +345,10 @@ void MachineInstr::setExtraInfo(MachineFunction &MF, // 32-bit pointers. // FIXME: Maybe we should make the symbols in the extra info mutable? else if (NumPointers > 1 || HasMMRAs || HasHeapAllocMarker || HasPCSections || - HasCFIType) { + HasCFIType || HasDS) { Info.set( MF.createMIExtraInfo(MMOs, PreInstrSymbol, PostInstrSymbol, - HeapAllocMarker, PCSections, CFIType, MMRAs)); + HeapAllocMarker, PCSections, CFIType, MMRAs, DS)); return; } @@ -365,7 +367,7 @@ void MachineInstr::dropMemRefs(MachineFunction &MF) { setExtraInfo(MF, {}, getPreInstrSymbol(), getPostInstrSymbol(), getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setMemRefs(MachineFunction &MF, @@ -377,7 +379,7 @@ void MachineInstr::setMemRefs(MachineFunction &MF, setExtraInfo(MF, MMOs, getPreInstrSymbol(), getPostInstrSymbol(), getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::addMemOperand(MachineFunction &MF, @@ -488,7 +490,7 @@ void MachineInstr::setPreInstrSymbol(MachineFunction &MF, MCSymbol *Symbol) { setExtraInfo(MF, memoperands(), Symbol, getPostInstrSymbol(), getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setPostInstrSymbol(MachineFunction &MF, MCSymbol *Symbol) { @@ -504,7 +506,7 @@ void MachineInstr::setPostInstrSymbol(MachineFunction &MF, MCSymbol *Symbol) { setExtraInfo(MF, memoperands(), getPreInstrSymbol(), Symbol, getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setHeapAllocMarker(MachineFunction &MF, MDNode *Marker) { @@ -513,7 +515,8 @@ void MachineInstr::setHeapAllocMarker(MachineFunction &MF, MDNode *Marker) { return; setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), - Marker, getPCSections(), getCFIType(), getMMRAMetadata()); + Marker, getPCSections(), getCFIType(), getMMRAMetadata(), + getDeactivationSymbol()); } void MachineInstr::setPCSections(MachineFunction &MF, MDNode *PCSections) { @@ -523,7 +526,7 @@ void MachineInstr::setPCSections(MachineFunction &MF, MDNode *PCSections) { setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), getHeapAllocMarker(), PCSections, getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setCFIType(MachineFunction &MF, uint32_t Type) { @@ -532,7 +535,8 @@ void MachineInstr::setCFIType(MachineFunction &MF, uint32_t Type) { return; setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), - getHeapAllocMarker(), getPCSections(), Type, getMMRAMetadata()); + getHeapAllocMarker(), getPCSections(), Type, getMMRAMetadata(), + getDeactivationSymbol()); } void MachineInstr::setMMRAMetadata(MachineFunction &MF, MDNode *MMRAs) { @@ -541,7 +545,18 @@ void MachineInstr::setMMRAMetadata(MachineFunction &MF, MDNode *MMRAs) { return; setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), - getHeapAllocMarker(), getPCSections(), getCFIType(), MMRAs); + getHeapAllocMarker(), getPCSections(), getCFIType(), MMRAs, + getDeactivationSymbol()); +} + +void MachineInstr::setDeactivationSymbol(MachineFunction &MF, Value *DS) { + // Do nothing if old and new symbols are the same. + if (DS == getDeactivationSymbol()) + return; + + setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), + getHeapAllocMarker(), getPCSections(), getCFIType(), + getMMRAMetadata(), DS); } void MachineInstr::cloneInstrSymbols(MachineFunction &MF, @@ -730,6 +745,8 @@ bool MachineInstr::isIdenticalTo(const MachineInstr &Other, // Call instructions with different CFI types are not identical. if (isCall() && getCFIType() != Other.getCFIType()) return false; + if (getDeactivationSymbol() != Other.getDeactivationSymbol()) + return false; return true; } @@ -2035,6 +2052,8 @@ void MachineInstr::print(raw_ostream &OS, ModuleSlotTracker &MST, OS << ','; OS << " cfi-type " << CFIType; } + if (getDeactivationSymbol()) + OS << ", deactivation-symbol " << getDeactivationSymbol()->getName(); if (DebugInstrNum) { if (!FirstOp) diff --git a/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp b/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp index 8de2c48581a1e..5d7129ae0cb87 100644 --- a/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp +++ b/llvm/lib/CodeGen/PreISelIntrinsicLowering.cpp @@ -21,9 +21,11 @@ #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/RuntimeLibcalls.h" #include "llvm/IR/Type.h" @@ -461,6 +463,162 @@ bool PreISelIntrinsicLowering::expandMemIntrinsicUses( return Changed; } +namespace { + +enum class PointerEncoding { + Rotate, + PACCopyable, + PACNonCopyable, +}; + +bool expandProtectedFieldPtr(Function &Intr) { + Module &M = *Intr.getParent(); + + SmallPtrSet DSsToDeactivate; + SmallPtrSet LoadsStores; + + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + Type *Int64Ty = Type::getInt64Ty(M.getContext()); + PointerType *PtrTy = PointerType::get(M.getContext(), 0); + + Function *SignIntr = + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_sign, {}); + Function *AuthIntr = + Intrinsic::getOrInsertDeclaration(&M, Intrinsic::ptrauth_auth, {}); + + auto *EmuFnTy = FunctionType::get(Int64Ty, {Int64Ty, Int64Ty}, false); + FunctionCallee EmuSignIntr = M.getOrInsertFunction("__emupac_pacda", EmuFnTy); + FunctionCallee EmuAuthIntr = M.getOrInsertFunction("__emupac_autda", EmuFnTy); + + auto CreateSign = [&](IRBuilder<> &B, Value *Val, Value *Disc, + OperandBundleDef DSBundle) { + Function *F = B.GetInsertBlock()->getParent(); + Attribute FSAttr = F->getFnAttribute("target-features"); + if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) + return B.CreateCall(SignIntr, {Val, B.getInt32(2), Disc}, DSBundle); + return B.CreateCall(EmuSignIntr, {Val, Disc}, DSBundle); + }; + + auto CreateAuth = [&](IRBuilder<> &B, Value *Val, Value *Disc, + OperandBundleDef DSBundle) { + Function *F = B.GetInsertBlock()->getParent(); + Attribute FSAttr = F->getFnAttribute("target-features"); + if (FSAttr.isValid() && FSAttr.getValueAsString().contains("+pauth")) + return B.CreateCall(AuthIntr, {Val, B.getInt32(2), Disc}, DSBundle); + return B.CreateCall(EmuAuthIntr, {Val, Disc}, DSBundle); + }; + + auto GetDeactivationSymbol = [&](CallInst *Call) -> GlobalValue * { + if (auto Bundle = + Call->getOperandBundle(LLVMContext::OB_deactivation_symbol)) + return cast(Bundle->Inputs[0]); + return nullptr; + }; + + for (User *U : Intr.users()) { + auto *Call = cast(U); + auto *DS = GetDeactivationSymbol(Call); + + for (Use &U : Call->uses()) { + if (auto *LI = dyn_cast(U.getUser())) { + if (isa(LI->getType())) { + LoadsStores.insert(LI); + continue; + } + } + if (auto *SI = dyn_cast(U.getUser())) { + if (U.getOperandNo() == 1 && + isa(SI->getValueOperand()->getType())) { + LoadsStores.insert(SI); + continue; + } + } + // Comparisons against null cannot be used to recover the original + // pointer so we allow them. + if (auto *CI = dyn_cast(U.getUser())) { + if (auto *Op = dyn_cast(CI->getOperand(0))) + if (Op->isNullValue()) + continue; + if (auto *Op = dyn_cast(CI->getOperand(1))) + if (Op->isNullValue()) + continue; + } + if (DS) + DSsToDeactivate.insert(DS); + } + } + + for (Instruction *I : LoadsStores) { + auto *PointerOperand = isa(I) + ? cast(I)->getPointerOperand() + : cast(I)->getPointerOperand(); + auto *Call = cast(PointerOperand); + + auto *Disc = Call->getArgOperand(1); + bool UseHWEncoding = cast(Call->getArgOperand(2))->getZExtValue(); + + GlobalValue *DS = GetDeactivationSymbol(Call); + OperandBundleDef DSBundle("deactivation-symbol", DS); + + if (auto *LI = dyn_cast(I)) { + IRBuilder<> B(LI->getNextNode()); + auto *LIInt = cast(B.CreatePtrToInt(LI, B.getInt64Ty())); + Value *Auth; + if (UseHWEncoding) { + Auth = CreateAuth(B, LIInt, Disc, DSBundle); + } else { + Auth = B.CreateAdd(LIInt, Disc); + Auth = B.CreateIntrinsic( + Auth->getType(), Intrinsic::fshr, + {Auth, Auth, ConstantInt::get(Auth->getType(), 16)}); + } + LI->replaceAllUsesWith(B.CreateIntToPtr(Auth, B.getPtrTy())); + LIInt->setOperand(0, LI); + } else { + auto *SI = cast(I); + IRBuilder<> B(SI); + auto *SIValInt = + B.CreatePtrToInt(SI->getValueOperand(), B.getInt64Ty()); + Value *Sign; + if (UseHWEncoding) { + Sign = CreateSign(B, SIValInt, Disc, DSBundle); + } else { + Sign = B.CreateIntrinsic( + SIValInt->getType(), Intrinsic::fshl, + {SIValInt, SIValInt, ConstantInt::get(SIValInt->getType(), 16)}); + Sign = B.CreateSub(Sign, Disc); + } + SI->setOperand(0, B.CreateIntToPtr(Sign, B.getPtrTy())); + } + } + + for (User *U : llvm::make_early_inc_range(Intr.users())) { + auto *Call = cast(U); + auto *Pointer = Call->getArgOperand(0); + + Call->replaceAllUsesWith(Pointer); + Call->eraseFromParent(); + } + + if (!DSsToDeactivate.empty()) { + Constant *Nop = + ConstantExpr::getIntToPtr(ConstantInt::get(Int64Ty, 0xd503201f), PtrTy); + for (GlobalValue *OldDS : DSsToDeactivate) { + GlobalValue *DS = GlobalAlias::create( + Int8Ty, 0, GlobalValue::ExternalLinkage, OldDS->getName(), Nop, &M); + DS->setVisibility(GlobalValue::HiddenVisibility); + if (OldDS) { + DS->takeName(OldDS); + OldDS->replaceAllUsesWith(DS); + OldDS->eraseFromParent(); + } + } + } + return true; +} + +} + bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const { // Map unique constants to globals. DenseMap CMap; @@ -598,6 +756,9 @@ bool PreISelIntrinsicLowering::lowerIntrinsics(Module &M) const { return lowerUnaryVectorIntrinsicAsLoop(M, CI); }); break; + case Intrinsic::protected_field_ptr: + Changed |= expandProtectedFieldPtr(F); + break; } } return Changed; diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp index 8c8daef6dccd4..b0301a08c3384 100644 --- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp @@ -15,10 +15,12 @@ #include "InstrEmitter.h" #include "SDNodeDbgValue.h" #include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetLowering.h" @@ -61,6 +63,8 @@ static unsigned countOperands(SDNode *Node, unsigned NumExpUses, unsigned N = Node->getNumOperands(); while (N && Node->getOperand(N - 1).getValueType() == MVT::Glue) --N; + if (N && Node->getOperand(N - 1).getOpcode() == ISD::DEACTIVATION_SYMBOL) + --N; // Ignore deactivation symbol if it exists. if (N && Node->getOperand(N - 1).getValueType() == MVT::Other) --N; // Ignore chain if it exists. @@ -1224,15 +1228,23 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned, } } - if (SDNode *GluedNode = Node->getGluedNode()) { - // FIXME: Possibly iterate over multiple glue nodes? - if (GluedNode->getOpcode() == - ~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) { - Register VReg = getVR(GluedNode->getOperand(0), VRBaseMap); - MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false, - /*isImp=*/true); - MIB->addOperand(MO); - } + unsigned Op = Node->getNumOperands(); + if (Op != 0 && Node->getOperand(Op - 1)->getOpcode() == + ~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) { + Register VReg = getVR(Node->getOperand(Op - 1)->getOperand(0), VRBaseMap); + MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false, + /*isImp=*/true); + MIB->addOperand(MO); + Op--; + } + + if (Op != 0 && + Node->getOperand(Op - 1)->getOpcode() == ISD::DEACTIVATION_SYMBOL) { + MI->setDeactivationSymbol( + *MF, const_cast( + cast(Node->getOperand(Op - 1)) + ->getGlobal())); + Op--; } // Run post-isel target hook to adjust this instruction if needed. @@ -1253,7 +1265,8 @@ EmitSpecialNode(SDNode *Node, bool IsClone, bool IsCloned, llvm_unreachable("This target-independent node should have been selected!"); case ISD::EntryToken: case ISD::MERGE_VALUES: - case ISD::TokenFactor: // fall thru + case ISD::TokenFactor: + case ISD::DEACTIVATION_SYMBOL: break; case ISD::CopyToReg: { Register DestReg = cast(Node->getOperand(1))->getReg(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index f41b6eb26bbda..a373ba200bbee 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1910,6 +1910,21 @@ SDValue SelectionDAG::getGlobalAddress(const GlobalValue *GV, const SDLoc &DL, return SDValue(N, 0); } +SDValue SelectionDAG::getDeactivationSymbol(const GlobalValue *GV) { + SDVTList VTs = getVTList(MVT::Untyped); + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::DEACTIVATION_SYMBOL, VTs, {}); + ID.AddPointer(GV); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) + return SDValue(E, 0); + + auto *N = newSDNode(GV, VTs); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + SDValue SelectionDAG::getFrameIndex(int FI, EVT VT, bool isTarget) { unsigned Opc = isTarget ? ISD::TargetFrameIndex : ISD::FrameIndex; SDVTList VTs = getVTList(VT); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 306e068f1c1da..91ef834818454 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -46,6 +46,7 @@ #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/SelectionDAGTargetInfo.h" #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/SwiftErrorValueTracking.h" @@ -5379,6 +5380,13 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I, // Create the node. SDValue Result; + if (auto Bundle = I.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + auto *Sym = Bundle->Inputs[0].get(); + SDValue SDSym = getValue(Sym); + SDSym = DAG.getDeactivationSymbol(cast(Sym)); + Ops.push_back(SDSym); + } + if (auto Bundle = I.getOperandBundle(LLVMContext::OB_convergencectrl)) { auto *Token = Bundle->Inputs[0].get(); SDValue ConvControlToken = getValue(Token); @@ -8949,6 +8957,11 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee, ConvControlToken = getValue(Token); } + GlobalValue *DeactivationSymbol = nullptr; + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + DeactivationSymbol = cast(Bundle->Inputs[0].get()); + } + TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(getCurSDLoc()) .setChain(getRoot()) @@ -8958,7 +8971,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee, .setIsPreallocated( CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0) .setCFIType(CFIType) - .setConvergenceControlToken(ConvControlToken); + .setConvergenceControlToken(ConvControlToken) + .setDeactivationSymbol(DeactivationSymbol); // Set the pointer authentication info if we have it. if (PAI) { @@ -9575,7 +9589,7 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) { {LLVMContext::OB_deopt, LLVMContext::OB_funclet, LLVMContext::OB_cfguardtarget, LLVMContext::OB_preallocated, LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_kcfi, - LLVMContext::OB_convergencectrl}); + LLVMContext::OB_convergencectrl, LLVMContext::OB_deactivation_symbol}); SDValue Callee = getValue(I.getCalledOperand()); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index 26071ed70c9db..ff01019182e73 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3268,6 +3268,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case ISD::LIFETIME_START: case ISD::LIFETIME_END: case ISD::PSEUDO_PROBE: + case ISD::DEACTIVATION_SYMBOL: NodeToMatch->setNodeId(-1); // Mark selected. return; case ISD::AssertSext: @@ -3346,7 +3347,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // These are the current input chain and glue for use when generating nodes. // Various Emit operations change these. For example, emitting a copytoreg // uses and updates these. - SDValue InputChain, InputGlue; + SDValue InputChain, InputGlue, DeactivationSymbol; // ChainNodesMatched - If a pattern matches nodes that have input/output // chains, the OPC_EmitMergeInputChains operation is emitted which indicates @@ -3499,6 +3500,15 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, InputGlue = N->getOperand(N->getNumOperands()-1); continue; + case OPC_CaptureDeactivationSymbol: + // If the current node has a deactivation symbol, capture it in + // DeactivationSymbol. + if (N->getNumOperands() != 0 && + N->getOperand(N->getNumOperands() - 1).getOpcode() == + ISD::DEACTIVATION_SYMBOL) + DeactivationSymbol = N->getOperand(N->getNumOperands()-1); + continue; + case OPC_MoveChild: { unsigned ChildNo = MatcherTable[MatcherIndex++]; if (ChildNo >= N.getNumOperands()) @@ -4180,6 +4190,8 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // If this has chain/glue inputs, add them. if (EmitNodeInfo & OPFL_Chain) Ops.push_back(InputChain); + if (DeactivationSymbol.getNode() != nullptr) + Ops.push_back(DeactivationSymbol); if ((EmitNodeInfo & OPFL_GlueInput) && InputGlue.getNode() != nullptr) Ops.push_back(InputGlue); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index e5a4e1e6b7ce5..a1facc9ba1e19 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -1667,12 +1667,14 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, if (const ConstantPtrAuth *CPA = dyn_cast(CV)) { Out << "ptrauth ("; - // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC]?]?) + // ptrauth (ptr CST, i32 KEY[, i64 DISC[, ptr ADDRDISC[, ptr DS]?]?]?) unsigned NumOpsToWrite = 2; if (!CPA->getOperand(2)->isNullValue()) NumOpsToWrite = 3; if (!CPA->getOperand(3)->isNullValue()) NumOpsToWrite = 4; + if (!CPA->getOperand(4)->isNullValue()) + NumOpsToWrite = 5; ListSeparator LS; for (unsigned i = 0, e = NumOpsToWrite; i != e; ++i) { diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index a3c725b2af62a..d87bc6c11400e 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2060,19 +2060,22 @@ Value *NoCFIValue::handleOperandChangeImpl(Value *From, Value *To) { // ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) { - Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc}; + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) { + Constant *ArgVec[] = {Ptr, Key, Disc, AddrDisc, DeactivationSymbol}; ConstantPtrAuthKeyType MapKey(ArgVec); LLVMContextImpl *pImpl = Ptr->getContext().pImpl; return pImpl->ConstantPtrAuths.getOrCreate(Ptr->getType(), MapKey); } ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const { - return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator()); + return get(Pointer, getKey(), getDiscriminator(), getAddrDiscriminator(), + getDeactivationSymbol()); } ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) : Constant(Ptr->getType(), Value::ConstantPtrAuthVal, AllocMarker) { assert(Ptr->getType()->isPointerTy()); assert(Key->getBitWidth() == 32); @@ -2082,6 +2085,7 @@ ConstantPtrAuth::ConstantPtrAuth(Constant *Ptr, ConstantInt *Key, setOperand(1, Key); setOperand(2, Disc); setOperand(3, AddrDisc); + setOperand(4, DeactivationSymbol); } /// Remove the constant from the constant table. diff --git a/llvm/lib/IR/ConstantsContext.h b/llvm/lib/IR/ConstantsContext.h index 51fb40bad201d..584391b53800a 100644 --- a/llvm/lib/IR/ConstantsContext.h +++ b/llvm/lib/IR/ConstantsContext.h @@ -539,7 +539,8 @@ struct ConstantPtrAuthKeyType { ConstantPtrAuth *create(TypeClass *Ty) const { return new ConstantPtrAuth(Operands[0], cast(Operands[1]), - cast(Operands[2]), Operands[3]); + cast(Operands[2]), Operands[3], + Operands[4]); } }; diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp index f7ef4aa473ef5..904111e742be0 100644 --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -1699,7 +1699,9 @@ LLVMValueRef LLVMConstantPtrAuth(LLVMValueRef Ptr, LLVMValueRef Key, LLVMValueRef Disc, LLVMValueRef AddrDisc) { return wrap(ConstantPtrAuth::get( unwrap(Ptr), unwrap(Key), - unwrap(Disc), unwrap(AddrDisc))); + unwrap(Disc), unwrap(AddrDisc), + ConstantPointerNull::get( + cast(unwrap(AddrDisc)->getType())))); } /*-- Opcode mapping */ diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp index 57532cd491dd6..647c1661083a4 100644 --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -53,6 +53,8 @@ static StringRef knownBundleName(unsigned BundleTagID) { return "kcfi"; case LLVMContext::OB_convergencectrl: return "convergencectrl"; + case LLVMContext::OB_deactivation_symbol: + return "deactivation-symbol"; default: llvm_unreachable("unknown bundle id"); } @@ -76,7 +78,7 @@ LLVMContext::LLVMContext() : pImpl(new LLVMContextImpl(*this)) { } for (unsigned BundleTagID = LLVMContext::OB_deopt; - BundleTagID <= LLVMContext::OB_convergencectrl; ++BundleTagID) { + BundleTagID <= LLVMContext::OB_deactivation_symbol; ++BundleTagID) { [[maybe_unused]] const auto *Entry = pImpl->getOrInsertBundleTag(knownBundleName(BundleTagID)); assert(Entry->second == BundleTagID && "operand bundle id drifted!"); diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 3ff9895e161c4..3478c2c450ae7 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2627,6 +2627,11 @@ void Verifier::visitConstantPtrAuth(const ConstantPtrAuth *CPA) { Check(CPA->getDiscriminator()->getBitWidth() == 64, "signed ptrauth constant discriminator must be i64 constant integer"); + + Check(isa(CPA->getDeactivationSymbol()) || + CPA->getDeactivationSymbol()->isNullValue(), + "signed ptrauth constant deactivation symbol must be a global value " + "or null"); } bool Verifier::verifyAttributeCount(AttributeList Attrs, unsigned Params) { diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp index 9de88ef2cf0a0..eb14797af081c 100644 --- a/llvm/lib/SandboxIR/Constant.cpp +++ b/llvm/lib/SandboxIR/Constant.cpp @@ -412,10 +412,12 @@ PointerType *NoCFIValue::getType() const { } ConstantPtrAuth *ConstantPtrAuth::get(Constant *Ptr, ConstantInt *Key, - ConstantInt *Disc, Constant *AddrDisc) { + ConstantInt *Disc, Constant *AddrDisc, + Constant *DeactivationSymbol) { auto *LLVMC = llvm::ConstantPtrAuth::get( cast(Ptr->Val), cast(Key->Val), - cast(Disc->Val), cast(AddrDisc->Val)); + cast(Disc->Val), cast(AddrDisc->Val), + cast(DeactivationSymbol->Val)); return cast(Ptr->getContext().getOrCreateConstant(LLVMC)); } @@ -439,6 +441,11 @@ Constant *ConstantPtrAuth::getAddrDiscriminator() const { cast(Val)->getAddrDiscriminator()); } +Constant *ConstantPtrAuth::getDeactivationSymbol() const { + return Ctx.getOrCreateConstant( + cast(Val)->getDeactivationSymbol()); +} + ConstantPtrAuth *ConstantPtrAuth::getWithSameSchema(Constant *Pointer) const { auto *LLVMC = cast(Val)->getWithSameSchema( cast(Pointer->Val)); diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index c52487ab8a79a..b5922c97dafed 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -49,12 +49,14 @@ #include "llvm/IR/Module.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCContext.h" +#include "llvm/MC/MCExpr.h" #include "llvm/MC/MCInst.h" #include "llvm/MC/MCInstBuilder.h" #include "llvm/MC/MCSectionELF.h" #include "llvm/MC/MCSectionMachO.h" #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" +#include "llvm/MC/MCValue.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -95,6 +97,7 @@ class AArch64AsmPrinter : public AsmPrinter { bool EnableImportCallOptimization = false; DenseMap>> SectionToImportedFunctionCalls; + unsigned PAuthIFuncNextUniqueID = 1; public: static char ID; @@ -174,7 +177,12 @@ class AArch64AsmPrinter : public AsmPrinter { const MachineOperand *AUTAddrDisc, Register Scratch, std::optional PACKey, - uint64_t PACDisc, Register PACAddrDisc); + uint64_t PACDisc, Register PACAddrDisc, Value *DS); + + // Emit R_AARCH64_PATCHINST, the deactivation symbol relocation. Returns true + // if no instruction should be emitted because the deactivation symbol is + // defined in the current module so this function emitted a NOP instead. + bool emitDeactivationSymbolRelocation(Value *DS); // Emit the sequence for PAC. void emitPtrauthSign(const MachineInstr *MI); @@ -212,6 +220,10 @@ class AArch64AsmPrinter : public AsmPrinter { // authenticating) void LowerLOADgotAUTH(const MachineInstr &MI); + const MCExpr *emitPAuthRelocationAsIRelative( + const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr); + /// tblgen'erated driver function for lowering simple MI->MC /// pseudo instructions. bool lowerPseudoInstExpansion(const MachineInstr *MI, MCInst &Inst); @@ -2074,11 +2086,31 @@ void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) { /*ShouldTrap=*/true, /*OnFailure=*/nullptr); } +bool AArch64AsmPrinter::emitDeactivationSymbolRelocation(Value *DS) { + if (!DS) + return false; + + if (isa(DS)) { + // Just emit the nop directly. + EmitToStreamer(MCInstBuilder(AArch64::HINT).addImm(0)); + return true; + } + MCSymbol *Dot = OutContext.createTempSymbol(); + OutStreamer->emitLabel(Dot); + const MCExpr *DeactDotExpr = MCSymbolRefExpr::create(Dot, OutContext); + + const MCExpr *DSExpr = MCSymbolRefExpr::create( + OutContext.getOrCreateSymbol(DS->getName()), OutContext); + OutStreamer->emitRelocDirective(*DeactDotExpr, "R_AARCH64_PATCHINST", DSExpr, + SMLoc()); + return false; +} + void AArch64AsmPrinter::emitPtrauthAuthResign( Register AUTVal, AArch64PACKey::ID AUTKey, uint64_t AUTDisc, const MachineOperand *AUTAddrDisc, Register Scratch, std::optional PACKey, uint64_t PACDisc, - Register PACAddrDisc) { + Register PACAddrDisc, Value *DS) { const bool IsAUTPAC = PACKey.has_value(); // We expand AUT/AUTPAC into a sequence of the form @@ -2125,15 +2157,17 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( bool AUTZero = AUTDiscReg == AArch64::XZR; unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero); - // autiza x16 ; if AUTZero - // autia x16, x17 ; if !AUTZero - MCInst AUTInst; - AUTInst.setOpcode(AUTOpc); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - if (!AUTZero) - AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); - EmitToStreamer(*OutStreamer, AUTInst); + if (!emitDeactivationSymbolRelocation(DS)) { + // autiza x16 ; if AUTZero + // autia x16, x17 ; if !AUTZero + MCInst AUTInst; + AUTInst.setOpcode(AUTOpc); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + if (!AUTZero) + AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); + EmitToStreamer(*OutStreamer, AUTInst); + } // Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done. if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) @@ -2198,6 +2232,9 @@ void AArch64AsmPrinter::emitPtrauthSign(const MachineInstr *MI) { bool IsZeroDisc = DiscReg == AArch64::XZR; unsigned Opc = getPACOpcodeForKey(Key, IsZeroDisc); + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + // paciza x16 ; if IsZeroDisc // pacia x16, x17 ; if !IsZeroDisc MCInst PACInst; @@ -2259,6 +2296,161 @@ void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) { EmitToStreamer(*OutStreamer, BRInst); } +static void emitAddress(MCStreamer &Streamer, MCRegister Reg, + const MCExpr *Expr, bool DSOLocal, + const MCSubtargetInfo &STI) { + MCValue Val; + if (!Expr->evaluateAsRelocatable(Val, nullptr)) + report_fatal_error("emitAddress could not evaluate"); + if (DSOLocal) { + Streamer.emitInstruction( + MCInstBuilder(AArch64::ADRP) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(Expr, AArch64::S_ABS_PAGE, + Streamer.getContext())), + STI); + Streamer.emitInstruction( + MCInstBuilder(AArch64::ADDXri) + .addReg(Reg) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(Expr, AArch64::S_LO12, + Streamer.getContext())) + .addImm(0), + STI); + } else { + auto *SymRef = MCSymbolRefExpr::create(Val.getAddSym(), Streamer.getContext()); + Streamer.emitInstruction( + MCInstBuilder(AArch64::ADRP) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(SymRef, AArch64::S_GOT_PAGE, + Streamer.getContext())), + STI); + Streamer.emitInstruction( + MCInstBuilder(AArch64::LDRXui) + .addReg(Reg) + .addReg(Reg) + .addExpr(MCSpecifierExpr::create(SymRef, AArch64::S_GOT_LO12, + Streamer.getContext())), + STI); + if (Val.getConstant()) + Streamer.emitInstruction(MCInstBuilder(AArch64::ADDXri) + .addReg(Reg) + .addReg(Reg) + .addImm(Val.getConstant()) + .addImm(0), + STI); + } +} + +static bool targetSupportsPAuthRelocation(const Triple &TT, + const MCExpr *Target, + const MCExpr *DSExpr) { + // No released version of glibc supports PAuth relocations. + if (TT.isOSGlibc()) + return false; + + // We emit PAuth constants as IRELATIVE relocations in cases where the + // constant cannot be represented as a PAuth relocation: + // 1) There is a deactivation symbol. + // 2) The signed value is not a symbol. + return !DSExpr && !isa(Target); +} + +static bool targetSupportsIRelativeRelocation(const Triple &TT) { + // IFUNCs are ELF-only. + if (!TT.isOSBinFormatELF()) + return false; + + // musl doesn't support IFUNCs. + if (TT.isMusl()) + return false; + + return true; +} + +const MCExpr *AArch64AsmPrinter::emitPAuthRelocationAsIRelative( + const MCExpr *Target, uint16_t Disc, AArch64PACKey::ID KeyID, + bool HasAddressDiversity, bool IsDSOLocal, const MCExpr *DSExpr) { + const Triple &TT = TM.getTargetTriple(); + + // We only emit an IRELATIVE relocation if the target supports IRELATIVE and + // does not support the kind of PAuth relocation that we are trying to emit. + if (targetSupportsPAuthRelocation(TT, Target, DSExpr) || + !targetSupportsIRelativeRelocation(TT)) + return nullptr; + + // For now, only the DA key is supported. + if (KeyID != AArch64PACKey::DA) + return nullptr; + + std::unique_ptr STI( + TM.getTarget().createMCSubtargetInfo(TT.str(), "", "")); + assert(STI && "Unable to create subtarget info"); + + MCSymbol *Place = OutStreamer->getContext().createTempSymbol(); + OutStreamer->emitLabel(Place); + OutStreamer->pushSection(); + + OutStreamer->switchSection(OutStreamer->getContext().getELFSection( + ".text.startup", ELF::SHT_PROGBITS, ELF::SHF_ALLOC | ELF::SHF_EXECINSTR, + 0, "", true, PAuthIFuncNextUniqueID++, nullptr)); + + MCSymbol *IRelativeSym = + OutStreamer->getContext().createLinkerPrivateSymbol("pauth_ifunc"); + OutStreamer->emitLabel(IRelativeSym); + if (isa(Target)) { + OutStreamer->emitInstruction(MCInstBuilder(AArch64::MOVZXi) + .addReg(AArch64::X0) + .addExpr(Target) + .addImm(0), + *STI); + } else { + emitAddress(*OutStreamer, AArch64::X0, Target, IsDSOLocal, *STI); + } + if (HasAddressDiversity) { + auto *PlacePlusDisc = MCBinaryExpr::createAdd( + MCSymbolRefExpr::create(Place, OutStreamer->getContext()), + MCConstantExpr::create(static_cast(Disc), + OutStreamer->getContext()), + OutStreamer->getContext()); + emitAddress(*OutStreamer, AArch64::X1, PlacePlusDisc, /*IsDSOLocal=*/true, + *STI); + } else { + emitMOVZ(AArch64::X1, Disc, 0); + } + + if (DSExpr) { + MCSymbol *PrePACInst = OutStreamer->getContext().createTempSymbol(); + OutStreamer->emitLabel(PrePACInst); + + auto *PrePACInstExpr = + MCSymbolRefExpr::create(PrePACInst, OutStreamer->getContext()); + OutStreamer->emitRelocDirective(*PrePACInstExpr, "R_AARCH64_PATCHINST", + DSExpr, SMLoc()); + } + + // We don't know the subtarget because this is being emitted for a global + // initializer. Because the performance of IFUNC resolvers is unimportant, we + // always call the EmuPAC runtime, which will end up using the PAC instruction + // if the target supports PAC. + MCSymbol *EmuPAC = + OutStreamer->getContext().getOrCreateSymbol("__emupac_pacda"); + const MCSymbolRefExpr *EmuPACRef = + MCSymbolRefExpr::create(EmuPAC, OutStreamer->getContext()); + OutStreamer->emitInstruction(MCInstBuilder(AArch64::B).addExpr(EmuPACRef), + *STI); + + // We need a RET despite the above tail call because the deactivation symbol + // may replace it with a NOP. + if (DSExpr) + OutStreamer->emitInstruction( + MCInstBuilder(AArch64::RET).addReg(AArch64::LR), *STI); + OutStreamer->popSection(); + + return MCSymbolRefExpr::create(IRelativeSym, AArch64::S_FUNCINIT, + OutStreamer->getContext()); +} + const MCExpr * AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { MCContext &Ctx = OutContext; @@ -2270,22 +2462,26 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { auto *BaseGVB = dyn_cast(BaseGV); - // If we can't understand the referenced ConstantExpr, there's nothing - // else we can do: emit an error. - if (!BaseGVB) { - BaseGV->getContext().emitError( - "cannot resolve target base/addend of ptrauth constant"); - return nullptr; + const MCExpr *Sym; + if (BaseGVB) { + // If there is an addend, turn that into the appropriate MCExpr. + Sym = MCSymbolRefExpr::create(getSymbol(BaseGVB), Ctx); + if (Offset.sgt(0)) + Sym = MCBinaryExpr::createAdd( + Sym, MCConstantExpr::create(Offset.getSExtValue(), Ctx), Ctx); + else if (Offset.slt(0)) + Sym = MCBinaryExpr::createSub( + Sym, MCConstantExpr::create((-Offset).getSExtValue(), Ctx), Ctx); + } else { + Sym = MCConstantExpr::create(Offset.getSExtValue(), Ctx); } - // If there is an addend, turn that into the appropriate MCExpr. - const MCExpr *Sym = MCSymbolRefExpr::create(getSymbol(BaseGVB), Ctx); - if (Offset.sgt(0)) - Sym = MCBinaryExpr::createAdd( - Sym, MCConstantExpr::create(Offset.getSExtValue(), Ctx), Ctx); - else if (Offset.slt(0)) - Sym = MCBinaryExpr::createSub( - Sym, MCConstantExpr::create((-Offset).getSExtValue(), Ctx), Ctx); + const MCExpr *DSExpr = nullptr; + if (auto *DS = dyn_cast(CPA.getDeactivationSymbol())) { + if (isa(DS)) + return Sym; + DSExpr = MCSymbolRefExpr::create(getSymbol(DS), Ctx); + } uint64_t KeyID = CPA.getKey()->getZExtValue(); // We later rely on valid KeyID value in AArch64PACKeyIDToString call from @@ -2304,6 +2500,16 @@ AArch64AsmPrinter::lowerConstantPtrAuth(const ConstantPtrAuth &CPA) { Disc = 0; } + // Check if we need to represent this with an IRELATIVE and emit it if so. + if (auto *IFuncSym = emitPAuthRelocationAsIRelative( + Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), + BaseGVB && BaseGVB->isDSOLocal(), DSExpr)) + return IFuncSym; + + if (DSExpr) + report_fatal_error("deactivation symbols unsupported in constant " + "expressions on this target"); + // Finally build the complete @AUTH expr. return AArch64AuthMCExpr::create(Sym, Disc, AArch64PACKey::ID(KeyID), CPA.hasAddressDiscriminator(), Ctx); @@ -2903,17 +3109,18 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { } case AArch64::AUTx16x17: - emitPtrauthAuthResign(AArch64::X16, - (AArch64PACKey::ID)MI->getOperand(0).getImm(), - MI->getOperand(1).getImm(), &MI->getOperand(2), - AArch64::X17, std::nullopt, 0, 0); + emitPtrauthAuthResign( + AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), + MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, + std::nullopt, 0, 0, MI->getDeactivationSymbol()); return; case AArch64::AUTxMxN: emitPtrauthAuthResign(MI->getOperand(0).getReg(), (AArch64PACKey::ID)MI->getOperand(3).getImm(), MI->getOperand(4).getImm(), &MI->getOperand(5), - MI->getOperand(1).getReg(), std::nullopt, 0, 0); + MI->getOperand(1).getReg(), std::nullopt, 0, 0, + MI->getDeactivationSymbol()); return; case AArch64::AUTPAC: @@ -2921,7 +3128,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, (AArch64PACKey::ID)MI->getOperand(3).getImm(), - MI->getOperand(4).getImm(), MI->getOperand(5).getReg()); + MI->getOperand(4).getImm(), MI->getOperand(5).getReg(), + MI->getDeactivationSymbol()); return; case AArch64::PAC: @@ -3377,6 +3585,9 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { return; } + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + // Finally, do the automated lowerings for everything else. MCInst TmpInst; MCInstLowering.Lower(MI, TmpInst); diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index ad42f4b56caf2..bfe1d8d214747 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -1535,7 +1535,10 @@ void AArch64DAGToDAGISel::SelectPtrauthAuth(SDNode *N) { extractPtrauthBlendDiscriminators(AUTDisc, CurDAG); if (!Subtarget->isX16X17Safer()) { - SDValue Ops[] = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + std::vector Ops = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + // Copy deactivation symbol if present. + if (N->getNumOperands() > 4) + Ops.push_back(N->getOperand(4)); SDNode *AUT = CurDAG->getMachineNode(AArch64::AUTxMxN, DL, MVT::i64, MVT::i64, Ops); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 4f6e3ddd18def..49b65f73901ba 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -9584,6 +9584,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (InGlue.getNode()) Ops.push_back(InGlue); + if (CLI.DeactivationSymbol) + Ops.push_back(DAG.getDeactivationSymbol(CLI.DeactivationSymbol)); + // If we're doing a tall call, use a TC_RETURN here rather than an // actual call instruction. if (IsTailCall) { diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index ba7cbccc0bcd6..3cc4380138b1a 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -2402,6 +2402,7 @@ class BImm pattern> let Inst{25-0} = addr; let DecoderMethod = "DecodeUnconditionalBranch"; + let supportsDeactivationSymbol = true; } class BranchImm pattern> @@ -2459,6 +2460,7 @@ class SignAuthOneData opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = Rn; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthZero opcode_prefix, bits<2> opcode, string asm, @@ -2472,6 +2474,7 @@ class SignAuthZero opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = 0b11111; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthTwoOperand opc, string asm, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 251fd44b6ea31..58257851c1437 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -2167,6 +2167,7 @@ let Predicates = [HasPAuth] in { let Size = 12; let Defs = [X16, X17]; let usesCustomInserter = 1; + let supportsDeactivationSymbol = 1; } // A standalone pattern is used, so that literal 0 can be passed as $Disc. diff --git a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp index 1ca61f5c6b349..cda83b898a54f 100644 --- a/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp +++ b/llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp @@ -8193,6 +8193,8 @@ bool AArch64AsmParser::parseDataExpr(const MCExpr *&Res) { Spec = AArch64::S_GOTPCREL; else if (Identifier == "plt") Spec = AArch64::S_PLT; + else if (Identifier == "funcinit") + Spec = AArch64::S_FUNCINIT; } if (Spec == AArch64::S_None) return Error(Loc, "invalid relocation specifier"); diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 010d0aaa46e7f..c44cbea010ec6 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -1421,6 +1421,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } else if (Info.CFIType) { MIB->setCFIType(MF, Info.CFIType->getZExtValue()); } + MIB->setDeactivationSymbol(MF, Info.DeactivationSymbol); MIB.add(Info.Callee); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp index 7618a57691868..0f1a513b09511 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64ELFObjectWriter.cpp @@ -40,6 +40,7 @@ class AArch64ELFObjectWriter : public MCELFObjectTargetWriter { bool IsPCRel) const override; bool needsRelocateWithSymbol(const MCValue &, unsigned Type) const override; bool isNonILP32reloc(const MCFixup &Fixup, AArch64::Specifier RefKind) const; + void sortRelocs(std::vector &Relocs) override; bool IsILP32; }; @@ -231,6 +232,8 @@ unsigned AArch64ELFObjectWriter::getRelocType(const MCFixup &Fixup, } if (RefKind == AArch64::S_AUTH || RefKind == AArch64::S_AUTHADDR) return ELF::R_AARCH64_AUTH_ABS64; + if (RefKind == AArch64::S_FUNCINIT) + return ELF::R_AARCH64_FUNCINIT64; return ELF::R_AARCH64_ABS64; } case AArch64::fixup_aarch64_add_imm12: @@ -497,6 +500,17 @@ bool AArch64ELFObjectWriter::needsRelocateWithSymbol(const MCValue &Val, Val.getSpecifier()); } +void AArch64ELFObjectWriter::sortRelocs( + std::vector &Relocs) { + // PATCHINST relocations should be applied last because they may overwrite the + // whole instruction and so should take precedence over other relocations that + // modify operands of the original instruction. + std::stable_partition(Relocs.begin(), Relocs.end(), + [](const ELFRelocationEntry &R) { + return R.Type != ELF::R_AARCH64_PATCHINST; + }); +} + std::unique_ptr llvm::createAArch64ELFObjectWriter(uint8_t OSABI, bool IsILP32) { return std::make_unique(OSABI, IsILP32); diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp index 828c5c5462407..4238561804437 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.cpp @@ -40,6 +40,7 @@ const MCAsmInfo::AtSpecifier ELFAtSpecifiers[] = { {AArch64::S_GOT, "GOT"}, {AArch64::S_GOTPCREL, "GOTPCREL"}, {AArch64::S_PLT, "PLT"}, + {AArch64::S_FUNCINIT, "FUNCINIT"}, }; const MCAsmInfo::AtSpecifier MachOAtSpecifiers[] = { diff --git a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h index c28e925d77e2b..a6fe7367605d3 100644 --- a/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h +++ b/llvm/lib/Target/AArch64/MCTargetDesc/AArch64MCAsmInfo.h @@ -164,6 +164,7 @@ enum { // ELF relocation specifiers in data directives: S_PLT = 0x400, S_GOTPCREL, + S_FUNCINIT, // Mach-O @ relocation specifiers: S_MACHO_GOT, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 47e017e17092b..20bb1bc43e598 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3042,9 +3042,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (NeedSign && isa(II->getArgOperand(4))) { auto *SignKey = cast(II->getArgOperand(3)); auto *SignDisc = cast(II->getArgOperand(4)); - auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy()); + auto *Null = ConstantPointerNull::get(Builder.getPtrTy()); auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey, - SignDisc, SignAddrDisc); + SignDisc, Null, Null); replaceInstUsesWith( *II, ConstantExpr::getPointerCast(NewCPA, II->getType())); return eraseInstFromFunction(*II); diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 6477141ab095f..316e5b278f3cb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -697,8 +697,7 @@ static bool isSafeAndProfitableToSinkLoad(LoadInst *L) { Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LoadInst *FirstLI = cast(PN.getIncomingValue(0)); - // Can't forward swifterror through a phi. - if (FirstLI->getOperand(0)->isSwiftError()) + if (!shouldFoldOperandThroughPhi(FirstLI->getOperand(0))) return nullptr; // FIXME: This is overconservative; this transform is allowed in some cases @@ -737,8 +736,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LI->getPointerAddressSpace() != LoadAddrSpace) return nullptr; - // Can't forward swifterror through a phi. - if (LI->getOperand(0)->isSwiftError()) + if (!shouldFoldOperandThroughPhi(LI->getOperand(0))) return nullptr; // We can't sink the load if the loaded value could be modified between diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 23256cf2acbd2..c212c4f45dc37 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -62,6 +62,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" @@ -523,9 +524,10 @@ class Slice { public: Slice() = default; - Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable) + Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable, + Value *ProtectedFieldDisc) : BeginOffset(BeginOffset), EndOffset(EndOffset), - UseAndIsSplittable(U, IsSplittable) {} + UseAndIsSplittable(U, IsSplittable), ProtectedFieldDisc(ProtectedFieldDisc) {} uint64_t beginOffset() const { return BeginOffset; } uint64_t endOffset() const { return EndOffset; } @@ -538,6 +540,10 @@ class Slice { bool isDead() const { return getUse() == nullptr; } void kill() { UseAndIsSplittable.setPointer(nullptr); } + // When this access is via an llvm.protected.field.ptr intrinsic, contains + // the second argument to the intrinsic, the discriminator. + Value *ProtectedFieldDisc; + /// Support for ordering ranges. /// /// This provides an ordering over ranges such that start offsets are @@ -631,6 +637,9 @@ class AllocaSlices { /// Access the dead users for this alloca. ArrayRef getDeadUsers() const { return DeadUsers; } + /// Access the PFP users for this alloca. + ArrayRef getPFPUsers() const { return PFPUsers; } + /// Access Uses that should be dropped if the alloca is promotable. ArrayRef getDeadUsesIfPromotable() const { return DeadUseIfPromotable; @@ -691,6 +700,10 @@ class AllocaSlices { /// they come from outside of the allocated space. SmallVector DeadUsers; + /// Users that are llvm.protected.field.ptr intrinsics. These will be RAUW'd + /// to their first argument if we rewrite the alloca. + SmallVector PFPUsers; + /// Uses which will become dead if can promote the alloca. SmallVector DeadUseIfPromotable; @@ -1064,7 +1077,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { EndOffset = AllocSize; } - AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable)); + AS.Slices.push_back( + Slice(BeginOffset, EndOffset, U, IsSplittable, ProtectedFieldDisc)); } void visitBitCastInst(BitCastInst &BC) { @@ -1274,6 +1288,9 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor { return; } + if (II.getIntrinsicID() == Intrinsic::protected_field_ptr) + AS.PFPUsers.push_back(&II); + Base::visitIntrinsicInst(II); } @@ -4682,7 +4699,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { NewSlices.push_back( Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PLoad->getOperandUse(PLoad->getPointerOperandIndex()), - /*IsSplittable*/ false)); + /*IsSplittable*/ false, nullptr)); LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() << ", " << NewSlices.back().endOffset() << "): " << *PLoad << "\n"); @@ -4838,10 +4855,12 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { LLVMContext::MD_access_group}); // Now build a new slice for the alloca. + // ProtectedFieldDisc==nullptr is a lie, but it doesn't matter because we + // already determined that all accesses are consistent. NewSlices.push_back( Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, &PStore->getOperandUse(PStore->getPointerOperandIndex()), - /*IsSplittable*/ false)); + /*IsSplittable*/ false, nullptr)); LLVM_DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() << ", " << NewSlices.back().endOffset() << "): " << *PStore << "\n"); @@ -5618,6 +5637,32 @@ SROA::runOnAlloca(AllocaInst &AI) { return {Changed, CFGChanged}; } + for (auto &P : AS.partitions()) { + std::optional ProtectedFieldDisc; + // For now, we can't split if a field is accessed both via protected + // field and not. + for (Slice &S : P) { + if (auto *II = dyn_cast(S.getUse()->getUser())) + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + if (!ProtectedFieldDisc) + ProtectedFieldDisc = S.ProtectedFieldDisc; + if (*ProtectedFieldDisc != S.ProtectedFieldDisc) + return {Changed, CFGChanged}; + } + for (Slice *S : P.splitSliceTails()) { + if (auto *II = dyn_cast(S->getUse()->getUser())) + if (II->getIntrinsicID() == Intrinsic::lifetime_start || + II->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + if (!ProtectedFieldDisc) + ProtectedFieldDisc = S->ProtectedFieldDisc; + if (*ProtectedFieldDisc != S->ProtectedFieldDisc) + return {Changed, CFGChanged}; + } + } + // Delete all the dead users of this alloca before splitting and rewriting it. for (Instruction *DeadUser : AS.getDeadUsers()) { // Free up everything used by this instruction. @@ -5635,6 +5680,12 @@ SROA::runOnAlloca(AllocaInst &AI) { clobberUse(*DeadOp); Changed = true; } + for (IntrinsicInst *PFPUser : AS.getPFPUsers()) { + PFPUser->replaceAllUsesWith(PFPUser->getArgOperand(0)); + + DeadInsts.push_back(PFPUser); + Changed = true; + } // No slices to split. Leave the dead alloca for a later pass to clean up. if (AS.begin() == AS.end()) diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index babd7f6b3a058..f9aa8455c30cd 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -3846,10 +3846,7 @@ bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { if (Op->getType()->isMetadataTy()) return false; - // swifterror pointers can only be used by a load, store, or as a swifterror - // argument; swifterror pointers are not allowed to be used in select or phi - // instructions. - if (Op->isSwiftError()) + if (!shouldFoldOperandThroughPhi(Op)) return false; // Cannot replace alloca argument with phi/select. diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 674de573f7919..a9c8898732be9 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -2134,6 +2134,22 @@ static bool replacingOperandWithVariableIsCheap(const Instruction *I, return !isa(I); } +bool llvm::shouldFoldOperandThroughPhi(const Value *Ptr) { + // swifterror pointers can only be used by a load or store; sinking a load + // or store would require introducing a select for the pointer operand, + // which isn't allowed for swifterror pointers. + if (Ptr->isSwiftError()) + return false; + + // Protected pointer field loads/stores should be paired with the intrinsic + // to avoid unnecessary address escapes. + if (auto *II = dyn_cast(Ptr)) + if (II->getIntrinsicID() == Intrinsic::protected_field_ptr) + return false; + + return true; +} + // All instructions in Insts belong to different blocks that all unconditionally // branch to a common successor. Analyze each instruction and return true if it // would be possible to sink them into their successor, creating one common diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 8d8a60b6918fe..fca17893f47f4 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) { if (isa(C)) return getVM()[V] = ConstantVector::get(Ops); if (isa(C)) - return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast(Ops[1]), - cast(Ops[2]), Ops[3]); + return getVM()[V] = + ConstantPtrAuth::get(Ops[0], cast(Ops[1]), + cast(Ops[2]), Ops[3], Ops[4]); // If this is a no-operand constant, it must be because the type was remapped. if (isa(C)) return getVM()[V] = PoisonValue::get(NewTy); diff --git a/llvm/test/Bitcode/compatibility.ll b/llvm/test/Bitcode/compatibility.ll index 0b5ce08c00a23..9b490f1105912 100644 --- a/llvm/test/Bitcode/compatibility.ll +++ b/llvm/test/Bitcode/compatibility.ll @@ -217,9 +217,13 @@ declare void @g.f1() ; CHECK: @g.sanitize_address_dyninit = global i32 0, sanitize_address_dyninit ; CHECK: @g.sanitize_multiple = global i32 0, sanitize_memtag, sanitize_address_dyninit +@ds = external global i32 + ; ptrauth constant @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null) ; CHECK: @auth_var = global ptr ptrauth (ptr @g1, i32 0, i64 65535) +@auth_var.ds = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null, ptr @ds) +; CHECK: @auth_var.ds = global ptr ptrauth (ptr @g1, i32 0, i64 65535, ptr null, ptr @ds) ;; Aliases ; Format: @ = [Linkage] [Visibility] [DLLStorageClass] [ThreadLocal] diff --git a/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll b/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll index d860104b9cb3d..5628e17b4936e 100644 --- a/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll +++ b/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll @@ -13,6 +13,7 @@ ; CHECK-NEXT: &1 | FileCheck %s + +@g = external global i8 + +; CHECK: signed ptrauth constant deactivation symbol must be a global variable or null +@ptr = global ptr ptrauth (ptr @g, i32 0, i64 65535, ptr null, ptr inttoptr (i64 16 to ptr)) diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 33928ac118e0c..168502f89cbf8 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -1385,7 +1385,7 @@ define ptr @foo() { // Check get(), getKey(), getDiscriminator(), getAddrDiscriminator(). auto *NewPtrAuth = sandboxir::ConstantPtrAuth::get( &F, PtrAuth->getKey(), PtrAuth->getDiscriminator(), - PtrAuth->getAddrDiscriminator()); + PtrAuth->getAddrDiscriminator(), PtrAuth->getDeactivationSymbol()); EXPECT_EQ(NewPtrAuth, PtrAuth); // Check hasAddressDiscriminator(). EXPECT_EQ(PtrAuth->hasAddressDiscriminator(), diff --git a/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp b/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp index 86ad41fa7ad50..90b58236060d2 100644 --- a/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp +++ b/llvm/unittests/Transforms/Utils/ValueMapperTest.cpp @@ -450,6 +450,10 @@ TEST(ValueMapperTest, mapValuePtrAuth) { PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "Storage0"); std::unique_ptr Storage1 = std::make_unique( PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "Storage1"); + std::unique_ptr DS0 = std::make_unique( + PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "DS0"); + std::unique_ptr DS1 = std::make_unique( + PtrTy, false, GlobalValue::ExternalLinkage, nullptr, "DS1"); ConstantInt *ConstKey = ConstantInt::get(Int32Ty, 1); ConstantInt *ConstDisc = ConstantInt::get(Int64Ty, 1234); @@ -457,11 +461,12 @@ TEST(ValueMapperTest, mapValuePtrAuth) { ValueToValueMapTy VM; VM[Var0.get()] = Var1.get(); VM[Storage0.get()] = Storage1.get(); + VM[DS0.get()] = DS1.get(); - ConstantPtrAuth *Value = - ConstantPtrAuth::get(Var0.get(), ConstKey, ConstDisc, Storage0.get()); - ConstantPtrAuth *MappedValue = - ConstantPtrAuth::get(Var1.get(), ConstKey, ConstDisc, Storage1.get()); + ConstantPtrAuth *Value = ConstantPtrAuth::get(Var0.get(), ConstKey, ConstDisc, + Storage0.get(), DS0.get()); + ConstantPtrAuth *MappedValue = ConstantPtrAuth::get( + Var1.get(), ConstKey, ConstDisc, Storage1.get(), DS1.get()); EXPECT_EQ(ValueMapper(VM).mapValue(*Value), MappedValue); } diff --git a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp index d3fba6a92357a..477f25e6c0688 100644 --- a/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp @@ -957,6 +957,13 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, } } const EmitNodeMatcherCommon *EN = cast(N); + bool SupportsDeactivationSymbol = + EN->getInstruction().TheDef->getValueAsBit( + "supportsDeactivationSymbol"); + if (SupportsDeactivationSymbol) { + OS << "OPC_CaptureDeactivationSymbol,\n"; + OS.indent(FullIndexWidth + Indent); + } bool IsEmitNode = isa(EN); OS << (IsEmitNode ? "OPC_EmitNode" : "OPC_MorphNodeTo"); bool CompressVTs = EN->getNumVTs() < 3; @@ -1049,8 +1056,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, OS << '\n'; } - return 4 + !CompressVTs + !CompressNodeInfo + NumTypeBytes + - NumOperandBytes + NumCoveredBytes; + return 4 + SupportsDeactivationSymbol + !CompressVTs + !CompressNodeInfo + + NumTypeBytes + NumOperandBytes + NumCoveredBytes; } case Matcher::CompleteMatch: { const CompleteMatchMatcher *CM = cast(N);