-
Notifications
You must be signed in to change notification settings - Fork 13.6k
TypeTree support in autodiff #144197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
TypeTree support in autodiff #144197
Changes from all commits
0a56a89
a03bfe5
aad9b6e
07685f3
d584119
cd9753d
46e89c8
2544376
94aa8bc
96ac876
340aa40
54573be
9ffdfe8
e93b3e3
dbc871a
6ff80be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ use gccjit::{ | |
use rustc_abi as abi; | ||
use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout, WrappingRange}; | ||
use rustc_apfloat::{Float, Round, Status, ieee}; | ||
use rustc_ast::expand::typetree::FncTree; | ||
use rustc_codegen_ssa::MemFlags; | ||
use rustc_codegen_ssa::common::{ | ||
AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope, TypeKind, | ||
|
@@ -1368,6 +1369,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { | |
_src_align: Align, | ||
size: RValue<'gcc>, | ||
flags: MemFlags, | ||
_tt: Option<FncTree>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gcc people would probably appreciate a commend clarifying, that these tt are an LLVM only feature and thus shouldn't be set for gcc |
||
) { | ||
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported"); | ||
let size = self.intcast(size, self.type_size_t(), false); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -682,13 +682,11 @@ pub(crate) fn run_pass_manager( | |
for function in cx.get_functions() { | ||
let enzyme_marker = "enzyme_marker"; | ||
if attributes::has_string_attr(function, enzyme_marker) { | ||
// Sanity check: Ensure 'noinline' is present before replacing it. | ||
assert!( | ||
attributes::has_attr(function, Function, llvm::AttributeKind::NoInline), | ||
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'" | ||
); | ||
// Remove 'noinline' if present (it should be there in most cases) | ||
if attributes::has_attr(function, Function, llvm::AttributeKind::NoInline) { | ||
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd probably merge the frontend rework first, then you can drop these lines. |
||
} | ||
|
||
attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline); | ||
attributes::remove_string_attr_from_llfn(function, enzyme_marker); | ||
|
||
assert!( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
use std::os::raw::{c_char, c_uint}; | ||
use std::ptr; | ||
|
||
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; | ||
use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; | ||
use rustc_codegen_ssa::ModuleCodegen; | ||
use rustc_codegen_ssa::common::TypeKind; | ||
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; | ||
|
@@ -14,7 +16,7 @@ use crate::context::SimpleCx; | |
use crate::declare::declare_simple_fn; | ||
use crate::errors::{AutoDiffWithoutEnable, LlvmError}; | ||
use crate::llvm::AttributePlace::Function; | ||
use crate::llvm::{Metadata, True}; | ||
use crate::llvm::{Metadata, True, TypeTree}; | ||
use crate::value::Value; | ||
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; | ||
|
||
|
@@ -512,3 +514,141 @@ pub(crate) fn differentiate<'ll>( | |
|
||
Ok(()) | ||
} | ||
|
||
/// Converts a Rust TypeTree to Enzyme's internal TypeTree format | ||
/// | ||
/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) | ||
/// and converts it to Enzyme's internal C++ TypeTree representation that | ||
/// Enzyme can understand during differentiation analysis. | ||
#[cfg(llvm_enzyme)] | ||
fn to_enzyme_typetree( | ||
rust_typetree: RustTypeTree, | ||
data_layout: &str, | ||
llcx: &llvm::Context, | ||
) -> TypeTree { | ||
// Start with an empty TypeTree | ||
let mut enzyme_tt = TypeTree::new(); | ||
|
||
// Convert each Type in the Rust TypeTree to Enzyme format | ||
for rust_type in rust_typetree.0 { | ||
let concrete_type = match rust_type.kind { | ||
rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, | ||
rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, | ||
rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, | ||
rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, | ||
rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, | ||
rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, | ||
rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, | ||
}; | ||
|
||
// Create a TypeTree for this specific type | ||
let type_tt = TypeTree::from_type(concrete_type, llcx); | ||
|
||
// Apply offset if specified | ||
let type_tt = if rust_type.offset == -1 { | ||
type_tt // -1 means everywhere/no specific offset | ||
} else { | ||
// Apply specific offset positioning | ||
type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) | ||
}; | ||
|
||
// Merge this type into the main TypeTree | ||
enzyme_tt = enzyme_tt.merge(type_tt); | ||
} | ||
|
||
enzyme_tt | ||
} | ||
|
||
#[cfg(not(llvm_enzyme))] | ||
#[allow(dead_code)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this need a dead_code? Are you also gating the caller side? Then you shouldn't need this dummy. |
||
fn to_enzyme_typetree( | ||
_rust_typetree: RustTypeTree, | ||
_data_layout: &str, | ||
_llcx: &llvm::Context, | ||
) -> ! { | ||
unimplemented!("TypeTree conversion not available without llvm_enzyme support") | ||
} | ||
|
||
// Attaches TypeTree information to LLVM function as enzyme_type attributes. | ||
#[cfg(llvm_enzyme)] | ||
pub(crate) fn add_tt<'ll>( | ||
llmod: &'ll llvm::Module, | ||
llcx: &'ll llvm::Context, | ||
fn_def: &'ll Value, | ||
tt: FncTree, | ||
) { | ||
let inputs = tt.args; | ||
let ret_tt: RustTypeTree = tt.ret; | ||
|
||
// Get LLVM data layout string for TypeTree conversion | ||
let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; | ||
let llvm_data_layout = | ||
std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) | ||
.expect("got a non-UTF8 data-layout from LLVM"); | ||
|
||
// Attribute name that Enzyme recognizes for TypeTree information | ||
let attr_name = "enzyme_type"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @KMJ-007 Here you mark your attribute as enzyme_type, but in the test below I can't spot any metadata with that name. Are you sure it is getting lowered and not just dropped? Maybe add a few dbg statements to see till where exists. |
||
let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); | ||
|
||
// Attach TypeTree attributes to each input parameter | ||
// Enzyme uses these to understand parameter memory layouts during differentiation | ||
for (i, input) in inputs.iter().enumerate() { | ||
unsafe { | ||
// Convert Rust TypeTree to Enzyme's internal format | ||
let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); | ||
|
||
// Serialize TypeTree to string format that Enzyme can parse | ||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); | ||
let c_str = std::ffi::CStr::from_ptr(c_str); | ||
|
||
// Create LLVM string attribute with TypeTree information | ||
let attr = llvm::LLVMCreateStringAttribute( | ||
llcx, | ||
c_attr_name.as_ptr(), | ||
c_attr_name.as_bytes().len() as c_uint, | ||
c_str.as_ptr(), | ||
c_str.to_bytes().len() as c_uint, | ||
); | ||
|
||
// Attach attribute to the specific function parameter | ||
// Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments | ||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); | ||
|
||
// Free the C string to prevent memory leaks | ||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); | ||
} | ||
} | ||
|
||
// Attach TypeTree attribute to the return type | ||
// Enzyme needs this to understand how to handle return value derivatives | ||
unsafe { | ||
let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); | ||
let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); | ||
let c_str = std::ffi::CStr::from_ptr(c_str); | ||
|
||
let ret_attr = llvm::LLVMCreateStringAttribute( | ||
llcx, | ||
c_attr_name.as_ptr(), | ||
c_attr_name.as_bytes().len() as c_uint, | ||
c_str.as_ptr(), | ||
c_str.to_bytes().len() as c_uint, | ||
); | ||
|
||
// Attach to function return type | ||
attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); | ||
|
||
// Free the C string | ||
llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); | ||
} | ||
} | ||
|
||
// Fallback implementation when Enzyme is not available | ||
#[cfg(not(llvm_enzyme))] | ||
pub(crate) fn add_tt<'ll>( | ||
_llmod: &'ll llvm::Module, | ||
_llcx: &'ll llvm::Context, | ||
_fn_def: &'ll Value, | ||
_tt: FncTree, | ||
) { | ||
unimplemented!() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant comment.